-
Notifications
You must be signed in to change notification settings - Fork 30
Pytorch update #280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch update #280
Conversation
Refactor load_model to use factories for settings and model
niekdejonge
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@florian-huber Thanks a lot! That seems like quite some work to get saving to work as well. I assume we will have to upload new MS2DeepScore models, of the new version right? I can do that now.
One question, will the old loading function be able to handle the new model format? If not, I will make a new zenodo link, so people using older versions don't automatically download the new versions.
|
|
||
| if settings: | ||
| # coerce a copy against current defaults | ||
| settings = _coerce_settings_dict(settings, self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition, to automatically convert incorrect types.
|
|
||
| def save(self, filepath): | ||
|
|
||
| def save(self, filepath: Union[str, Path]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice that you fixed this @florian-huber Thanks!
| except Exception as safe_err: | ||
| if not allow_legacy: | ||
| raise RuntimeError( | ||
| "Failed to load safely. If this is a trusted legacy file, call with allow_legacy=True." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice solution!
|
|
||
| # ---------- public API ---------- | ||
|
|
||
| def load_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is now like this:
The load by default works for the new way of saving, but if allow_legacy=True it will still be able to load the old versions. If this is correct, that is great!
tests/test_SettingsMS2deepscore.py
Outdated
| assert settings_false.use_fixed_set is False | ||
|
|
||
|
|
||
| #def test_coerce_invalid_string_to_bool(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these tests be there?
torch>=2.4(was <=2.6 before) --> closes Different way of saving models #236allow_legacy=True)models/__model_format__.py. The first setting here isms2deepscore.safe.v1(but any string can in the future be used).convert_legacy_checkpoint()function that allows converting "old" MS2DeepScore 2.0 models to the new style.