Skip to content

Commit 47d0f25

Browse files
add test_fine_tuning_deterministic
1 parent d1aa70e commit 47d0f25

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/test_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22

33
import anndata
4+
import lightning as L
45
import numpy as np
56
import pandas as pd
67
import pytest
@@ -238,6 +239,28 @@ def test_saved_model_identical(slide_key: str | None, scgpt_model_dir: str | Non
238239
assert torch.equal(param, new_model.state_dict()[name])
239240

240241

242+
def test_fine_tuning_deterministic():
243+
adata = novae.toy_dataset(n_panels=1, compute_spatial_neighbors=True, xmax=300)[0]
244+
245+
model = novae.Novae.from_pretrained("MICS-Lab/novae-human-0")
246+
L.seed_everything(0)
247+
model.fine_tune(adata, max_epochs=1)
248+
model.compute_representations(adata)
249+
model.assign_domains(adata)
250+
251+
domains = adata.obs[Keys.LEAVES].copy()
252+
representations = adata.obsm[Keys.REPR].copy()
253+
254+
new_model = novae.Novae.from_pretrained("MICS-Lab/novae-human-0")
255+
L.seed_everything(0)
256+
new_model.fine_tune(adata, max_epochs=1)
257+
new_model.compute_representations(adata)
258+
new_model.assign_domains(adata)
259+
260+
assert (adata.obsm[Keys.REPR] == representations).all()
261+
assert domains.equals(adata.obs[Keys.LEAVES])
262+
263+
241264
def test_safetensors_parameters_names():
242265
from huggingface_hub import hf_hub_download
243266
from safetensors import safe_open

0 commit comments

Comments
 (0)