diff --git a/python/genvarloader/_dataset/_impl.py b/python/genvarloader/_dataset/_impl.py index 8136edb..da53fcc 100644 --- a/python/genvarloader/_dataset/_impl.py +++ b/python/genvarloader/_dataset/_impl.py @@ -734,12 +734,12 @@ def ploidy(self) -> int | None: @property def shape(self) -> tuple[int, int]: - """Return the shape of the dataset. :code:`(n_samples, n_regions)`""" + """Return the shape of the dataset. :code:`(n_regions, n_samples)`""" return self.n_regions, self.n_samples @property def full_shape(self) -> tuple[int, int]: - """Return the full shape of the dataset, ignoring any subsetting. :code:`(n_samples, n_regions)`""" + """Return the full shape of the dataset, ignoring any subsetting. :code:`(n_regions, n_samples)`""" return self._idxer.full_shape @property diff --git a/python/genvarloader/_dataset/_rag_variants.py b/python/genvarloader/_dataset/_rag_variants.py index 15a2cfe..ee85ffe 100644 --- a/python/genvarloader/_dataset/_rag_variants.py +++ b/python/genvarloader/_dataset/_rag_variants.py @@ -31,7 +31,7 @@ class RaggedVariant(ak.Record): class RaggedVariants(ak.Array): """An awkward record array, typically with shape (batch, ploidy, ~variants). - Guaranteed to at least have the field "alts" and "v_starts" and one of "refs" or "ilens".""" + Guaranteed to at least have the field "alt" and "start" and one of "ref" or "ilen".""" def __init__( self, @@ -59,6 +59,25 @@ def __init__( super().__init__(arr) + @classmethod + def from_ak(cls, arr: ak.Array) -> RaggedVariants: + """Create a RaggedVariants object from an awkward array. + + Parameters + ---------- + arr + The awkward array to create a RaggedVariants object from. + """ + fields = set(arr.fields) + + if missing := {"alt", "start"} - fields: + raise ValueError(f"Missing required fields: {missing}") + + if {"ref", "ilen"}.isdisjoint(fields): + raise ValueError("Must have one of ref or ilen.") + + return ak.with_parameter(arr, "__record__", RaggedVariants.__name__) + @property def alt(self) -> ak.Array: """Alternative alleles.""" @@ -73,7 +92,7 @@ def start(self) -> Ragged[POS_TYPE]: def ilen(self) -> Ragged[np.int32]: """Indel lengths. Infallible.""" if "ilen" not in self.fields: - ilen = ak.num(self.alt, -1) - ak.num(self.ref, -1) + ilen = ak.str.length(self.alt) - ak.str.length(self.ref) # type: ignore ilen = Ragged(ilen) return ilen diff --git a/python/genvarloader/_dataset/_reconstruct.py b/python/genvarloader/_dataset/_reconstruct.py index 33f70ba..1a328c4 100644 --- a/python/genvarloader/_dataset/_reconstruct.py +++ b/python/genvarloader/_dataset/_reconstruct.py @@ -80,7 +80,7 @@ def from_table(cls, path: str | Path, one_based: bool = True): info = { k: variants[k].to_numpy() for k, v in variants.schema.items() - if v.is_numeric() + if v.is_numeric() and k not in {"POS", "ILEN"} } ref = (