From b0cb60f7793a123df1f603a3fe916fbefadec1ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ezequiel=20Leonardo=20Casta=C3=B1o?= <14986783+ELC@users.noreply.github.com> Date: Sat, 24 May 2025 17:12:33 +0000 Subject: [PATCH 1/8] feat(pyspark): add Pydantic integration tests with PySpark MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implement tests for the integration between PySpark and Pydantic. - Create sample schema models and validate data using Pydantic. Signed-off-by: Ezequiel Leonardo Castaño <14986783+ELC@users.noreply.github.com> --- pandera/api/pyspark/__init__.py | 1 + pandera/typing/pyspark.py | 82 +++++++++++++++- pandera/typing/pyspark_sql.py | 98 ++++++++++++++++--- .../test_pyspark_pydantic_integration.py | 81 +++++++++++++++ 4 files changed, 247 insertions(+), 15 deletions(-) create mode 100644 tests/pyspark/test_pyspark_pydantic_integration.py diff --git a/pandera/api/pyspark/__init__.py b/pandera/api/pyspark/__init__.py index efc22d416..12080384f 100644 --- a/pandera/api/pyspark/__init__.py +++ b/pandera/api/pyspark/__init__.py @@ -2,3 +2,4 @@ from pandera.api.pyspark.components import Column from pandera.api.pyspark.container import DataFrameSchema +from pandera.api.pyspark.model import DataFrameModel diff --git a/pandera/typing/pyspark.py b/pandera/typing/pyspark.py index dd141744a..dbe09e13c 100644 --- a/pandera/typing/pyspark.py +++ b/pandera/typing/pyspark.py @@ -1,14 +1,22 @@ """Pandera type annotations for Pyspark Pandas.""" -from typing import TYPE_CHECKING, Generic, TypeVar +import functools +import json +from typing import TYPE_CHECKING, Generic, TypeVar, Any, get_args +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + +from pandera.engines import PYDANTIC_V2 +from pandera.errors import SchemaInitError from pandera.typing.common import ( DataFrameBase, GenericDtype, IndexBase, SeriesBase, + _GenericAlias, ) -from pandera.typing.pandas import DataFrameModel, _GenericAlias +from pandera.typing.pandas import DataFrameModel try: import pyspark.pandas as ps @@ -39,6 +47,76 @@ def __class_getitem__(cls, item): """Define this to override's pyspark.pandas generic type.""" return _GenericAlias(cls, item) + @classmethod + def pydantic_validate(cls, obj: Any, schema_model: T) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic >= v2 + """ + try: + schema = schema_model.to_schema() # type: ignore[attr-defined] + except SchemaInitError as exc: + error_message = ( + f"Cannot use {cls} as a pydantic type as its " + "DataFrameModel cannot be converted to a DataFrameSchema.\n" + f"Please revisit the model to address the following errors:" + f"\n{exc}" + ) + raise ValueError(error_message) from exc + + validated_data = schema.validate(obj) + + if validated_data.pandera.errors: + errors = json.dumps( + dict(validated_data.pandera.errors), indent=4 + ) + raise ValueError(errors) + + return validated_data + + if PYDANTIC_V2: + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + schema_model = get_args(_source_type)[0] + return core_schema.no_info_plain_validator_function( + functools.partial( + cls.pydantic_validate, + schema_model=schema_model, + ), + ) + + else: + + @classmethod + def __get_validators__(cls): + yield cls._pydantic_validate + + @classmethod + def _get_schema_model(cls, field): + if not field.sub_fields: + raise TypeError( + "Expected a typed pandera.typing.DataFrame," + " e.g. DataFrame[Schema]" + ) + schema_model = field.sub_fields[0].type_ + return schema_model + + @classmethod + def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic < v1 + """ + schema_model = cls._get_schema_model(field) + return cls.pydantic_validate(obj, schema_model) + # pylint:disable=too-few-public-methods,arguments-renamed class Series(SeriesBase, ps.Series, Generic[GenericDtype]): # type: ignore [misc] # noqa """Representation of pandas.Series, only used for type annotation. diff --git a/pandera/typing/pyspark_sql.py b/pandera/typing/pyspark_sql.py index 91cbcea35..b7a4a22fd 100644 --- a/pandera/typing/pyspark_sql.py +++ b/pandera/typing/pyspark_sql.py @@ -1,9 +1,16 @@ -"""Pandera type annotations for Pyspark.""" +"""Pandera type annotations for Pyspark SQL.""" -from typing import TypeVar, Union +import functools +import json +from typing import Union, TypeVar, Any, get_args, Generic -from pandera.typing.common import DataFrameBase -from pandera.typing.pandas import DataFrameModel, _GenericAlias +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + +from pandera.engines import pyspark_engine, PYDANTIC_V2 +from pandera.errors import SchemaInitError +from pandera.typing.common import DataFrameBase, _GenericAlias +from pandera.api.pyspark import DataFrameModel try: import pyspark.sql as ps @@ -12,9 +19,9 @@ except ImportError: # pragma: no cover PYSPARK_SQL_INSTALLED = False -if PYSPARK_SQL_INSTALLED: - from pandera.engines import pyspark_engine +T = TypeVar("T", bound=DataFrameModel) +if PYSPARK_SQL_INSTALLED: PysparkString = pyspark_engine.String PysparkInt = pyspark_engine.Int PysparkLongInt = pyspark_engine.BigInt @@ -43,13 +50,6 @@ PysparkBinary, # type: ignore ], ) - from typing import TYPE_CHECKING, Generic - - # pylint:disable=invalid-name - if TYPE_CHECKING: - T = TypeVar("T") # pragma: no cover - else: - T = DataFrameModel if PYSPARK_SQL_INSTALLED: # pylint: disable=too-few-public-methods,arguments-renamed @@ -64,3 +64,75 @@ class DataFrame(DataFrameBase, ps.DataFrame, Generic[T]): def __class_getitem__(cls, item): """Define this to override's pyspark.pandas generic type.""" return _GenericAlias(cls, item) # pragma: no cover + + @classmethod + def pydantic_validate( + cls, obj: ps.DataFrame, schema_model: T + ) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic V1 and V2. + """ + try: + schema = schema_model.to_schema() + except SchemaInitError as exc: + error_message = ( + f"Cannot use {cls} as a pydantic type as its " + "DataFrameModel cannot be converted to a DataFrameSchema.\n" + f"Please revisit the model to address the following errors:" + f"\n{exc}" + ) + raise ValueError(error_message) from exc + + validated_data = schema.validate(obj) + + if validated_data.pandera.errors: + errors = json.dumps( + dict(validated_data.pandera.errors), indent=4 + ) + raise ValueError(errors) + + return validated_data + + if PYDANTIC_V2: + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + schema_model = get_args(_source_type)[0] + return core_schema.no_info_plain_validator_function( + functools.partial( + cls.pydantic_validate, + schema_model=schema_model, + ), + ) + + else: + + @classmethod + def __get_validators__(cls): + yield cls._pydantic_validate + + @classmethod + def _get_schema_model(cls, field): + if not field.sub_fields: + raise TypeError( + "Expected a typed pandera.typing.DataFrame," + " e.g. DataFrame[Schema]" + ) + schema_model = field.sub_fields[0].type_ + return schema_model + + @classmethod + def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic v1 + """ + schema_model = cls._get_schema_model(field) + return cls.pydantic_validate(obj, schema_model) diff --git a/tests/pyspark/test_pyspark_pydantic_integration.py b/tests/pyspark/test_pyspark_pydantic_integration.py new file mode 100644 index 000000000..0db32240d --- /dev/null +++ b/tests/pyspark/test_pyspark_pydantic_integration.py @@ -0,0 +1,81 @@ +"""Tests for the integration between PySpark and Pydantic.""" + +import pytest +from pydantic import BaseModel, ValidationError +from pyspark.testing.utils import assertDataFrameEqual +import pyspark.sql.types as T + +import pandera.pyspark as pa +from pandera.typing.pyspark_sql import DataFrame as PySparkSQLDataFrame +from pandera.typing.pyspark import DataFrame as PySparkDataFrame +from pandera.pyspark import DataFrameModel + + +@pytest.fixture +def sample_schema_model(): + class SampleSchema(DataFrameModel): + """ + Sample schema model with data checks. + """ + + product: T.StringType() = pa.Field() + price: T.IntegerType() = pa.Field() + + return SampleSchema + + +@pytest.fixture( + params=[PySparkDataFrame, PySparkSQLDataFrame], + ids=["pyspark", "pyspark_sql"], +) +def pydantic_container(request, sample_schema_model): + TypingClass = request.param + + class PydanticContainer(BaseModel): + """ + Pydantic container with a DataFrameModel as a field. + """ + + data: TypingClass[sample_schema_model] + + return PydanticContainer + + +@pytest.fixture +def correct_data(spark, sample_data, sample_spark_schema): + """ + Correct data that should pass validation. + """ + return spark.createDataFrame(sample_data, sample_spark_schema) + + +@pytest.fixture +def incorrect_data(spark): + """ + Incorrect data that should fail validation. + """ + data = [ + (1, "Apples"), + (2, "Bananas"), + ] + return spark.createDataFrame(data, ["product", "price"]) + + +def test_pydantic_model_instantiates_with_correct_data( + correct_data, pydantic_container +): + """ + Test that a Pydantic model can be instantiated with a DataFrameModel when data is valid. + """ + my_container = pydantic_container(data=correct_data) + assertDataFrameEqual(my_container.data, correct_data) + + +def test_pydantic_model_throws_validation_error_with_incorrect_data( + incorrect_data, pydantic_container +): + """ + Test that a Pydantic model throws a ValidationError when data is invalid. + """ + with pytest.raises(ValidationError): + pydantic_container(data=incorrect_data) From 86c4d40c77c3f85268e5a9fe105e8183bef275d2 Mon Sep 17 00:00:00 2001 From: cosmicBboy Date: Mon, 26 May 2025 20:10:42 -0400 Subject: [PATCH 2/8] remove unused txt file Signed-off-by: cosmicBboy --- foo.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 foo.txt diff --git a/foo.txt b/foo.txt deleted file mode 100644 index e69de29bb..000000000 From c594ea1a6f923f949818aad037879e247b594b75 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Tue, 27 May 2025 07:43:53 -0600 Subject: [PATCH 3/8] Delete previously-added foo.txt and new_example.py (#2013) * Delete foo.txt Signed-off-by: Deepyaman Datta * Delete new_example.py Signed-off-by: Deepyaman Datta --------- Signed-off-by: Deepyaman Datta --- new_example.py | 44 -------------------------------------------- 1 file changed, 44 deletions(-) delete mode 100644 new_example.py diff --git a/new_example.py b/new_example.py deleted file mode 100644 index 941ee4c67..000000000 --- a/new_example.py +++ /dev/null @@ -1,44 +0,0 @@ -"""New example.""" - -import pandas as pd -import pandera.pandas as pa - -# data to validate -df = pd.DataFrame( - { - "column1": [1, 2, 3], - "column2": [1.1, 1.2, 1.3], - "column3": ["a", "b", "c"], - } -) - -schema = pa.DataFrameSchema( - { - "column1": pa.Column(int, pa.Check.ge(0)), - "column2": pa.Column(float, pa.Check.lt(10)), - "column3": pa.Column( - str, - [ - pa.Check.isin([*"abc"]), - pa.Check(lambda series: series.str.len() == 1), - ], - ), - } -) - -print(schema.validate(df)) - - -# define DataFrameModel Schema -class Schema(pa.DataFrameModel): - column1: int = pa.Field(ge=0) - column2: float = pa.Field(lt=10) - column3: str = pa.Field(isin=[*"abc"]) - - @pa.check("column3") - @classmethod - def custom_check(cls, series: pd.Series) -> pd.Series: - return series.str.len() == 1 - - -print(Schema.validate(df)) From bbff287ddebee935d272dde0dacf814fbae9ad48 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Tue, 27 May 2025 12:05:02 -0600 Subject: [PATCH 4/8] Pin PySpark due to test failures/incompatibilities (#2010) Signed-off-by: Deepyaman Datta --- environment.yml | 2 +- pyproject.toml | 4 ++-- requirements.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/environment.yml b/environment.yml index 34f55c0b1..322808699 100644 --- a/environment.yml +++ b/environment.yml @@ -23,7 +23,7 @@ dependencies: - pandas-stubs # pyspark extra - - pyspark[connect] >= 3.2.0 + - pyspark[connect] >= 3.2.0, < 4.0.0 # polars extra - polars >= 0.20.0 diff --git a/pyproject.toml b/pyproject.toml index ac933496a..b27b9bc25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ geopandas = [ "geopandas", "shapely", ] -pyspark = ["pyspark[connect] >= 3.2.0"] +pyspark = ["pyspark[connect] >= 3.2.0, < 4.0.0"] modin = [ "modin", "ray", @@ -92,7 +92,7 @@ all = [ "pyyaml >= 5.1", "black", "frictionless <= 4.40.8", - "pyspark[connect] >= 3.2.0", + "pyspark[connect] >= 3.2.0, < 4.0.0", "modin", "ray", "dask[dataframe]", diff --git a/requirements.txt b/requirements.txt index 98742c707..e11dcc310 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ pyarrow pydantic scipy pandas-stubs -pyspark[connect] >= 3.2.0 +pyspark[connect] >= 3.2.0, < 4.0.0 polars >= 0.20.0 modin protobuf From 08ab1fdc82bfdae0fb36e5d8add72ee6d29ac5ee Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Tue, 27 May 2025 12:05:11 -0600 Subject: [PATCH 5/8] Temporarily pin `polars` due to test failure in CI (#2011) Signed-off-by: Deepyaman Datta --- noxfile.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/noxfile.py b/noxfile.py index 3161555e3..107d741b6 100644 --- a/noxfile.py +++ b/noxfile.py @@ -152,12 +152,12 @@ def _testing_requirements( if req.startswith("numpy") and _numpy is not None: print("adding numpy constraint <2") req = f"{req}, {_numpy}" - if ( - req == "polars" - or req.startswith("polars ") - and sys.platform == "darwin" - ): - req = "polars-lts-cpu" + if req == "polars" or req.startswith("polars "): + # TODO(deepyaman): Support latest Polars. + if sys.platform == "darwin": + req = "polars-lts-cpu < 1.30.0" + else: + req = "polars < 1.30.0" # for some reason uv will try to install an old version of dask, # have to specifically pin dask[dataframe] to a higher version if ( From 59b58ea194c48740d4ff8b39e732c8bd43e1c42f Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Tue, 27 May 2025 19:09:11 -0600 Subject: [PATCH 6/8] Replace `event_loop` removed in pytest-asyncio 1.0 (#2014) Signed-off-by: Deepyaman Datta --- tests/pandas/test_decorators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pandas/test_decorators.py b/tests/pandas/test_decorators.py index 06944ef37..72d079b93 100644 --- a/tests/pandas/test_decorators.py +++ b/tests/pandas/test_decorators.py @@ -1,8 +1,8 @@ """Testing the Decorators that check a functions input or output.""" +import asyncio import pickle import typing -from asyncio import AbstractEventLoop import numpy as np import pandas as pd @@ -1095,7 +1095,7 @@ def star_args_kwargs( pd.testing.assert_frame_equal(expected, actual) -def test_coroutines(event_loop: AbstractEventLoop) -> None: +def test_coroutines() -> None: # pylint: disable=missing-class-docstring,too-few-public-methods,missing-function-docstring class Schema(DataFrameModel): col1: Series[int] @@ -1192,7 +1192,7 @@ async def check_coros() -> None: with pytest.raises(errors.SchemaError): await coro(bad_df) - event_loop.run_until_complete(check_coros()) + asyncio.get_event_loop().run_until_complete(check_coros()) class Schema(DataFrameModel): From a2b22389b9ffb79bf8477013db867bbb411a4d56 Mon Sep 17 00:00:00 2001 From: Ahmet Zamanis Date: Thu, 29 May 2025 23:01:09 +0300 Subject: [PATCH 7/8] Fix typehint in unique_values_eq (issue #1492) (#2015) * Fix typehint in unique_values_eq Signed-off-by: Ahmet Zamanis * Fix typo in unique_values_eq docstring Signed-off-by: Ahmet Zamanis --------- Signed-off-by: Ahmet Zamanis --- pandera/api/checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandera/api/checks.py b/pandera/api/checks.py index 215e7dd05..ae2788b3d 100644 --- a/pandera/api/checks.py +++ b/pandera/api/checks.py @@ -531,14 +531,14 @@ def str_length( ) @classmethod - def unique_values_eq(cls, values: str, **kwargs) -> "Check": + def unique_values_eq(cls, values: Iterable, **kwargs) -> "Check": """Ensure that unique values in the data object contain all values. .. note:: In contrast with :func:`isin`, this check makes sure that all the items in the ``values`` iterable are contained within the series. - :param values: The set of values that must be present. Maybe any iterable. + :param values: The set of values that must be present. May be any iterable. """ try: values_mod = frozenset(values) From c942ba54695bb4f9df57a5c87b948f5ff3f68af4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ezequiel=20Leonardo=20Casta=C3=B1o?= <14986783+ELC@users.noreply.github.com> Date: Sun, 1 Jun 2025 03:01:28 +0000 Subject: [PATCH 8/8] fix(tests): disable ANSI SQL checks in Spark sessions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Set `spark.sql.ansi.enabled` to False in Spark session fixtures to prevent SQL compatibility issues during tests. Signed-off-by: Ezequiel Leonardo Castaño <14986783+ELC@users.noreply.github.com> --- tests/pyspark/conftest.py | 2 ++ tests/pyspark/test_pyspark_accessor.py | 1 + tests/pyspark/test_schemas_on_pyspark_pandas.py | 2 ++ 3 files changed, 5 insertions(+) diff --git a/tests/pyspark/conftest.py b/tests/pyspark/conftest.py index 4ae06ea50..09f129fdc 100644 --- a/tests/pyspark/conftest.py +++ b/tests/pyspark/conftest.py @@ -17,6 +17,7 @@ def spark() -> SparkSession: creates spark session """ spark: SparkSession = SparkSession.builder.getOrCreate() + spark.conf.set("spark.sql.ansi.enabled", False) yield spark spark.stop() @@ -29,6 +30,7 @@ def spark_connect() -> SparkSession: # Set location of localhost Spark Connect server os.environ["SPARK_LOCAL_REMOTE"] = "sc://localhost" spark: SparkSession = SparkSession.builder.getOrCreate() + spark.conf.set("spark.sql.ansi.enabled", False) yield spark spark.stop() diff --git a/tests/pyspark/test_pyspark_accessor.py b/tests/pyspark/test_pyspark_accessor.py index cf2fbc4b0..2f5e7d070 100644 --- a/tests/pyspark/test_pyspark_accessor.py +++ b/tests/pyspark/test_pyspark_accessor.py @@ -12,6 +12,7 @@ from pandera.pyspark import pyspark_sql_accessor spark = SparkSession.builder.getOrCreate() +spark.conf.set("spark.sql.ansi.enabled", False) @pytest.mark.parametrize( diff --git a/tests/pyspark/test_schemas_on_pyspark_pandas.py b/tests/pyspark/test_schemas_on_pyspark_pandas.py index c39ad3981..e0a5581fe 100644 --- a/tests/pyspark/test_schemas_on_pyspark_pandas.py +++ b/tests/pyspark/test_schemas_on_pyspark_pandas.py @@ -278,8 +278,10 @@ def test_index_dtypes( not in { pandas_engine.Engine.dtype(pandas_engine.BOOL), pandas_engine.DateTime(tz="UTC"), # type: ignore[call-arg] + pandas_engine.Engine.dtype(pa.dtypes.Timedelta), # type: ignore[call-arg] } ], + ids=lambda x: str(x) ) @hypothesis.given(st.data()) def test_nullable(