Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,8 @@ def partitions(self) -> Sequence[InputPartition]:
partition value to read the data.

This method is called once during query planning. By default, it returns a
single partition with the value ``None``. Subclasses can override this method
to return multiple partitions.
single partition with the value `InputPartition(None)`. Subclasses can override
this method to return multiple partitions.

It's recommended to override this method for better performance when reading
large datasets.
Expand Down Expand Up @@ -626,10 +626,7 @@ def partitions(self) -> Sequence[InputPartition]:
>>> def partitions(self):
... return [RangeInputPartition(1, 3), RangeInputPartition(5, 10)]
"""
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
messageParameters={"feature": "partitions"},
)
return [InputPartition(None)]

@abstractmethod
def read(self, partition: InputPartition) -> Union[Iterator[Tuple], Iterator["RecordBatch"]]:
Expand All @@ -643,7 +640,7 @@ def read(self, partition: InputPartition) -> Union[Iterator[Tuple], Iterator["Re

Parameters
----------
partition : object
partition : InputPartition
The partition to read. It must be one of the partition values returned by
:meth:`DataSourceReader.partitions`.

Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def partitions(self):
if partition_func is not None:
return partition_func()
else:
raise NotImplementedError
return [InputPartition(None)]

def read(self, partition):
return read_func(self.schema, partition)
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def reader(self, schema) -> "DataSourceReader":
{"class_name": "TestJsonReader", "func_name": "partitions"},
),
(
"TestJsonReader.read: None",
"TestJsonReader.read: InputPartition(value=None)",
{"class_name": "TestJsonReader", "func_name": "read"},
),
]
Expand Down Expand Up @@ -1151,7 +1151,7 @@ def reader(self, schema) -> "DataSourceReader":
{"class_name": "TestJsonReader", "func_name": "partitions"},
),
(
"TestJsonReader.read: None",
"TestJsonReader.read: InputPartition(value=None)",
{"class_name": "TestJsonReader", "func_name": "read"},
),
]
Expand Down
47 changes: 22 additions & 25 deletions python/pyspark/sql/worker/plan_data_source_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,12 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
# Deserialize the partition value.
partition = pickleSer.loads(partition_bytes)

assert partition is None or isinstance(partition, InputPartition), (
assert isinstance(partition, InputPartition), (
"Expected the partition value to be of type 'InputPartition', "
f"but found '{type(partition).__name__}'."
)

output_iter = reader.read(partition) # type: ignore[arg-type]
output_iter = reader.read(partition)

# Validate the output iterator.
if not isinstance(output_iter, Iterator):
Expand All @@ -240,29 +240,26 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec

if not is_streaming:
# The partitioning of python batch source read is determined before query execution.
try:
partitions = reader.partitions() # type: ignore[call-arg]
if not isinstance(partitions, list):
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "'partitions' to return a list",
"actual": f"'{type(partitions).__name__}'",
},
)
if not all(isinstance(p, InputPartition) for p in partitions):
partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions])
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "elements in 'partitions' to be of type 'InputPartition'",
"actual": partition_types,
},
)
if len(partitions) == 0:
partitions = [None] # type: ignore[list-item]
except NotImplementedError:
partitions = [None] # type: ignore[list-item]
partitions = reader.partitions() # type: ignore[call-arg]
if not isinstance(partitions, list):
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "'partitions' to return a list",
"actual": f"'{type(partitions).__name__}'",
},
)
if not all(isinstance(p, InputPartition) for p in partitions):
partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions])
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "elements in 'partitions' to be of type 'InputPartition'",
"actual": partition_types,
},
)
if len(partitions) == 0:
partitions = [InputPartition(None)]

# Return the serialized partition values.
write_int(len(partitions), outfile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase {
| return []
|
| def read(self, partition):
| if partition is None:
| if partition.value is None:
| yield ("success", )
| else:
| yield ("failed", )
Expand Down