diff --git a/pandera/api/pyspark/__init__.py b/pandera/api/pyspark/__init__.py index efc22d416..12080384f 100644 --- a/pandera/api/pyspark/__init__.py +++ b/pandera/api/pyspark/__init__.py @@ -2,3 +2,4 @@ from pandera.api.pyspark.components import Column from pandera.api.pyspark.container import DataFrameSchema +from pandera.api.pyspark.model import DataFrameModel diff --git a/pandera/typing/pyspark.py b/pandera/typing/pyspark.py index dd141744a..dbe09e13c 100644 --- a/pandera/typing/pyspark.py +++ b/pandera/typing/pyspark.py @@ -1,14 +1,22 @@ """Pandera type annotations for Pyspark Pandas.""" -from typing import TYPE_CHECKING, Generic, TypeVar +import functools +import json +from typing import TYPE_CHECKING, Generic, TypeVar, Any, get_args +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + +from pandera.engines import PYDANTIC_V2 +from pandera.errors import SchemaInitError from pandera.typing.common import ( DataFrameBase, GenericDtype, IndexBase, SeriesBase, + _GenericAlias, ) -from pandera.typing.pandas import DataFrameModel, _GenericAlias +from pandera.typing.pandas import DataFrameModel try: import pyspark.pandas as ps @@ -39,6 +47,76 @@ def __class_getitem__(cls, item): """Define this to override's pyspark.pandas generic type.""" return _GenericAlias(cls, item) + @classmethod + def pydantic_validate(cls, obj: Any, schema_model: T) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic >= v2 + """ + try: + schema = schema_model.to_schema() # type: ignore[attr-defined] + except SchemaInitError as exc: + error_message = ( + f"Cannot use {cls} as a pydantic type as its " + "DataFrameModel cannot be converted to a DataFrameSchema.\n" + f"Please revisit the model to address the following errors:" + f"\n{exc}" + ) + raise ValueError(error_message) from exc + + validated_data = schema.validate(obj) + + if validated_data.pandera.errors: + errors = json.dumps( + dict(validated_data.pandera.errors), indent=4 + ) + raise ValueError(errors) + + return validated_data + + if PYDANTIC_V2: + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + schema_model = get_args(_source_type)[0] + return core_schema.no_info_plain_validator_function( + functools.partial( + cls.pydantic_validate, + schema_model=schema_model, + ), + ) + + else: + + @classmethod + def __get_validators__(cls): + yield cls._pydantic_validate + + @classmethod + def _get_schema_model(cls, field): + if not field.sub_fields: + raise TypeError( + "Expected a typed pandera.typing.DataFrame," + " e.g. DataFrame[Schema]" + ) + schema_model = field.sub_fields[0].type_ + return schema_model + + @classmethod + def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic < v1 + """ + schema_model = cls._get_schema_model(field) + return cls.pydantic_validate(obj, schema_model) + # pylint:disable=too-few-public-methods,arguments-renamed class Series(SeriesBase, ps.Series, Generic[GenericDtype]): # type: ignore [misc] # noqa """Representation of pandas.Series, only used for type annotation. diff --git a/pandera/typing/pyspark_sql.py b/pandera/typing/pyspark_sql.py index 91cbcea35..b7a4a22fd 100644 --- a/pandera/typing/pyspark_sql.py +++ b/pandera/typing/pyspark_sql.py @@ -1,9 +1,16 @@ -"""Pandera type annotations for Pyspark.""" +"""Pandera type annotations for Pyspark SQL.""" -from typing import TypeVar, Union +import functools +import json +from typing import Union, TypeVar, Any, get_args, Generic -from pandera.typing.common import DataFrameBase -from pandera.typing.pandas import DataFrameModel, _GenericAlias +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + +from pandera.engines import pyspark_engine, PYDANTIC_V2 +from pandera.errors import SchemaInitError +from pandera.typing.common import DataFrameBase, _GenericAlias +from pandera.api.pyspark import DataFrameModel try: import pyspark.sql as ps @@ -12,9 +19,9 @@ except ImportError: # pragma: no cover PYSPARK_SQL_INSTALLED = False -if PYSPARK_SQL_INSTALLED: - from pandera.engines import pyspark_engine +T = TypeVar("T", bound=DataFrameModel) +if PYSPARK_SQL_INSTALLED: PysparkString = pyspark_engine.String PysparkInt = pyspark_engine.Int PysparkLongInt = pyspark_engine.BigInt @@ -43,13 +50,6 @@ PysparkBinary, # type: ignore ], ) - from typing import TYPE_CHECKING, Generic - - # pylint:disable=invalid-name - if TYPE_CHECKING: - T = TypeVar("T") # pragma: no cover - else: - T = DataFrameModel if PYSPARK_SQL_INSTALLED: # pylint: disable=too-few-public-methods,arguments-renamed @@ -64,3 +64,75 @@ class DataFrame(DataFrameBase, ps.DataFrame, Generic[T]): def __class_getitem__(cls, item): """Define this to override's pyspark.pandas generic type.""" return _GenericAlias(cls, item) # pragma: no cover + + @classmethod + def pydantic_validate( + cls, obj: ps.DataFrame, schema_model: T + ) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic V1 and V2. + """ + try: + schema = schema_model.to_schema() + except SchemaInitError as exc: + error_message = ( + f"Cannot use {cls} as a pydantic type as its " + "DataFrameModel cannot be converted to a DataFrameSchema.\n" + f"Please revisit the model to address the following errors:" + f"\n{exc}" + ) + raise ValueError(error_message) from exc + + validated_data = schema.validate(obj) + + if validated_data.pandera.errors: + errors = json.dumps( + dict(validated_data.pandera.errors), indent=4 + ) + raise ValueError(errors) + + return validated_data + + if PYDANTIC_V2: + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + schema_model = get_args(_source_type)[0] + return core_schema.no_info_plain_validator_function( + functools.partial( + cls.pydantic_validate, + schema_model=schema_model, + ), + ) + + else: + + @classmethod + def __get_validators__(cls): + yield cls._pydantic_validate + + @classmethod + def _get_schema_model(cls, field): + if not field.sub_fields: + raise TypeError( + "Expected a typed pandera.typing.DataFrame," + " e.g. DataFrame[Schema]" + ) + schema_model = field.sub_fields[0].type_ + return schema_model + + @classmethod + def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame: + """ + Verify that the input can be converted into a pandas dataframe that + meets all schema requirements. + + This is for pydantic v1 + """ + schema_model = cls._get_schema_model(field) + return cls.pydantic_validate(obj, schema_model) diff --git a/tests/pyspark/conftest.py b/tests/pyspark/conftest.py index 4ae06ea50..09f129fdc 100644 --- a/tests/pyspark/conftest.py +++ b/tests/pyspark/conftest.py @@ -17,6 +17,7 @@ def spark() -> SparkSession: creates spark session """ spark: SparkSession = SparkSession.builder.getOrCreate() + spark.conf.set("spark.sql.ansi.enabled", False) yield spark spark.stop() @@ -29,6 +30,7 @@ def spark_connect() -> SparkSession: # Set location of localhost Spark Connect server os.environ["SPARK_LOCAL_REMOTE"] = "sc://localhost" spark: SparkSession = SparkSession.builder.getOrCreate() + spark.conf.set("spark.sql.ansi.enabled", False) yield spark spark.stop() diff --git a/tests/pyspark/test_pyspark_accessor.py b/tests/pyspark/test_pyspark_accessor.py index cf2fbc4b0..2f5e7d070 100644 --- a/tests/pyspark/test_pyspark_accessor.py +++ b/tests/pyspark/test_pyspark_accessor.py @@ -12,6 +12,7 @@ from pandera.pyspark import pyspark_sql_accessor spark = SparkSession.builder.getOrCreate() +spark.conf.set("spark.sql.ansi.enabled", False) @pytest.mark.parametrize( diff --git a/tests/pyspark/test_pyspark_pydantic_integration.py b/tests/pyspark/test_pyspark_pydantic_integration.py new file mode 100644 index 000000000..0db32240d --- /dev/null +++ b/tests/pyspark/test_pyspark_pydantic_integration.py @@ -0,0 +1,81 @@ +"""Tests for the integration between PySpark and Pydantic.""" + +import pytest +from pydantic import BaseModel, ValidationError +from pyspark.testing.utils import assertDataFrameEqual +import pyspark.sql.types as T + +import pandera.pyspark as pa +from pandera.typing.pyspark_sql import DataFrame as PySparkSQLDataFrame +from pandera.typing.pyspark import DataFrame as PySparkDataFrame +from pandera.pyspark import DataFrameModel + + +@pytest.fixture +def sample_schema_model(): + class SampleSchema(DataFrameModel): + """ + Sample schema model with data checks. + """ + + product: T.StringType() = pa.Field() + price: T.IntegerType() = pa.Field() + + return SampleSchema + + +@pytest.fixture( + params=[PySparkDataFrame, PySparkSQLDataFrame], + ids=["pyspark", "pyspark_sql"], +) +def pydantic_container(request, sample_schema_model): + TypingClass = request.param + + class PydanticContainer(BaseModel): + """ + Pydantic container with a DataFrameModel as a field. + """ + + data: TypingClass[sample_schema_model] + + return PydanticContainer + + +@pytest.fixture +def correct_data(spark, sample_data, sample_spark_schema): + """ + Correct data that should pass validation. + """ + return spark.createDataFrame(sample_data, sample_spark_schema) + + +@pytest.fixture +def incorrect_data(spark): + """ + Incorrect data that should fail validation. + """ + data = [ + (1, "Apples"), + (2, "Bananas"), + ] + return spark.createDataFrame(data, ["product", "price"]) + + +def test_pydantic_model_instantiates_with_correct_data( + correct_data, pydantic_container +): + """ + Test that a Pydantic model can be instantiated with a DataFrameModel when data is valid. + """ + my_container = pydantic_container(data=correct_data) + assertDataFrameEqual(my_container.data, correct_data) + + +def test_pydantic_model_throws_validation_error_with_incorrect_data( + incorrect_data, pydantic_container +): + """ + Test that a Pydantic model throws a ValidationError when data is invalid. + """ + with pytest.raises(ValidationError): + pydantic_container(data=incorrect_data) diff --git a/tests/pyspark/test_schemas_on_pyspark_pandas.py b/tests/pyspark/test_schemas_on_pyspark_pandas.py index c39ad3981..e0a5581fe 100644 --- a/tests/pyspark/test_schemas_on_pyspark_pandas.py +++ b/tests/pyspark/test_schemas_on_pyspark_pandas.py @@ -278,8 +278,10 @@ def test_index_dtypes( not in { pandas_engine.Engine.dtype(pandas_engine.BOOL), pandas_engine.DateTime(tz="UTC"), # type: ignore[call-arg] + pandas_engine.Engine.dtype(pa.dtypes.Timedelta), # type: ignore[call-arg] } ], + ids=lambda x: str(x) ) @hypothesis.given(st.data()) def test_nullable(