@@ -269,6 +269,7 @@ def with_settings(
269269 rc_neg : bool | None = None ,
270270 min_af : float | Literal [False ] | None = None ,
271271 max_af : float | Literal [False ] | None = None ,
272+ var_fields : list [str ] | None = None ,
272273 ) -> Self :
273274 """Modify settings of the dataset, returning a new dataset without modifying the old one.
274275
@@ -291,6 +292,8 @@ def with_settings(
291292 max_af
292293 The maximum allele frequency to include in the dataset. If set to :code:`False`, disables this filter.
293294 If dataset is not backed by SVAR genotypes, this will raise an error.
295+ var_fields
296+ The variant fields to include in the dataset.
294297 """
295298 to_evolve = {}
296299
@@ -342,16 +345,26 @@ def with_settings(
342345 elif max_af is False :
343346 max_af = None
344347
345- haps = evolve (self ._seqs , min_af = min_af , max_af = max_af )
348+ haps = to_evolve .get ("_seqs" , self ._seqs )
349+ haps = evolve (haps , min_af = min_af , max_af = max_af )
346350 to_evolve ["_seqs" ] = haps
347351
352+ if var_fields is not None :
353+ missing = list (set (var_fields ) - set (self .available_var_fields ))
354+ if missing or not isinstance (self ._seqs , Haps ):
355+ raise ValueError (f"Missing variant fields: { missing } " )
356+ haps = to_evolve .get ("_seqs" , self ._seqs )
357+ haps = evolve (haps , var_fields = var_fields )
358+ to_evolve ["_seqs" ] = haps
359+
360+ if "_seqs" in to_evolve :
361+ haps = to_evolve ["_seqs" ]
348362 if isinstance (self ._recon , Haps ):
349363 recon = haps
364+ to_evolve ["_recon" ] = recon
350365 elif isinstance (self ._recon , HapsTracks ):
351366 recon = evolve (self ._recon , haps = haps )
352- else :
353- recon = self ._recon
354- to_evolve ["_recon" ] = recon
367+ to_evolve ["_recon" ] = recon
355368
356369 return evolve (self , ** to_evolve )
357370
@@ -728,6 +741,24 @@ def full_shape(self) -> tuple[int, int]:
728741 """Return the full shape of the dataset, ignoring any subsetting. :code:`(n_samples, n_regions)`"""
729742 return self ._idxer .full_shape
730743
744+ @property
745+ def available_var_fields (self ) -> list [str ]:
746+ """Available variant fields."""
747+ match self ._seqs :
748+ case Haps ():
749+ return self ._seqs .available_var_fields
750+ case _:
751+ return []
752+
753+ @property
754+ def active_var_fields (self ) -> list [str ]:
755+ """Active variant fields."""
756+ match self ._recon :
757+ case (Haps () as haps ) | HapsTracks (haps = haps ):
758+ return haps .var_fields
759+ case _:
760+ return []
761+
731762 @property
732763 def available_tracks (self ) -> list [str ] | None :
733764 """The available tracks in the dataset."""
@@ -1352,7 +1383,7 @@ def __getitem__(
13521383
13531384 if squeeze :
13541385 # (1 [p] l) -> ([p] l)
1355- recon = tuple (o .squeeze (0 ) for o in recon )
1386+ recon = tuple (o .squeeze (axis = 0 ) for o in recon )
13561387
13571388 if unlist :
13581389 recon = recon [0 ]
@@ -1538,28 +1569,32 @@ def with_seqs(
15381569 return super ().with_seqs (kind )
15391570
15401571 @overload
1541- def with_tracks (self , tracks : None = ... , kind : None = ... ) -> Self : ...
1572+ def with_tracks (self , tracks : None = None , kind : None = None ) -> Self : ...
15421573 @overload
15431574 def with_tracks (
1544- self , tracks : None = ... , kind : Literal ["tracks" ] = ...
1575+ self , * , tracks : None = None , kind : Literal ["tracks" ]
15451576 ) -> ArrayDataset [MaybeSEQ , NDArray [np .float32 ]]: ...
15461577 @overload
15471578 def with_tracks (
1548- self , tracks : None = ... , kind : Literal ["intervals" ] = ...
1579+ self , * , tracks : None = None , kind : Literal ["intervals" ]
15491580 ) -> ArrayDataset [MaybeSEQ , RaggedIntervals ]: ...
15501581 @overload
15511582 def with_tracks (
15521583 self ,
1553- tracks : Literal [False ] = ... ,
1554- kind : Literal ["tracks" , "intervals" ] | None = ... ,
1584+ tracks : Literal [False ],
1585+ kind : Literal ["tracks" , "intervals" ] | None = None ,
15551586 ) -> ArrayDataset [MaybeSEQ , None ]: ...
15561587 @overload
15571588 def with_tracks (
1558- self , tracks : str | list [str ] = ..., kind : Literal ["tracks" ] = ...
1589+ self , tracks : str | list [str ], kind : None = None
1590+ ) -> ArrayDataset [MaybeSEQ , MaybeTRK ]: ...
1591+ @overload
1592+ def with_tracks (
1593+ self , tracks : str | list [str ], kind : Literal ["tracks" ]
15591594 ) -> ArrayDataset [MaybeSEQ , NDArray [np .float32 ]]: ...
15601595 @overload
15611596 def with_tracks (
1562- self , tracks : str | list [str ] = ... , kind : Literal ["intervals" ] = ...
1597+ self , tracks : str | list [str ], kind : Literal ["intervals" ]
15631598 ) -> ArrayDataset [MaybeSEQ , RaggedIntervals ]: ...
15641599 def with_tracks (
15651600 self ,
@@ -1684,28 +1719,32 @@ def with_seqs(
16841719 return super ().with_seqs (kind )
16851720
16861721 @overload
1687- def with_tracks (self , tracks : None = ... , kind : None = ... ) -> Self : ...
1722+ def with_tracks (self , tracks : None = None , kind : None = None ) -> Self : ...
16881723 @overload
16891724 def with_tracks (
1690- self , tracks : None = ... , kind : Literal ["tracks" ] = ...
1725+ self , * , tracks : None = None , kind : Literal ["tracks" ]
16911726 ) -> RaggedDataset [MaybeRSEQ , RaggedTracks ]: ...
16921727 @overload
16931728 def with_tracks (
1694- self , tracks : None = ... , kind : Literal ["intervals" ] = ...
1729+ self , * , tracks : None = None , kind : Literal ["intervals" ]
16951730 ) -> RaggedDataset [MaybeRSEQ , RaggedIntervals ]: ...
16961731 @overload
16971732 def with_tracks (
16981733 self ,
1699- tracks : Literal [False ] = ... ,
1700- kind : Literal ["tracks" , "intervals" ] | None = ... ,
1734+ tracks : Literal [False ],
1735+ kind : Literal ["tracks" , "intervals" ] | None = None ,
17011736 ) -> RaggedDataset [MaybeRSEQ , None ]: ...
17021737 @overload
17031738 def with_tracks (
1704- self , tracks : str | list [str ] = ..., kind : Literal ["tracks" ] = ...
1739+ self , tracks : str | list [str ], kind : None = None
1740+ ) -> RaggedDataset [MaybeRSEQ , MaybeRTRK ]: ...
1741+ @overload
1742+ def with_tracks (
1743+ self , tracks : str | list [str ], kind : Literal ["tracks" ]
17051744 ) -> RaggedDataset [MaybeRSEQ , RaggedTracks ]: ...
17061745 @overload
17071746 def with_tracks (
1708- self , tracks : str | list [str ] = ... , kind : Literal ["intervals" ] = ...
1747+ self , tracks : str | list [str ], kind : Literal ["intervals" ]
17091748 ) -> RaggedDataset [MaybeRSEQ , RaggedIntervals ]: ...
17101749 def with_tracks (
17111750 self ,
0 commit comments