Skip to content

Commit f8ab948

Browse files
committed
feat: make RaggedVariants an Awkward Array subclass supporting arbitrary additional fields.
1 parent b5d202a commit f8ab948

File tree

12 files changed

+475
-353
lines changed

12 files changed

+475
-353
lines changed

pixi.lock

Lines changed: 38 additions & 36 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ plink2 = "*"
5656
[pypi-dependencies]
5757
# genvarloader = { path = ".", editable = true }
5858
hirola = "==0.3"
59-
# seqpro = "==0.7.1"
60-
genoray = { git = "https://github.com/d-laub/genoray.git", branch = "dlaub/svar-af" }
59+
seqpro = "==0.8.2"
60+
genoray = "==0.16.0"
6161

6262
[feature.docs.dependencies]
6363
sphinx = ">=7.4.7"
@@ -91,9 +91,11 @@ python = "3.11.*"
9191
python = "3.12.*"
9292

9393
[tasks]
94-
install = "uv pip install -e . -e /carter/users/dlaub/projects/ML4GLand/SeqPro"
94+
install = "uv pip install -e ."
9595
pre-commit = "pre-commit install --hook-type commit-msg --hook-type pre-push"
96-
gen = { cmd = "python tests/data/generate_ground_truth.py", depends-on = ["install"] }
96+
gen = { cmd = "python tests/data/generate_ground_truth.py", depends-on = [
97+
"install",
98+
] }
9799
test = { cmd = "pytest tests && cargo test --release", depends-on = ["gen"] }
98100

99101
[feature.docs.tasks]
@@ -107,4 +109,4 @@ xarray = "*"
107109
ipykernel = "*"
108110

109111
[feature.demo.tasks]
110-
i-kernel = "ipython kernel install --user --name 'gvl-demo' --display-name 'GVL Demo'"
112+
i-kernel = "ipython kernel install --user --name 'gvl-demo' --display-name 'GVL Demo'"

pyproject.toml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ dependencies = [
2929
"pooch",
3030
"awkward",
3131
"hirola>=0.3,<0.4",
32-
"seqpro>=0.7.1",
33-
"genoray>=0.15.0",
32+
"seqpro>=0.8.2",
33+
"genoray>=0.16.0",
3434
]
3535

3636
[project.urls]
@@ -84,7 +84,14 @@ version_scheme = "semver2"
8484
version_provider = "pep621"
8585
update_changelog_on_bump = true
8686
major_version_zero = true
87-
allowed_prefixes = ["Merge", "Revert", "Pull request", "fixup!", "squash!", "[pre-commit.ci]"]
87+
allowed_prefixes = [
88+
"Merge",
89+
"Revert",
90+
"Pull request",
91+
"fixup!",
92+
"squash!",
93+
"[pre-commit.ci]",
94+
]
8895

8996
[build-system]
9097
requires = ["maturin>=1.6,<2.0"]

python/genvarloader/_dataset/_impl.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def with_settings(
269269
rc_neg: bool | None = None,
270270
min_af: float | Literal[False] | None = None,
271271
max_af: float | Literal[False] | None = None,
272+
var_fields: list[str] | None = None,
272273
) -> Self:
273274
"""Modify settings of the dataset, returning a new dataset without modifying the old one.
274275
@@ -291,6 +292,8 @@ def with_settings(
291292
max_af
292293
The maximum allele frequency to include in the dataset. If set to :code:`False`, disables this filter.
293294
If dataset is not backed by SVAR genotypes, this will raise an error.
295+
var_fields
296+
The variant fields to include in the dataset.
294297
"""
295298
to_evolve = {}
296299

@@ -342,16 +345,26 @@ def with_settings(
342345
elif max_af is False:
343346
max_af = None
344347

345-
haps = evolve(self._seqs, min_af=min_af, max_af=max_af)
348+
haps = to_evolve.get("_seqs", self._seqs)
349+
haps = evolve(haps, min_af=min_af, max_af=max_af)
346350
to_evolve["_seqs"] = haps
347351

352+
if var_fields is not None:
353+
missing = list(set(var_fields) - set(self.available_var_fields))
354+
if missing or not isinstance(self._seqs, Haps):
355+
raise ValueError(f"Missing variant fields: {missing}")
356+
haps = to_evolve.get("_seqs", self._seqs)
357+
haps = evolve(haps, var_fields=var_fields)
358+
to_evolve["_seqs"] = haps
359+
360+
if "_seqs" in to_evolve:
361+
haps = to_evolve["_seqs"]
348362
if isinstance(self._recon, Haps):
349363
recon = haps
364+
to_evolve["_recon"] = recon
350365
elif isinstance(self._recon, HapsTracks):
351366
recon = evolve(self._recon, haps=haps)
352-
else:
353-
recon = self._recon
354-
to_evolve["_recon"] = recon
367+
to_evolve["_recon"] = recon
355368

356369
return evolve(self, **to_evolve)
357370

@@ -728,6 +741,24 @@ def full_shape(self) -> tuple[int, int]:
728741
"""Return the full shape of the dataset, ignoring any subsetting. :code:`(n_samples, n_regions)`"""
729742
return self._idxer.full_shape
730743

744+
@property
745+
def available_var_fields(self) -> list[str]:
746+
"""Available variant fields."""
747+
match self._seqs:
748+
case Haps():
749+
return self._seqs.available_var_fields
750+
case _:
751+
return []
752+
753+
@property
754+
def active_var_fields(self) -> list[str]:
755+
"""Active variant fields."""
756+
match self._recon:
757+
case (Haps() as haps) | HapsTracks(haps=haps):
758+
return haps.var_fields
759+
case _:
760+
return []
761+
731762
@property
732763
def available_tracks(self) -> list[str] | None:
733764
"""The available tracks in the dataset."""
@@ -1352,7 +1383,7 @@ def __getitem__(
13521383

13531384
if squeeze:
13541385
# (1 [p] l) -> ([p] l)
1355-
recon = tuple(o.squeeze(0) for o in recon)
1386+
recon = tuple(o.squeeze(axis=0) for o in recon)
13561387

13571388
if unlist:
13581389
recon = recon[0]
@@ -1538,28 +1569,32 @@ def with_seqs(
15381569
return super().with_seqs(kind)
15391570

15401571
@overload
1541-
def with_tracks(self, tracks: None = ..., kind: None = ...) -> Self: ...
1572+
def with_tracks(self, tracks: None = None, kind: None = None) -> Self: ...
15421573
@overload
15431574
def with_tracks(
1544-
self, tracks: None = ..., kind: Literal["tracks"] = ...
1575+
self, *, tracks: None = None, kind: Literal["tracks"]
15451576
) -> ArrayDataset[MaybeSEQ, NDArray[np.float32]]: ...
15461577
@overload
15471578
def with_tracks(
1548-
self, tracks: None = ..., kind: Literal["intervals"] = ...
1579+
self, *, tracks: None = None, kind: Literal["intervals"]
15491580
) -> ArrayDataset[MaybeSEQ, RaggedIntervals]: ...
15501581
@overload
15511582
def with_tracks(
15521583
self,
1553-
tracks: Literal[False] = ...,
1554-
kind: Literal["tracks", "intervals"] | None = ...,
1584+
tracks: Literal[False],
1585+
kind: Literal["tracks", "intervals"] | None = None,
15551586
) -> ArrayDataset[MaybeSEQ, None]: ...
15561587
@overload
15571588
def with_tracks(
1558-
self, tracks: str | list[str] = ..., kind: Literal["tracks"] = ...
1589+
self, tracks: str | list[str], kind: None = None
1590+
) -> ArrayDataset[MaybeSEQ, MaybeTRK]: ...
1591+
@overload
1592+
def with_tracks(
1593+
self, tracks: str | list[str], kind: Literal["tracks"]
15591594
) -> ArrayDataset[MaybeSEQ, NDArray[np.float32]]: ...
15601595
@overload
15611596
def with_tracks(
1562-
self, tracks: str | list[str] = ..., kind: Literal["intervals"] = ...
1597+
self, tracks: str | list[str], kind: Literal["intervals"]
15631598
) -> ArrayDataset[MaybeSEQ, RaggedIntervals]: ...
15641599
def with_tracks(
15651600
self,
@@ -1684,28 +1719,32 @@ def with_seqs(
16841719
return super().with_seqs(kind)
16851720

16861721
@overload
1687-
def with_tracks(self, tracks: None = ..., kind: None = ...) -> Self: ...
1722+
def with_tracks(self, tracks: None = None, kind: None = None) -> Self: ...
16881723
@overload
16891724
def with_tracks(
1690-
self, tracks: None = ..., kind: Literal["tracks"] = ...
1725+
self, *, tracks: None = None, kind: Literal["tracks"]
16911726
) -> RaggedDataset[MaybeRSEQ, RaggedTracks]: ...
16921727
@overload
16931728
def with_tracks(
1694-
self, tracks: None = ..., kind: Literal["intervals"] = ...
1729+
self, *, tracks: None = None, kind: Literal["intervals"]
16951730
) -> RaggedDataset[MaybeRSEQ, RaggedIntervals]: ...
16961731
@overload
16971732
def with_tracks(
16981733
self,
1699-
tracks: Literal[False] = ...,
1700-
kind: Literal["tracks", "intervals"] | None = ...,
1734+
tracks: Literal[False],
1735+
kind: Literal["tracks", "intervals"] | None = None,
17011736
) -> RaggedDataset[MaybeRSEQ, None]: ...
17021737
@overload
17031738
def with_tracks(
1704-
self, tracks: str | list[str] = ..., kind: Literal["tracks"] = ...
1739+
self, tracks: str | list[str], kind: None = None
1740+
) -> RaggedDataset[MaybeRSEQ, MaybeRTRK]: ...
1741+
@overload
1742+
def with_tracks(
1743+
self, tracks: str | list[str], kind: Literal["tracks"]
17051744
) -> RaggedDataset[MaybeRSEQ, RaggedTracks]: ...
17061745
@overload
17071746
def with_tracks(
1708-
self, tracks: str | list[str] = ..., kind: Literal["intervals"] = ...
1747+
self, tracks: str | list[str], kind: Literal["intervals"]
17091748
) -> RaggedDataset[MaybeRSEQ, RaggedIntervals]: ...
17101749
def with_tracks(
17111750
self,

0 commit comments

Comments
 (0)