Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ plink2 = "*"
hirola = "==0.3"
seqpro = "==0.8.2"
genoray = "==0.16.0"
pydantic = ">=2,<3"

[feature.docs.dependencies]
sphinx = ">=7.4.7"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"pysam",
"pyarrow",
"pyranges",
"pydantic>=2,<3",
"more-itertools",
"tqdm",
"pybigwig",
Expand Down
13 changes: 7 additions & 6 deletions python/genvarloader/_dataset/_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import Callable, Generic, Literal, TypeVar, cast, overload
Expand Down Expand Up @@ -34,6 +33,7 @@
from ._reconstruct import Haps, HapsTracks, Ref, RefTracks, Tracks, TrackType
from ._reference import Reference
from ._utils import bed_to_regions, regions_to_bed
from ._write import Metadata

if TORCH_AVAILABLE:
import torch
Expand Down Expand Up @@ -143,11 +143,11 @@ def open(

# read metadata
with _py_open(path / "metadata.json") as f:
metadata = json.load(f)
samples: list[str] = metadata["samples"]
contigs: list[str] = metadata["contigs"]
ploidy: int | None = metadata.get("ploidy", None)
max_jitter: int = metadata.get("max_jitter", 0)
metadata = Metadata.model_validate_json(f.read())
samples = metadata.samples
contigs = metadata.contigs
ploidy = metadata.ploidy
max_jitter = metadata.max_jitter

# read input regions and generate index map
bed = pl.read_ipc(path / "input_regions.arrow")
Expand Down Expand Up @@ -179,6 +179,7 @@ def open(
regions=regions,
samples=samples,
ploidy=ploidy,
version=metadata.version,
min_af=min_af,
max_af=max_af,
)
Expand Down
7 changes: 6 additions & 1 deletion python/genvarloader/_dataset/_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from loguru import logger
from numpy.typing import NDArray
from packaging.version import Version
from seqpro.rag import OFFSET_TYPE, Ragged
from tqdm.auto import tqdm
from typing_extensions import assert_never
Expand Down Expand Up @@ -210,6 +211,7 @@ def from_path(
regions: NDArray[np.int32],
samples: list[str],
ploidy: int,
version: Version | None,
min_af: float | None = None,
max_af: float | None = None,
) -> Haps[RaggedVariants]:
Expand Down Expand Up @@ -246,7 +248,10 @@ def from_path(
)
else:
logger.info("Loading variant data.")
variants = _Variants.from_table(path / "genotypes" / "variants.arrow")
variants = _Variants.from_table(
path / "genotypes" / "variants.arrow",
one_based=version is not None and version >= Version("0.18.0"),
)
v_idxs = np.memmap(
path / "genotypes" / "variant_idxs.npy",
dtype=V_IDX_TYPE,
Expand Down
32 changes: 28 additions & 4 deletions python/genvarloader/_dataset/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import json
import shutil
import warnings
from importlib.metadata import version
from pathlib import Path
from typing import cast
from typing import Annotated, Any, cast

import awkward as ak
import numpy as np
Expand All @@ -16,6 +17,8 @@
from more_itertools import mark_ends
from natsort import natsorted
from numpy.typing import NDArray
from packaging.version import Version
from pydantic import BaseModel, BeforeValidator, PlainSerializer, WithJsonSchema
from seqpro.rag import OFFSET_TYPE
from tqdm.auto import tqdm

Expand All @@ -26,6 +29,27 @@
from ._utils import splits_sum_le_value


class Metadata(BaseModel, arbitrary_types_allowed=True):
samples: list[str]
contigs: list[str]
n_regions: int
ploidy: int | None = None
max_jitter: int = 0
version: (
Annotated[
Version,
BeforeValidator(lambda v: Version(v) if isinstance(v, str) else v),
PlainSerializer(lambda v: str(v), return_type=str),
WithJsonSchema({"type": "string"}, mode="serialization"),
]
| None
) = None

@property
def n_samples(self) -> int:
return len(self.samples)


def write(
path: str | Path,
bed: str | Path | pl.DataFrame,
Expand Down Expand Up @@ -77,7 +101,7 @@ def write(

max_mem = parse_memory(max_mem)

metadata = {}
metadata: dict[str, Any] = {"version": Version(version("genvarloader"))}
path = Path(path)
if path.exists() and overwrite:
logger.info("Found existing GVL store, overwriting.")
Expand Down Expand Up @@ -147,7 +171,6 @@ def write(

logger.info(f"Using {len(samples)} samples.")
metadata["samples"] = samples
metadata["n_samples"] = len(samples)
metadata["n_regions"] = gvl_bed.height

if variants is not None:
Expand All @@ -172,8 +195,9 @@ def write(
for bw in bigwigs:
_write_bigwigs(path, gvl_bed, bw, samples, max_mem)

_metadata = Metadata(**metadata)
with open(path / "metadata.json", "w") as f:
json.dump(metadata, f)
json.dump(_metadata.model_dump(), f)

logger.info("Finished writing.")
warnings.simplefilter("default")
Expand Down