diff --git a/cellacdc/__init__.py b/cellacdc/__init__.py index 2b73f7b0..75919f66 100755 --- a/cellacdc/__init__.py +++ b/cellacdc/__init__.py @@ -525,11 +525,18 @@ def printl(*objects, pretty=False, is_decorator=False, idx=1, **kwargs): ] segment_anything_weights_filenames = [ - 'sam_vit_h_4b8939.pth', - 'sam_vit_l_0b3195.pth', + 'sam_vit_h_4b8939.pth', + 'sam_vit_l_0b3195.pth', 'sam_vit_b_01ec64.pth' ] +sam2_weights_filenames = [ + 'sam2.1_hiera_large.pt', + 'sam2.1_hiera_base_plus.pt', + 'sam2.1_hiera_small.pt', + 'sam2.1_hiera_tiny.pt' +] + deepsea_weights_filenames = [ 'segmentation.pth', 'tracker.pth' diff --git a/cellacdc/widgets.py b/cellacdc/widgets.py index d9da7f54..73933aba 100755 --- a/cellacdc/widgets.py +++ b/cellacdc/widgets.py @@ -11102,7 +11102,11 @@ def selectModel(self): model_name = win.model_name print(f'Importing promptable model {model_name}...') - + + # Download model weights, consistent with gui.py + downloadWin = apps.downloadModel(model_name, parent=self._parent) + downloadWin.download() + acdcPromptSegment = myutils.import_promptable_segment_module(model_name) init_argspecs, segment_argspecs = myutils.getModelArgSpec( acdcPromptSegment diff --git a/pyproject.toml b/pyproject.toml index 15dffc5b..d6e76ee3 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "boto3", "requests", "setuptools-scm", + "matplotlib" ] dynamic = [ "version", @@ -148,7 +149,6 @@ all = [ dev = [ "pytest", "pytest-cov", - "matplotlib", ] [project.scripts]