diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index bb73a7a9206b1..5547150349388 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -588,8 +588,8 @@ def partitions(self) -> Sequence[InputPartition]: partition value to read the data. This method is called once during query planning. By default, it returns a - single partition with the value ``None``. Subclasses can override this method - to return multiple partitions. + single partition with the value `InputPartition(None)`. Subclasses can override + this method to return multiple partitions. It's recommended to override this method for better performance when reading large datasets. @@ -626,10 +626,7 @@ def partitions(self) -> Sequence[InputPartition]: >>> def partitions(self): ... return [RangeInputPartition(1, 3), RangeInputPartition(5, 10)] """ - raise PySparkNotImplementedError( - errorClass="NOT_IMPLEMENTED", - messageParameters={"feature": "partitions"}, - ) + return [InputPartition(None)] @abstractmethod def read(self, partition: InputPartition) -> Union[Iterator[Tuple], Iterator["RecordBatch"]]: @@ -643,7 +640,7 @@ def read(self, partition: InputPartition) -> Union[Iterator[Tuple], Iterator["Re Parameters ---------- - partition : object + partition : InputPartition The partition to read. It must be one of the partition values returned by :meth:`DataSourceReader.partitions`. diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 9d90082c654d7..e8776c2887209 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -162,7 +162,7 @@ def partitions(self): if partition_func is not None: return partition_func() else: - raise NotImplementedError + return [InputPartition(None)] def read(self, partition): return read_func(self.schema, partition) @@ -1040,7 +1040,7 @@ def reader(self, schema) -> "DataSourceReader": {"class_name": "TestJsonReader", "func_name": "partitions"}, ), ( - "TestJsonReader.read: None", + "TestJsonReader.read: InputPartition(value=None)", {"class_name": "TestJsonReader", "func_name": "read"}, ), ] @@ -1151,7 +1151,7 @@ def reader(self, schema) -> "DataSourceReader": {"class_name": "TestJsonReader", "func_name": "partitions"}, ), ( - "TestJsonReader.read: None", + "TestJsonReader.read: InputPartition(value=None)", {"class_name": "TestJsonReader", "func_name": "read"}, ), ] diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index cded42031b0cc..d736df6084c61 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -210,12 +210,12 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec # Deserialize the partition value. partition = pickleSer.loads(partition_bytes) - assert partition is None or isinstance(partition, InputPartition), ( + assert isinstance(partition, InputPartition), ( "Expected the partition value to be of type 'InputPartition', " f"but found '{type(partition).__name__}'." ) - output_iter = reader.read(partition) # type: ignore[arg-type] + output_iter = reader.read(partition) # Validate the output iterator. if not isinstance(output_iter, Iterator): @@ -240,29 +240,26 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec if not is_streaming: # The partitioning of python batch source read is determined before query execution. - try: - partitions = reader.partitions() # type: ignore[call-arg] - if not isinstance(partitions, list): - raise PySparkRuntimeError( - errorClass="DATA_SOURCE_TYPE_MISMATCH", - messageParameters={ - "expected": "'partitions' to return a list", - "actual": f"'{type(partitions).__name__}'", - }, - ) - if not all(isinstance(p, InputPartition) for p in partitions): - partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions]) - raise PySparkRuntimeError( - errorClass="DATA_SOURCE_TYPE_MISMATCH", - messageParameters={ - "expected": "elements in 'partitions' to be of type 'InputPartition'", - "actual": partition_types, - }, - ) - if len(partitions) == 0: - partitions = [None] # type: ignore[list-item] - except NotImplementedError: - partitions = [None] # type: ignore[list-item] + partitions = reader.partitions() # type: ignore[call-arg] + if not isinstance(partitions, list): + raise PySparkRuntimeError( + errorClass="DATA_SOURCE_TYPE_MISMATCH", + messageParameters={ + "expected": "'partitions' to return a list", + "actual": f"'{type(partitions).__name__}'", + }, + ) + if not all(isinstance(p, InputPartition) for p in partitions): + partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions]) + raise PySparkRuntimeError( + errorClass="DATA_SOURCE_TYPE_MISMATCH", + messageParameters={ + "expected": "elements in 'partitions' to be of type 'InputPartition'", + "actual": partition_types, + }, + ) + if len(partitions) == 0: + partitions = [InputPartition(None)] # Return the serialized partition values. write_int(len(partitions), outfile) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index bac9849381a3f..a712607f18e41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -533,7 +533,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { | return [] | | def read(self, partition): - | if partition is None: + | if partition.value is None: | yield ("success", ) | else: | yield ("failed", )