diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index ed1c602b0af4b..32f22c7b03b10 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -220,6 +220,11 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec return records_to_arrow_batches(output_iter, max_arrow_batch_size, return_type, data_source) + # Set the module name so UDF worker can recognize that this is a data source function. + # This is needed when simple worker is used because the __module__ will be set to + # __main__, which confuses the profiler logic. + data_source_read_func.__module__ = "pyspark.sql.worker.plan_data_source_read" + command = (data_source_read_func, return_type) pickleSer._write_with_length(command, outfile) diff --git a/python/pyspark/sql/worker/utils.py b/python/pyspark/sql/worker/utils.py index 8a99abe3e4e9a..406894fc275a6 100644 --- a/python/pyspark/sql/worker/utils.py +++ b/python/pyspark/sql/worker/utils.py @@ -70,7 +70,15 @@ def worker_run(main: Callable, infile: IO, outfile: IO) -> None: SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam ) - worker_module = main.__module__.split(".")[-1] + if main.__module__ == "__main__": + try: + worker_module = sys.modules["__main__"].__spec__.name # type: ignore[union-attr] + except Exception: + worker_module = "__main__" + else: + worker_module = main.__module__ + worker_module = worker_module.split(".")[-1] + if conf.profiler == "perf": with WorkerPerfProfiler(accumulator, worker_module): main(infile, outfile) diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index 111829bb7d58f..c808661fedb72 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -228,6 +228,11 @@ def batch_to_rows() -> Iterator[Row]: messages = pa.array([pickled]) yield pa.record_batch([messages], names=[return_col_name]) + # Set the module name so UDF worker can recognize that this is a data source function. + # This is needed when simple worker is used because the __module__ will be set to + # __main__, which confuses the profiler logic. + data_source_write_func.__module__ = "pyspark.sql.worker.write_into_data_source" + # Return the pickled write UDF. command = (data_source_write_func, return_type) pickleSer._write_with_length(command, outfile)