Skip to content

Commit a2feaa6

Browse files
committed
fix: track file format version
1 parent 27420c0 commit a2feaa6

File tree

5 files changed

+43
-11
lines changed

5 files changed

+43
-11
lines changed

pixi.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ plink2 = "*"
5858
hirola = "==0.3"
5959
seqpro = "==0.8.2"
6060
genoray = "==0.16.0"
61+
pydantic = ">=2,<3"
6162

6263
[feature.docs.dependencies]
6364
sphinx = ">=7.4.7"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies = [
2020
"pysam",
2121
"pyarrow",
2222
"pyranges",
23+
"pydantic>=2,<3",
2324
"more-itertools",
2425
"tqdm",
2526
"pybigwig",

python/genvarloader/_dataset/_impl.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import json
43
from collections.abc import Iterable, Sequence
54
from pathlib import Path
65
from typing import Callable, Generic, Literal, TypeVar, cast, overload
@@ -34,6 +33,7 @@
3433
from ._reconstruct import Haps, HapsTracks, Ref, RefTracks, Tracks, TrackType
3534
from ._reference import Reference
3635
from ._utils import bed_to_regions, regions_to_bed
36+
from ._write import Metadata
3737

3838
if TORCH_AVAILABLE:
3939
import torch
@@ -143,11 +143,11 @@ def open(
143143

144144
# read metadata
145145
with _py_open(path / "metadata.json") as f:
146-
metadata = json.load(f)
147-
samples: list[str] = metadata["samples"]
148-
contigs: list[str] = metadata["contigs"]
149-
ploidy: int | None = metadata.get("ploidy", None)
150-
max_jitter: int = metadata.get("max_jitter", 0)
146+
metadata = Metadata.model_validate_json(f.read())
147+
samples = metadata.samples
148+
contigs = metadata.contigs
149+
ploidy = metadata.ploidy
150+
max_jitter = metadata.max_jitter
151151

152152
# read input regions and generate index map
153153
bed = pl.read_ipc(path / "input_regions.arrow")
@@ -179,6 +179,7 @@ def open(
179179
regions=regions,
180180
samples=samples,
181181
ploidy=ploidy,
182+
version=metadata.version,
182183
min_af=min_af,
183184
max_af=max_af,
184185
)

python/genvarloader/_dataset/_reconstruct.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from loguru import logger
2525
from numpy.typing import NDArray
26+
from packaging.version import Version
2627
from seqpro.rag import OFFSET_TYPE, Ragged
2728
from tqdm.auto import tqdm
2829
from typing_extensions import assert_never
@@ -210,6 +211,7 @@ def from_path(
210211
regions: NDArray[np.int32],
211212
samples: list[str],
212213
ploidy: int,
214+
version: Version | None,
213215
min_af: float | None = None,
214216
max_af: float | None = None,
215217
) -> Haps[RaggedVariants]:
@@ -246,7 +248,10 @@ def from_path(
246248
)
247249
else:
248250
logger.info("Loading variant data.")
249-
variants = _Variants.from_table(path / "genotypes" / "variants.arrow")
251+
variants = _Variants.from_table(
252+
path / "genotypes" / "variants.arrow",
253+
one_based=version is not None and version >= Version("0.18.0"),
254+
)
250255
v_idxs = np.memmap(
251256
path / "genotypes" / "variant_idxs.npy",
252257
dtype=V_IDX_TYPE,

python/genvarloader/_dataset/_write.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import json
33
import shutil
44
import warnings
5+
from importlib.metadata import version
56
from pathlib import Path
6-
from typing import cast
7+
from typing import Annotated, Any, cast
78

89
import awkward as ak
910
import numpy as np
@@ -16,6 +17,8 @@
1617
from more_itertools import mark_ends
1718
from natsort import natsorted
1819
from numpy.typing import NDArray
20+
from packaging.version import Version
21+
from pydantic import BaseModel, BeforeValidator, PlainSerializer, WithJsonSchema
1922
from seqpro.rag import OFFSET_TYPE
2023
from tqdm.auto import tqdm
2124

@@ -26,6 +29,27 @@
2629
from ._utils import splits_sum_le_value
2730

2831

32+
class Metadata(BaseModel, arbitrary_types_allowed=True):
33+
samples: list[str]
34+
contigs: list[str]
35+
n_regions: int
36+
ploidy: int | None = None
37+
max_jitter: int = 0
38+
version: (
39+
Annotated[
40+
Version,
41+
BeforeValidator(lambda v: Version(v) if isinstance(v, str) else v),
42+
PlainSerializer(lambda v: str(v), return_type=str),
43+
WithJsonSchema({"type": "string"}, mode="serialization"),
44+
]
45+
| None
46+
) = None
47+
48+
@property
49+
def n_samples(self) -> int:
50+
return len(self.samples)
51+
52+
2953
def write(
3054
path: str | Path,
3155
bed: str | Path | pl.DataFrame,
@@ -77,7 +101,7 @@ def write(
77101

78102
max_mem = parse_memory(max_mem)
79103

80-
metadata = {}
104+
metadata: dict[str, Any] = {"version": Version(version("genvarloader"))}
81105
path = Path(path)
82106
if path.exists() and overwrite:
83107
logger.info("Found existing GVL store, overwriting.")
@@ -147,7 +171,6 @@ def write(
147171

148172
logger.info(f"Using {len(samples)} samples.")
149173
metadata["samples"] = samples
150-
metadata["n_samples"] = len(samples)
151174
metadata["n_regions"] = gvl_bed.height
152175

153176
if variants is not None:
@@ -172,8 +195,9 @@ def write(
172195
for bw in bigwigs:
173196
_write_bigwigs(path, gvl_bed, bw, samples, max_mem)
174197

198+
_metadata = Metadata(**metadata)
175199
with open(path / "metadata.json", "w") as f:
176-
json.dump(metadata, f)
200+
json.dump(_metadata.model_dump(), f)
177201

178202
logger.info("Finished writing.")
179203
warnings.simplefilter("default")

0 commit comments

Comments
 (0)