Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chap_core/adaptors/command_line_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _get_dataclass(estimator):

def _read_model_config(model_config_path):
if model_config_path is not None:
with open(model_config_path, "r") as file:
with open(model_config_path) as file:
model_config = yaml.safe_load(file)
# model_config = model_template.get_config_class().parse_file(model_config_path)
else:
Expand Down
2 changes: 1 addition & 1 deletion chap_core/alarms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable
from collections.abc import Iterable

import numpy as np
from pydantic import BaseModel
Expand Down
7 changes: 3 additions & 4 deletions chap_core/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
import logging
from typing import List, Optional

from .assessment.forecast import forecast as do_forecast
from .datatypes import (
Expand Down Expand Up @@ -38,8 +37,8 @@ class PredictionData:
health_data: DataSet[HealthData] | None = None
climate_data: DataSet[ClimateData] | None = None
population_data: DataSet[HealthPopulationData] | None = None
disease_id: Optional[str] = None
features: List[object] | None = None
disease_id: str | None = None
features: list[object] | None = None


def extract_disease_name(health_data: dict) -> str:
Expand All @@ -50,7 +49,7 @@ def forecast(
model_name: str,
dataset_name: DataSetType,
n_months: int,
model_path: Optional[str] = None,
model_path: str | None = None,
):
logging.basicConfig(level=logging.INFO)
dataset = datasets[dataset_name].load()
Expand Down
16 changes: 8 additions & 8 deletions chap_core/api_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@


class FeatureModel(_FeatureModel):
id: Optional[str] = None
properties: Optional[dict[str, Any]] = Field(default_factory=dict) # type: ignore[assignment]
geometry: Union[
PointModel, MultiPointModel, LineStringModel, MultiLineStringModel, PolygonModel, MultiPolygonModel, None
] = None
id: str | None = None
properties: dict[str, Any] | None = Field(default_factory=dict) # type: ignore[assignment]
geometry: (
PointModel | MultiPointModel | LineStringModel | MultiLineStringModel | PolygonModel | MultiPolygonModel | None
) = None


class FeatureCollectionModel(_FeatureCollectionModel):
Expand All @@ -35,7 +35,7 @@ class FeatureCollectionModel(_FeatureCollectionModel):
class DataElement(BaseModel):
pe: str
ou: str
value: Optional[float]
value: float | None


class DataList(BaseModel):
Expand All @@ -47,7 +47,7 @@ class DataList(BaseModel):
class DataElementV2(BaseModel):
period: str
orgUnit: str
value: Optional[float]
value: float | None


class DataListV2(BaseModel):
Expand Down Expand Up @@ -92,7 +92,7 @@ class RunConfig(BaseModel):

ignore_environment: bool = False
debug: bool = False
log_file: Optional[str] = None
log_file: str | None = None
run_directory_type: Literal["latest", "timestamp", "use_existing"] = "timestamp"
is_chapkit_model: bool = False

Expand Down
19 changes: 10 additions & 9 deletions chap_core/assessment/backtest_plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,20 @@ def plot(
- disease_cases: float - Historical observed disease cases
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Dict, Optional, Type, Union

import altair as alt
import pandas as pd

from chap_core.database.tables import BackTest

# Type alias for Altair chart types that plots can return
ChartType = Union[alt.Chart, alt.VConcatChart, alt.FacetChart, alt.LayerChart, alt.HConcatChart]
ChartType = alt.Chart | alt.VConcatChart | alt.FacetChart | alt.LayerChart | alt.HConcatChart

# Global registry for backtest plots
_backtest_plots_registry: Dict[str, Type["BacktestPlotBase"]] = {}
_backtest_plots_registry: dict[str, type[BacktestPlotBase]] = {}


class BacktestPlotBase(ABC):
Expand All @@ -86,7 +87,7 @@ def plot(
self,
observations: pd.DataFrame,
forecasts: pd.DataFrame,
historical_observations: Optional[pd.DataFrame] = None,
historical_observations: pd.DataFrame | None = None,
) -> ChartType:
"""
Generate the visualization from flat DataFrames.
Expand Down Expand Up @@ -143,7 +144,7 @@ def plot(self, observations, forecasts, historical_observations=None):
...
"""

def decorator(cls: Type[BacktestPlotBase]) -> Type[BacktestPlotBase]:
def decorator(cls: type[BacktestPlotBase]) -> type[BacktestPlotBase]:
if not issubclass(cls, BacktestPlotBase):
raise TypeError(f"{cls.__name__} must inherit from BacktestPlotBase")

Expand All @@ -158,7 +159,7 @@ def decorator(cls: Type[BacktestPlotBase]) -> Type[BacktestPlotBase]:
return decorator


def get_backtest_plots_registry() -> Dict[str, Type[BacktestPlotBase]]:
def get_backtest_plots_registry() -> dict[str, type[BacktestPlotBase]]:
"""
Get the registry of all registered backtest plots.

Expand All @@ -170,7 +171,7 @@ def get_backtest_plots_registry() -> Dict[str, Type[BacktestPlotBase]]:
return _backtest_plots_registry.copy()


def get_backtest_plot(plot_id: str) -> Optional[Type[BacktestPlotBase]]:
def get_backtest_plot(plot_id: str) -> type[BacktestPlotBase] | None:
"""
Get a specific backtest plot class by ID.

Expand Down Expand Up @@ -247,7 +248,7 @@ def create_plot_from_backtest(plot_id: str, backtest: BackTest) -> ChartType:
forecasts_df: pd.DataFrame = flat_data.forecasts # type: ignore[assignment]

# Get historical observations if the plot needs them
historical_df: Optional[pd.DataFrame] = None
historical_df: pd.DataFrame | None = None
if plot_cls.needs_historical and flat_data.historical_observations is not None:
historical_df = flat_data.historical_observations # type: ignore[assignment]

Expand Down Expand Up @@ -292,7 +293,7 @@ def create_plot_from_evaluation(plot_id: str, evaluation) -> ChartType:
forecasts_df: pd.DataFrame = flat_data.forecasts # type: ignore[assignment]

# Get historical observations if the plot needs them
historical_df: Optional[pd.DataFrame] = None
historical_df: pd.DataFrame | None = None
if plot_cls.needs_historical and flat_data.historical_observations is not None:
historical_df = flat_data.historical_observations # type: ignore[assignment]

Expand Down
2 changes: 1 addition & 1 deletion chap_core/assessment/backtest_plots/evaluation_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def plot(
self,
observations: pd.DataFrame,
forecasts: pd.DataFrame,
historical_observations: Optional[pd.DataFrame] = None,
historical_observations: pd.DataFrame | None = None,
) -> ChartType:
"""
Generate and return the evaluation visualization.
Expand Down
2 changes: 1 addition & 1 deletion chap_core/assessment/backtest_plots/metrics_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def plot(
self,
observations: pd.DataFrame,
forecasts: pd.DataFrame,
historical_observations: Optional[pd.DataFrame] = None,
historical_observations: pd.DataFrame | None = None,
) -> ChartType:
"""
Generate and return the dashboard visualization.
Expand Down
2 changes: 1 addition & 1 deletion chap_core/assessment/backtest_plots/sample_bias_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def plot(
self,
observations: pd.DataFrame,
forecasts: pd.DataFrame,
historical_observations: Optional[pd.DataFrame] = None,
historical_observations: pd.DataFrame | None = None,
) -> ChartType:
"""
Generate and return the dashboard visualization.
Expand Down
11 changes: 5 additions & 6 deletions chap_core/assessment/data_representation_transforming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools
from collections import defaultdict
from typing import Dict, List

from chap_core.assessment.evaluator import Evaluator
from chap_core.assessment.representations import (
Expand All @@ -18,7 +17,7 @@
from chap_core.database.tables import BackTestForecast


def convert_to_multi_location_forecast(backTestList: List[BackTestForecast]) -> Dict[str, MultiLocationForecast]:
def convert_to_multi_location_forecast(backTestList: list[BackTestForecast]) -> dict[str, MultiLocationForecast]:
# Group samples by location
all_splitpoint_timeseries = {}
backTestList = sorted(backTestList, key=lambda x: x.last_seen_period)
Expand All @@ -29,8 +28,8 @@ def convert_to_multi_location_forecast(backTestList: List[BackTestForecast]) ->
return all_splitpoint_timeseries


def convert_single_splitpoint_to_multi_location_forecast(backTestList: List[BackTestForecast]) -> MultiLocationForecast:
location_forecasts: Dict[str, List[Samples]] = defaultdict(list)
def convert_single_splitpoint_to_multi_location_forecast(backTestList: list[BackTestForecast]) -> MultiLocationForecast:
location_forecasts: dict[str, list[Samples]] = defaultdict(list)
for forecast in backTestList:
location_key = str(forecast.org_unit) # Or use forecast.backtest.location if available

Expand All @@ -49,8 +48,8 @@ def convert_single_splitpoint_to_multi_location_forecast(backTestList: List[Back
return MultiLocationForecast(timeseries=timeseries)


def convert_to_multi_location_timeseries(obs: List[ObservationBase]) -> MultiLocationDiseaseTimeSeries:
grouped: defaultdict[str, List[DiseaseObservation]] = defaultdict(list)
def convert_to_multi_location_timeseries(obs: list[ObservationBase]) -> MultiLocationDiseaseTimeSeries:
grouped: defaultdict[str, list[DiseaseObservation]] = defaultdict(list)

for ob in obs:
if ob.feature_name == "disease_cases" and ob.value is not None:
Expand Down
15 changes: 8 additions & 7 deletions chap_core/assessment/dataset_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
slides forward.
"""

from typing import Iterable, Iterator, Optional, Protocol, Type
from collections.abc import Iterable, Iterator
from typing import Protocol

from chap_core.climate_predictor import FutureWeatherFetcher
from chap_core.datatypes import ClimateData
Expand All @@ -23,9 +24,9 @@ class IsTimeDelta(Protocol):
def split_test_train_on_period(
data_set: DataSet,
split_points: Iterable[TimePeriod],
future_length: Optional[IsTimeDelta] = None,
future_length: IsTimeDelta | None = None,
include_future_weather: bool = False,
future_weather_class: Type[ClimateData] = ClimateData,
future_weather_class: type[ClimateData] = ClimateData,
):
"""Generate train/test splits at each split point.

Expand Down Expand Up @@ -63,7 +64,7 @@ def split_test_train_on_period(
def train_test_split(
data_set: DataSet,
prediction_start_period: TimePeriod,
extension: Optional[IsTimeDelta] = None,
extension: IsTimeDelta | None = None,
restrict_test=True,
):
"""Split a dataset into train and test sets at a single split point.
Expand Down Expand Up @@ -105,7 +106,7 @@ def train_test_generator(
prediction_length: int,
n_test_sets: int = 1,
stride: int = 1,
future_weather_provider: Optional[FutureWeatherFetcher] = None,
future_weather_provider: FutureWeatherFetcher | None = None,
) -> tuple[DataSet, Iterator[tuple[DataSet, DataSet, DataSet]]]:
"""Generate expanding-window train/test splits for backtesting.

Expand Down Expand Up @@ -182,8 +183,8 @@ def train_test_generator(
def train_test_split_with_weather(
data_set: DataSet,
prediction_start_period: TimePeriod,
extension: Optional[IsTimeDelta] = None,
future_weather_class: Type[ClimateData] = ClimateData,
extension: IsTimeDelta | None = None,
future_weather_class: type[ClimateData] = ClimateData,
):
train_set, test_set = train_test_split(data_set, prediction_start_period, extension)
future_weather = test_set.remove_field("disease_cases")
Expand Down
Loading