From 6c8e8550cf95e19fb3b4bdeef5960b5a4d0715ea Mon Sep 17 00:00:00 2001 From: Parth Chandra Date: Mon, 9 Feb 2026 16:54:01 +0530 Subject: [PATCH 1/3] enable ignored 4.0 tests, enable ansi mode --- dev/diffs/4.0.1.diff | 215 ++++++------------------------------------- 1 file changed, 28 insertions(+), 187 deletions(-) diff --git a/dev/diffs/4.0.1.diff b/dev/diffs/4.0.1.diff index d6694e827f..91119f9322 100644 --- a/dev/diffs/4.0.1.diff +++ b/dev/diffs/4.0.1.diff @@ -1,5 +1,5 @@ diff --git a/pom.xml b/pom.xml -index 22922143fc3..7c56e5e8641 100644 +index 2bf6ba60fdf..568e1f12f81 100644 --- a/pom.xml +++ b/pom.xml @@ -148,6 +148,8 @@ @@ -11,7 +11,7 @@ index 22922143fc3..7c56e5e8641 100644 + 10.16.1.1 + 1.15.2 +- 2.1.3 ++ 2.1.4 + shaded-protobuf + 11.0.24 + 5.0.0 @@ -148,6 +148,8 @@ 4.0.3 2.5.3 @@ -11,7 +2209,29 @@ index 2bf6ba60fdf..568e1f12f81 100644 org.apache.datasketches +@@ -3150,6 +3177,10 @@ + com.google.common + ${spark.shade.packageName}.guava + ++ ++ com.google.thirdparty ++ ${spark.shade.packageName}.guava.thirdparty ++ + + org.dmg.pmml + ${spark.shade.packageName}.dmg.pmml +diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala +index cded163e81f..c484fef8516 100644 +--- a/project/SparkBuild.scala ++++ b/project/SparkBuild.scala +@@ -364,7 +364,8 @@ object SparkBuild extends PomBuild { + /* Enable shared settings on all projects */ + (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) + .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ +- ExcludedDependencies.settings ++ Checkstyle.settings ++ ExcludeShims.settings)) ++ ExcludedDependencies.settings ++ (if (noLintOnCompile) Nil else Checkstyle.settings) ++ ++ ExcludeShims.settings)) + + /* Enable tests settings for all projects except examples, assembly and tools */ + (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) +@@ -1471,7 +1472,7 @@ object Unidoc { + ) ++ ( + // Add links to sources when generating Scaladoc for a non-snapshot release + if (!isSnapshot.value) { +- Opts.doc.sourceUrl(unidocSourceBase.value + "€{FILE_PATH}.scala") ++ Opts.doc.sourceUrl(unidocSourceBase.value + "€{FILE_PATH_EXT}") + } else { + Seq() + } +diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py +index da4d25cc908..661ba5a8a7e 100755 +--- a/python/packaging/classic/setup.py ++++ b/python/packaging/classic/setup.py +@@ -344,7 +344,7 @@ try: + license="http://www.apache.org/licenses/LICENSE-2.0", + # Don't forget to update python/docs/source/getting_started/install.rst + # if you're updating the versions or dependencies. +- install_requires=["py4j==0.10.9.9"], ++ install_requires=["py4j>=0.10.9.7,<0.10.9.10"], + extras_require={ + "ml": ["numpy>=%s" % _minimum_numpy_version], + "mllib": ["numpy>=%s" % _minimum_numpy_version], +diff --git a/python/pyspark/ml/connect/feature.py b/python/pyspark/ml/connect/feature.py +index a0e5b6a943d..e08b37337c6 100644 +--- a/python/pyspark/ml/connect/feature.py ++++ b/python/pyspark/ml/connect/feature.py +@@ -15,11 +15,11 @@ + # limitations under the License. + # + +-import pickle + from typing import Any, Union, List, Tuple, Callable, Dict, Optional + + import numpy as np + import pandas as pd ++import pyarrow as pa + + from pyspark import keyword_only + from pyspark.sql import DataFrame +@@ -132,27 +132,29 @@ class MaxAbsScalerModel(Model, HasInputCol, HasOutputCol, ParamsReadWrite, CoreM + return transform_fn + + def _get_core_model_filename(self) -> str: +- return self.__class__.__name__ + ".sklearn.pkl" ++ return self.__class__.__name__ + ".arrow.parquet" + + def _save_core_model(self, path: str) -> None: +- from sklearn.preprocessing import MaxAbsScaler as sk_MaxAbsScaler +- +- sk_model = sk_MaxAbsScaler() +- sk_model.scale_ = self.scale_values +- sk_model.max_abs_ = self.max_abs_values +- sk_model.n_features_in_ = len(self.max_abs_values) # type: ignore[arg-type] +- sk_model.n_samples_seen_ = self.n_samples_seen +- +- with open(path, "wb") as fp: +- pickle.dump(sk_model, fp) ++ import pyarrow.parquet as pq ++ ++ table = pa.Table.from_arrays( ++ [ ++ pa.array([self.scale_values], pa.list_(pa.float64())), ++ pa.array([self.max_abs_values], pa.list_(pa.float64())), ++ pa.array([self.n_samples_seen], pa.int64()), ++ ], ++ names=["scale", "max_abs", "n_samples"], ++ ) ++ pq.write_table(table, path) + + def _load_core_model(self, path: str) -> None: +- with open(path, "rb") as fp: +- sk_model = pickle.load(fp) ++ import pyarrow.parquet as pq ++ ++ table = pq.read_table(path) + +- self.max_abs_values = sk_model.max_abs_ +- self.scale_values = sk_model.scale_ +- self.n_samples_seen = sk_model.n_samples_seen_ ++ self.max_abs_values = np.array(table.column("scale")[0].as_py()) ++ self.scale_values = np.array(table.column("max_abs")[0].as_py()) ++ self.n_samples_seen = table.column("n_samples")[0].as_py() + + + class StandardScaler(Estimator, HasInputCol, HasOutputCol, ParamsReadWrite): +@@ -251,29 +253,31 @@ class StandardScalerModel(Model, HasInputCol, HasOutputCol, ParamsReadWrite, Cor + return transform_fn + + def _get_core_model_filename(self) -> str: +- return self.__class__.__name__ + ".sklearn.pkl" ++ return self.__class__.__name__ + ".arrow.parquet" + + def _save_core_model(self, path: str) -> None: +- from sklearn.preprocessing import StandardScaler as sk_StandardScaler +- +- sk_model = sk_StandardScaler(with_mean=True, with_std=True) +- sk_model.scale_ = self.scale_values +- sk_model.var_ = self.std_values * self.std_values # type: ignore[operator] +- sk_model.mean_ = self.mean_values +- sk_model.n_features_in_ = len(self.std_values) # type: ignore[arg-type] +- sk_model.n_samples_seen_ = self.n_samples_seen +- +- with open(path, "wb") as fp: +- pickle.dump(sk_model, fp) ++ import pyarrow.parquet as pq ++ ++ table = pa.Table.from_arrays( ++ [ ++ pa.array([self.scale_values], pa.list_(pa.float64())), ++ pa.array([self.mean_values], pa.list_(pa.float64())), ++ pa.array([self.std_values], pa.list_(pa.float64())), ++ pa.array([self.n_samples_seen], pa.int64()), ++ ], ++ names=["scale", "mean", "std", "n_samples"], ++ ) ++ pq.write_table(table, path) + + def _load_core_model(self, path: str) -> None: +- with open(path, "rb") as fp: +- sk_model = pickle.load(fp) ++ import pyarrow.parquet as pq ++ ++ table = pq.read_table(path) + +- self.std_values = np.sqrt(sk_model.var_) +- self.scale_values = sk_model.scale_ +- self.mean_values = sk_model.mean_ +- self.n_samples_seen = sk_model.n_samples_seen_ ++ self.scale_values = np.array(table.column("scale")[0].as_py()) ++ self.mean_values = np.array(table.column("mean")[0].as_py()) ++ self.std_values = np.array(table.column("std")[0].as_py()) ++ self.n_samples_seen = table.column("n_samples")[0].as_py() + + + class ArrayAssembler( +diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +index 6812db77845..96f153b7b1b 100644 +--- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py ++++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +@@ -17,7 +17,6 @@ + # + + import os +-import pickle + import tempfile + import unittest + +@@ -85,12 +84,6 @@ class FeatureTestsMixin: + np.testing.assert_allclose(model.max_abs_values, loaded_model.max_abs_values) + assert model.n_samples_seen == loaded_model.n_samples_seen + +- # Test loading core model as scikit-learn model +- with open(os.path.join(model_path, "MaxAbsScalerModel.sklearn.pkl"), "rb") as f: +- sk_model = pickle.load(f) +- sk_result = sk_model.transform(np.stack(list(local_df1.features))) +- np.testing.assert_allclose(sk_result, expected_result) +- + def test_standard_scaler(self): + df1 = self.spark.createDataFrame( + [ +@@ -141,12 +134,6 @@ class FeatureTestsMixin: + np.testing.assert_allclose(model.scale_values, loaded_model.scale_values) + assert model.n_samples_seen == loaded_model.n_samples_seen + +- # Test loading core model as scikit-learn model +- with open(os.path.join(model_path, "StandardScalerModel.sklearn.pkl"), "rb") as f: +- sk_model = pickle.load(f) +- sk_result = sk_model.transform(np.stack(list(local_df1.features))) +- np.testing.assert_allclose(sk_result, expected_result) +- + def test_array_assembler(self): + spark_df = self.spark.createDataFrame( + [ +diff --git a/python/pyspark/pandas/tests/io/test_feather.py b/python/pyspark/pandas/tests/io/test_feather.py +index 74fa6bc7d7b..10638d915c0 100644 +--- a/python/pyspark/pandas/tests/io/test_feather.py ++++ b/python/pyspark/pandas/tests/io/test_feather.py +@@ -17,8 +17,10 @@ + import unittest + + import pandas as pd ++import sys + + from pyspark import pandas as ps ++from pyspark.loose_version import LooseVersion + from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils + + +@@ -34,6 +36,16 @@ class FeatherMixin: + def psdf(self): + return ps.from_pandas(self.pdf) + ++ has_arrow_21_or_below = False ++ try: ++ import pyarrow as pa ++ ++ if LooseVersion(pa.__version__) < LooseVersion("22.0.0"): ++ has_arrow_21_or_below = True ++ except ImportError: ++ pass ++ ++ @unittest.skipIf(not has_arrow_21_or_below, "SPARK-54068") + def test_to_feather(self): + with self.temp_dir() as dirpath: + path1 = f"{dirpath}/file1.feather" +diff --git a/python/pyspark/pandas/tests/io/test_stata.py b/python/pyspark/pandas/tests/io/test_stata.py +index 6fe7cf13513..3cdf2cdb150 100644 +--- a/python/pyspark/pandas/tests/io/test_stata.py ++++ b/python/pyspark/pandas/tests/io/test_stata.py +@@ -14,6 +14,7 @@ + # See the License for the specific language governing permissions and + # limitations under the License. + # ++import os + import unittest + + import pandas as pd +@@ -33,6 +34,9 @@ class StataMixin: + def psdf(self): + return ps.from_pandas(self.pdf) + ++ @unittest.skipIf( ++ os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54486: To be reenabled" ++ ) + def test_to_feather(self): + with self.temp_dir() as dirpath: + path1 = f"{dirpath}/file1.dta" +diff --git a/python/pyspark/pandas/tests/test_typedef.py b/python/pyspark/pandas/tests/test_typedef.py +index cac9aaf193a..afed59660d7 100644 +--- a/python/pyspark/pandas/tests/test_typedef.py ++++ b/python/pyspark/pandas/tests/test_typedef.py +@@ -15,6 +15,7 @@ + # limitations under the License. + # + ++import os + import sys + import unittest + import datetime +@@ -313,7 +314,6 @@ class TypeHintTestsMixin: + def test_as_spark_type_pandas_on_spark_dtype(self): + type_mapper = { + # binary +- np.character: (np.character, BinaryType()), + np.bytes_: (np.bytes_, BinaryType()), + bytes: (np.bytes_, BinaryType()), + # integer +@@ -348,6 +348,10 @@ class TypeHintTestsMixin: + ), + } + ++ if LooseVersion(np.__version__) < LooseVersion("2.3"): ++ # binary ++ type_mapper.update({np.character: (np.character, BinaryType())}) ++ + for numpy_or_python_type, (dtype, spark_type) in type_mapper.items(): + self.assertEqual(as_spark_type(numpy_or_python_type), spark_type) + self.assertEqual(pandas_on_spark_type(numpy_or_python_type), (dtype, spark_type)) +diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py +index 48545d124b2..a4ed9f996fe 100644 +--- a/python/pyspark/pandas/typedef/typehints.py ++++ b/python/pyspark/pandas/typedef/typehints.py +@@ -342,7 +342,7 @@ def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.Dat + try: + dtype = pandas_dtype(tpe) + spark_type = as_spark_type(dtype) +- except TypeError: ++ except (TypeError, ValueError): + spark_type = as_spark_type(tpe) + dtype = spark_type_to_pandas_dtype(spark_type) + return dtype, spark_type +diff --git a/python/pyspark/sql/connect/window.py b/python/pyspark/sql/connect/window.py +index bf6d60df635..952258e8db4 100644 +--- a/python/pyspark/sql/connect/window.py ++++ b/python/pyspark/sql/connect/window.py +@@ -18,7 +18,7 @@ from pyspark.sql.connect.utils import check_dependencies + + check_dependencies(__name__) + +-from typing import TYPE_CHECKING, Union, Sequence, List, Optional, Tuple, cast, Iterable ++from typing import TYPE_CHECKING, Any, Union, Sequence, List, Optional, Tuple, cast, Iterable + + from pyspark.sql.column import Column + from pyspark.sql.window import ( +@@ -69,6 +69,9 @@ class WindowSpec(ParentWindowSpec): + self.__init__(partitionSpec, orderSpec, frame) # type: ignore[misc] + return self + ++ def __getnewargs__(self) -> Tuple[Any, ...]: ++ return (self._partitionSpec, self._orderSpec, self._frame) ++ + def __init__( + self, + partitionSpec: Sequence[Expression], +diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py +index cd06b3fa3ee..a3f8bc7a0f0 100644 +--- a/python/pyspark/sql/dataframe.py ++++ b/python/pyspark/sql/dataframe.py +@@ -852,7 +852,6 @@ class DataFrame: + + Notes + ----- +- - Unlike `count()`, this method does not trigger any computation. + - An empty DataFrame has no rows. It may have columns, but no data. + + Examples +diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py +index d2f9f0957e0..45ca818d7ae 100644 +--- a/python/pyspark/sql/streaming/query.py ++++ b/python/pyspark/sql/streaming/query.py +@@ -283,7 +283,10 @@ class StreamingQuery: + + >>> sq.stop() + """ +- return [StreamingQueryProgress.fromJObject(p) for p in self._jsq.recentProgress()] ++ return [ ++ StreamingQueryProgress.fromJson(json.loads(p.json())) ++ for p in self._jsq.recentProgress() ++ ] + + @property + def lastProgress(self) -> Optional[StreamingQueryProgress]: +@@ -314,7 +317,7 @@ class StreamingQuery: + """ + lastProgress = self._jsq.lastProgress() + if lastProgress: +- return StreamingQueryProgress.fromJObject(lastProgress) ++ return StreamingQueryProgress.fromJson(json.loads(lastProgress.json())) + else: + return None + +diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py +index f0637056ab8..bf51c0839f6 100755 +--- a/python/pyspark/sql/tests/connect/test_connect_basic.py ++++ b/python/pyspark/sql/tests/connect/test_connect_basic.py +@@ -145,6 +145,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): + cdf2 = loads(data) + self.assertEqual(cdf.collect(), cdf2.collect()) + ++ def test_window_spec_serialization(self): ++ from pyspark.sql.connect.window import Window ++ from pyspark.serializers import CPickleSerializer ++ ++ pickle_ser = CPickleSerializer() ++ w = Window.partitionBy("some_string").orderBy("value") ++ b = pickle_ser.dumps(w) ++ w2 = pickle_ser.loads(b) ++ self.assertEqual(str(w), str(w2)) ++ + def test_df_getattr_behavior(self): + cdf = self.connect.range(10) + sdf = self.spark.range(10) +diff --git a/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py b/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py +index c6ef9810c68..c3b50341bbd 100644 +--- a/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py ++++ b/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py +@@ -19,7 +19,10 @@ import os + import unittest + + from pyspark.tests.test_memory_profiler import MemoryProfiler2TestsMixin, _do_computation +-from pyspark.testing.connectutils import ReusedConnectTestCase ++from pyspark.testing.connectutils import ( ++ ReusedConnectTestCase, ++ skip_if_server_version_is_greater_than_or_equal_to, ++) + + + class MemoryProfilerParityTests(MemoryProfiler2TestsMixin, ReusedConnectTestCase): +@@ -27,6 +30,14 @@ class MemoryProfilerParityTests(MemoryProfiler2TestsMixin, ReusedConnectTestCase + super().setUp() + self.spark._profiler_collector._value = None + ++ @skip_if_server_version_is_greater_than_or_equal_to("4.1.0") ++ def test_memory_profiler_pandas_udf_iterator_not_supported(self): ++ super().test_memory_profiler_pandas_udf_iterator_not_supported() ++ ++ @skip_if_server_version_is_greater_than_or_equal_to("4.1.0") ++ def test_memory_profiler_map_in_pandas_not_supported(self): ++ super().test_memory_profiler_map_in_pandas_not_supported() ++ + + class MemoryProfilerWithoutPlanCacheParityTests(MemoryProfilerParityTests): + @classmethod +diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py +index 5c46130c5b5..11bc4ef8384 100644 +--- a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py ++++ b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py +@@ -22,7 +22,10 @@ from pyspark.sql.tests.test_udf_profiler import ( + UDFProfiler2TestsMixin, + _do_computation, + ) +-from pyspark.testing.connectutils import ReusedConnectTestCase ++from pyspark.testing.connectutils import ( ++ ReusedConnectTestCase, ++ skip_if_server_version_is_greater_than_or_equal_to, ++) + from pyspark.testing.utils import have_flameprof + + +@@ -31,6 +34,14 @@ class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase): + super().setUp() + self.spark._profiler_collector._value = None + ++ @skip_if_server_version_is_greater_than_or_equal_to("4.1.0") ++ def test_perf_profiler_pandas_udf_iterator_not_supported(self): ++ super().test_perf_profiler_pandas_udf_iterator_not_supported() ++ ++ @skip_if_server_version_is_greater_than_or_equal_to("4.1.0") ++ def test_perf_profiler_map_in_pandas_not_supported(self): ++ super().test_perf_profiler_map_in_pandas_not_supported() ++ + + class UDFProfilerWithoutPlanCacheParityTests(UDFProfilerParityTests): + @classmethod +diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +index 1f953235267..3a6ab9c98eb 100644 +--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py ++++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +@@ -262,7 +262,7 @@ class CogroupedApplyInPandasTestsMixin: + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + self._test_merge_error( +- fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["2.0"]}), ++ fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["test_string"]}), + output_schema="id long, k double", + errorClass=PythonException, + error_message_regex=expected, +diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +index 4ef334549ef..d60e31d8879 100644 +--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py ++++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +@@ -17,6 +17,7 @@ + + import datetime + import unittest ++import os + + from collections import OrderedDict + from decimal import Decimal +@@ -288,28 +289,20 @@ class GroupedApplyInPandasTestsMixin: + ): + self._test_apply_in_pandas(lambda key, pdf: key) + +- @staticmethod +- def stats_with_column_names(key, pdf): +- # order of column can be different to applyInPandas schema when column names are given +- return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"]) +- +- @staticmethod +- def stats_with_no_column_names(key, pdf): +- # columns must be in order of applyInPandas schema when no columns given +- return pd.DataFrame([key + (pdf.v.mean(),)]) +- + def test_apply_in_pandas_returning_column_names(self): +- self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_column_names) ++ self._test_apply_in_pandas( ++ lambda key, pdf: pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"]) ++ ) + + def test_apply_in_pandas_returning_no_column_names(self): +- self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_no_column_names) ++ self._test_apply_in_pandas(lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),)])) + + def test_apply_in_pandas_returning_column_names_sometimes(self): + def stats(key, pdf): + if key[0] % 2: +- return GroupedApplyInPandasTestsMixin.stats_with_column_names(key, pdf) ++ return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"]) + else: +- return GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf) ++ return pd.DataFrame([key + (pdf.v.mean(),)]) + + self._test_apply_in_pandas(stats) + +@@ -343,9 +336,15 @@ class GroupedApplyInPandasTestsMixin: + lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(), pdf.v.std())]) + ) + ++ @unittest.skipIf( ++ os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54482: To be reenabled" ++ ) + def test_apply_in_pandas_returning_empty_dataframe(self): + self._test_apply_in_pandas_returning_empty_dataframe(pd.DataFrame()) + ++ @unittest.skipIf( ++ os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54482: To be reenabled" ++ ) + def test_apply_in_pandas_returning_incompatible_type(self): + with self.quiet(): + self.check_apply_in_pandas_returning_incompatible_type() +@@ -846,7 +845,7 @@ class GroupedApplyInPandasTestsMixin: + + def stats(key, pdf): + if key[0] % 2 == 0: +- return GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf) ++ return pd.DataFrame([key + (pdf.v.mean(),)]) + return empty_df + + result = ( +diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py +index 692f9705411..e5d0b56be69 100644 +--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py ++++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py +@@ -251,16 +251,17 @@ class MapInPandasTestsMixin: + self.check_dataframes_with_incompatible_types() + + def check_dataframes_with_incompatible_types(self): +- def func(iterator): +- for pdf in iterator: +- yield pdf.assign(id=pdf["id"].apply(str)) +- + for safely in [True, False]: + with self.subTest(convertToArrowArraySafely=safely), self.sql_conf( + {"spark.sql.execution.pandas.convertToArrowArraySafely": safely} + ): + # sometimes we see ValueErrors + with self.subTest(convert="string to double"): ++ ++ def func(iterator): ++ for pdf in iterator: ++ yield pdf.assign(id="test_string") ++ + expected = ( + r"ValueError: Exception thrown when converting pandas.Series " + r"\(object\) with name 'id' to Arrow Array \(double\)." +@@ -279,18 +280,31 @@ class MapInPandasTestsMixin: + .collect() + ) + +- # sometimes we see TypeErrors +- with self.subTest(convert="double to string"): +- with self.assertRaisesRegex( +- PythonException, +- r"TypeError: Exception thrown when converting pandas.Series " +- r"\(float64\) with name 'id' to Arrow Array \(string\).\n", +- ): +- ( +- self.spark.range(10, numPartitions=3) +- .select(col("id").cast("double")) +- .mapInPandas(self.identity_dataframes_iter("id"), "id string") +- .collect() ++ with self.subTest(convert="float to int precision loss"): ++ ++ def func(iterator): ++ for pdf in iterator: ++ yield pdf.assign(id=pdf["id"] + 0.1) ++ ++ df = ( ++ self.spark.range(10, numPartitions=3) ++ .select(col("id").cast("double")) ++ .mapInPandas(func, "id int") ++ ) ++ if safely: ++ expected = ( ++ r"ValueError: Exception thrown when converting pandas.Series " ++ r"\(float64\) with name 'id' to Arrow Array \(int32\)." ++ " It can be caused by overflows or other " ++ "unsafe conversions warned by Arrow. Arrow safe type check " ++ "can be disabled by using SQL config " ++ "`spark.sql.execution.pandas.convertToArrowArraySafely`." ++ ) ++ with self.assertRaisesRegex(PythonException, expected + "\n"): ++ df.collect() ++ else: ++ self.assertEqual( ++ df.collect(), self.spark.range(10, numPartitions=3).collect() + ) + + def test_empty_iterator(self): +diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +index fe027875880..ae62124153c 100644 +--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py ++++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +@@ -1601,6 +1601,49 @@ class TransformWithStateInPandasTestsMixin: + check_exception=check_exception, + ) + ++ def test_transform_with_state_in_pandas_large_values(self): ++ """Test large state values (512KB) to validate readFully fix for SPARK-53870""" ++ ++ def check_results(batch_df, batch_id): ++ batch_df.collect() ++ target_size_bytes = 512 * 1024 ++ large_string = "a" * target_size_bytes ++ expected_list_elements = ",".join( ++ [large_string, large_string + "b", large_string + "c"] ++ ) ++ expected_map_result = f"large_string_key:{large_string}" ++ ++ assert set(batch_df.sort("id").collect()) == { ++ Row( ++ id="0", ++ valueStateResult=large_string, ++ listStateResult=expected_list_elements, ++ mapStateResult=expected_map_result, ++ ), ++ Row( ++ id="1", ++ valueStateResult=large_string, ++ listStateResult=expected_list_elements, ++ mapStateResult=expected_map_result, ++ ), ++ } ++ ++ output_schema = StructType( ++ [ ++ StructField("id", StringType(), True), ++ StructField("valueStateResult", StringType(), True), ++ StructField("listStateResult", StringType(), True), ++ StructField("mapStateResult", StringType(), True), ++ ] ++ ) ++ ++ self._test_transform_with_state_in_pandas_basic( ++ PandasLargeValueStatefulProcessor(), ++ check_results, ++ single_batch=True, ++ output_schema=output_schema, ++ ) ++ + + class SimpleStatefulProcessorWithInitialState(StatefulProcessor): + # this dict is the same as input initial state dataframe +@@ -2374,6 +2417,46 @@ class PandasStatefulProcessorCompositeType(StatefulProcessor): + pass + + ++class PandasLargeValueStatefulProcessor(StatefulProcessor): ++ """Test processor for large state values (512KB) to validate readFully fix""" ++ ++ def init(self, handle: StatefulProcessorHandle): ++ value_state_schema = StructType([StructField("value", StringType(), True)]) ++ self.value_state = handle.getValueState("valueState", value_state_schema) ++ ++ list_state_schema = StructType([StructField("value", StringType(), True)]) ++ self.list_state = handle.getListState("listState", list_state_schema) ++ ++ self.map_state = handle.getMapState("mapState", "key string", "value string") ++ ++ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: ++ target_size_bytes = 512 * 1024 ++ large_string = "a" * target_size_bytes ++ ++ self.value_state.update((large_string,)) ++ value_retrieved = self.value_state.get()[0] ++ ++ self.list_state.put([(large_string,), (large_string + "b",), (large_string + "c",)]) ++ list_retrieved = list(self.list_state.get()) ++ list_elements = ",".join([elem[0] for elem in list_retrieved]) ++ ++ map_key = ("large_string_key",) ++ self.map_state.updateValue(map_key, (large_string,)) ++ map_retrieved = f"{map_key[0]}:{self.map_state.getValue(map_key)[0]}" ++ ++ yield pd.DataFrame( ++ { ++ "id": key, ++ "valueStateResult": [value_retrieved], ++ "listStateResult": [list_elements], ++ "mapStateResult": [map_retrieved], ++ } ++ ) ++ ++ def close(self) -> None: ++ pass ++ ++ + class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): + pass + +diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py +index 423a717e8ab..b4573d5fb5c 100644 +--- a/python/pyspark/testing/connectutils.py ++++ b/python/pyspark/testing/connectutils.py +@@ -16,12 +16,12 @@ + # + import shutil + import tempfile +-import typing + import os + import functools + import unittest + import uuid + import contextlib ++from typing import Callable, Optional + + from pyspark.testing import ( + grpc_requirement_message, +@@ -36,6 +36,7 @@ from pyspark.testing import ( + should_test_connect, + ) + from pyspark import Row, SparkConf ++from pyspark.loose_version import LooseVersion + from pyspark.util import is_remote_only + from pyspark.testing.utils import PySparkErrorTestUtils + from pyspark.testing.sqlutils import ( +@@ -197,3 +198,28 @@ class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils, PySparkErrorTestUti + return QuietTest(self._legacy_sc) + else: + return contextlib.nullcontext() ++ ++ ++def skip_if_server_version_is( ++ cond: Callable[[LooseVersion], bool], reason: Optional[str] = None ++) -> Callable: ++ def decorator(f: Callable) -> Callable: ++ @functools.wraps(f) ++ def wrapper(self, *args, **kwargs): ++ version = self.spark.version ++ if cond(LooseVersion(version)): ++ raise unittest.SkipTest( ++ f"Skipping test {f.__name__} because server version is {version}" ++ + (f" ({reason})" if reason else "") ++ ) ++ return f(self, *args, **kwargs) ++ ++ return wrapper ++ ++ return decorator ++ ++ ++def skip_if_server_version_is_greater_than_or_equal_to( ++ version: str, reason: Optional[str] = None ++) -> Callable: ++ return skip_if_server_version_is(lambda v: v >= LooseVersion(version), reason) +diff --git a/python/pyspark/version.py b/python/pyspark/version.py +index bfcc501ff93..41148c646f7 100644 +--- a/python/pyspark/version.py ++++ b/python/pyspark/version.py +@@ -16,4 +16,4 @@ + # See the License for the specific language governing permissions and + # limitations under the License. + +-__version__: str = "4.0.1" ++__version__: str = "4.0.3.dev0" +diff --git a/repl/pom.xml b/repl/pom.xml +index 02ed999e9b9..8f962239689 100644 +--- a/repl/pom.xml ++++ b/repl/pom.xml +@@ -21,7 +21,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../pom.xml + + +diff --git a/repl/src/test/resources/IntSumUdf.class b/repl/src/test/resources/IntSumUdf.class +new file mode 100644 +index 00000000000..75a41446cfc +Binary files /dev/null and b/repl/src/test/resources/IntSumUdf.class differ +diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml +index f3bace4ec6a..19f19273f6b 100644 +--- a/resource-managers/kubernetes/core/pom.xml ++++ b/resource-managers/kubernetes/core/pom.xml +@@ -20,7 +20,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../../pom.xml + + +diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml +index 5c31a10641b..ce77018ff85 100644 +--- a/resource-managers/kubernetes/integration-tests/pom.xml ++++ b/resource-managers/kubernetes/integration-tests/pom.xml +@@ -20,7 +20,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../../pom.xml + + +diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml +index 8a9437a04f6..581762e4bef 100644 +--- a/resource-managers/yarn/pom.xml ++++ b/resource-managers/yarn/pom.xml +@@ -20,7 +20,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../pom.xml + + +diff --git a/sql/api/pom.xml b/sql/api/pom.xml +index 09d458bdc5a..db17f3a5f5d 100644 +--- a/sql/api/pom.xml ++++ b/sql/api/pom.xml +@@ -22,7 +22,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../pom.xml + + +diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala +index 0f219725523..b90d9f8013d 100644 +--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala ++++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala +@@ -55,7 +55,7 @@ object JavaSerializationCodec extends (() => Codec[Any, Array[Byte]]) { + * server (driver & executors) very tricky. As a workaround a user can define their own Codec + * which internalizes the Kryo configuration. + */ +-object KryoSerializationCodec extends (() => Codec[Any, Array[Byte]]) { ++object KryoSerializationCodec extends (() => Codec[Any, Array[Byte]]) with Serializable { + private lazy val kryoCodecConstructor: MethodHandle = { + val cls = SparkClassUtils.classForName( + "org.apache.spark.sql.catalyst.encoders.KryoSerializationCodecImpl") +diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +index dd8ca26c524..044100c9226 100644 +--- a/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala ++++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +@@ -93,7 +93,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa + case _ => false + } + +- override def catalogString: String = sqlType.simpleString ++ override def catalogString: String = sqlType.catalogString + } + + private[spark] object UserDefinedType { +diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml +index 3b3e2a07b0c..bfc482e581c 100644 +--- a/sql/catalyst/pom.xml ++++ b/sql/catalyst/pom.xml +@@ -22,7 +22,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../pom.xml + + +diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsTriggerAvailableNow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsTriggerAvailableNow.java +index 47662dc97cc..268fa577b29 100644 +--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsTriggerAvailableNow.java ++++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsTriggerAvailableNow.java +@@ -36,6 +36,13 @@ public interface SupportsTriggerAvailableNow extends SupportsAdmissionControl { + * the query). The source will behave as if there is no new data coming in after the target + * offset, i.e., the source will not return an offset higher than the target offset when + * {@link #latestOffset(Offset, ReadLimit) latestOffset} is called. ++ *

++ * Note that there is an exception on the first uncommitted batch after a restart, where the end ++ * offset is not derived from the current latest offset. Sources need to take special ++ * considerations if wanting to assert such relation. One possible way is to have an internal ++ * flag in the source to indicate whether it is Trigger.AvailableNow, set the flag in this method, ++ * and record the target offset in the first call of ++ * {@link #latestOffset(Offset, ReadLimit) latestOffset}. + */ + void prepareForTriggerAvailableNow(); + } +diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +index ac05981da5a..b14cd3429e4 100644 +--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java ++++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +@@ -164,6 +164,7 @@ public final class ColumnarRow extends InternalRow { + + @Override + public Object get(int ordinal, DataType dataType) { ++ if (isNullAt(ordinal)) return null; + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +index 492ea741236..9dcaba8c2bc 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{expressions => exprs} + import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} + import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec} + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} +-import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder} ++import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, externalDataTypeFor, isNativeEncoder} + import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} + import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} + import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils} +@@ -241,19 +241,12 @@ object DeserializerBuildHelper { + val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName) + // Assumes we are deserializing the first column of a row. + val input = GetColumnByOrdinal(0, enc.dataType) +- enc match { +- case AgnosticEncoders.RowEncoder(fields) => +- val children = fields.zipWithIndex.map { case (f, i) => +- createDeserializer(f.enc, GetStructField(input, i), walkedTypePath) +- } +- CreateExternalRow(children, enc.schema) +- case _ => +- val deserializer = createDeserializer( +- enc, +- upCastToExpectedType(input, enc.dataType, walkedTypePath), +- walkedTypePath) +- expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) +- } ++ val deserializer = createDeserializer( ++ enc, ++ upCastToExpectedType(input, enc.dataType, walkedTypePath), ++ walkedTypePath, ++ isTopLevel = true) ++ expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) + } + + /** +@@ -265,11 +258,13 @@ object DeserializerBuildHelper { + * external representation. + * @param path The expression which can be used to extract serialized value. + * @param walkedTypePath The paths from top to bottom to access current field when deserializing. ++ * @param isTopLevel true if we are creating a deserializer for the top level value. + */ + private def createDeserializer( + enc: AgnosticEncoder[_], + path: Expression, +- walkedTypePath: WalkedTypePath): Expression = enc match { ++ walkedTypePath: WalkedTypePath, ++ isTopLevel: Boolean = false): Expression = enc match { + case ae: AgnosticExpressionPathEncoder[_] => + ae.fromCatalyst(path) + case _ if isNativeEncoder(enc) => +@@ -408,13 +403,12 @@ object DeserializerBuildHelper { + NewInstance(cls, arguments, Nil, propagateNull = false, dt, outerPointerGetter)) + + case AgnosticEncoders.RowEncoder(fields) => +- val isExternalRow = !path.dataType.isInstanceOf[StructType] + val convertedFields = fields.zipWithIndex.map { case (f, i) => + val newTypePath = walkedTypePath.recordField( + f.enc.clsTag.runtimeClass.getName, + f.name) + val deserializer = createDeserializer(f.enc, GetStructField(path, i), newTypePath) +- if (isExternalRow) { ++ if (!isTopLevel) { + exprs.If( + Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil), + exprs.Literal.create(null, externalDataTypeFor(f.enc)), +@@ -459,8 +453,8 @@ object DeserializerBuildHelper { + Invoke( + Literal.create(provider(), ObjectType(classOf[Codec[_, _]])), + "decode", +- ObjectType(tag.runtimeClass), +- createDeserializer(encoder, path, walkedTypePath) :: Nil) ++ dataTypeForClass(tag.runtimeClass), ++ createDeserializer(encoder, path, walkedTypePath, isTopLevel) :: Nil) + } + + private def deserializeArray( +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +index 5c4e9d4bddc..b568722c38a 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +@@ -756,7 +756,7 @@ object CatalogTable { + props.get(key).orElse { + if (props.exists { case (mapKey, _) => mapKey.startsWith(key) }) { + props.get(s"$key.numParts") match { +- case None => throw QueryCompilationErrors.insufficientTablePropertyError(key) ++ case None => None + case Some(numParts) => + val parts = (0 until numParts.toInt).map { index => + val keyPart = s"$key.part.$index" +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala +index 8f717795605..16d5adb064d 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala +@@ -152,6 +152,12 @@ object EncoderUtils { + VariantType -> classOf[VariantVal] + ) + ++ def dataTypeForClass(c: Class[_]): DataType = ++ javaClassToPrimitiveType.get(c).getOrElse(ObjectType(c)) ++ ++ private val javaClassToPrimitiveType: Map[Class[_], DataType] = ++ typeJavaMapping.iterator.filter(_._2.isPrimitive).map(_.swap).toMap ++ + val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]]( + BooleanType -> classOf[java.lang.Boolean], + ByteType -> classOf[java.lang.Byte], +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala +index 784bea899c4..e3ff7c5f05f 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala +@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch + import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType} + import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral} + import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper ++import org.apache.spark.sql.catalyst.optimizer.ScalarSubqueryReference + import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE + import org.apache.spark.sql.types._ + import org.apache.spark.util.sketch.BloomFilter +@@ -58,6 +59,7 @@ case class BloomFilterMightContain( + case GetStructField(subquery: PlanExpression[_], _, _) + if !subquery.containsPattern(OUTER_REFERENCE) => + TypeCheckResult.TypeCheckSuccess ++ case _: ScalarSubqueryReference => TypeCheckResult.TypeCheckSuccess + case _ => + DataTypeMismatch( + errorSubClass = "BLOOM_FILTER_BINARY_OP_WRONG_TYPE", +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala +index cbc8a8f273e..d3165e3a3e6 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala +@@ -328,7 +328,8 @@ case class HllUnionAgg( + union.update(sketch) + Some(union) + } catch { +- case _: SketchesArgumentException | _: java.lang.Error => ++ case _: SketchesArgumentException | _: java.lang.Error ++ | _: ArrayIndexOutOfBoundsException => + throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName) + } + case _ => +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala +index a4ac0bdbb11..1880d71e7d5 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala +@@ -56,7 +56,8 @@ case class HllSketchEstimate(child: Expression) + try { + Math.round(HllSketch.heapify(Memory.wrap(buffer)).getEstimate) + } catch { +- case _: SketchesArgumentException | _: java.lang.Error => ++ case _: SketchesArgumentException | _: java.lang.Error ++ | _: ArrayIndexOutOfBoundsException => + throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName) + } + } +@@ -108,13 +109,15 @@ case class HllUnion(first: Expression, second: Expression, third: Expression) + val sketch1 = try { + HllSketch.heapify(Memory.wrap(value1.asInstanceOf[Array[Byte]])) + } catch { +- case _: SketchesArgumentException | _: java.lang.Error => ++ case _: SketchesArgumentException | _: java.lang.Error ++ | _: ArrayIndexOutOfBoundsException => + throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName) + } + val sketch2 = try { + HllSketch.heapify(Memory.wrap(value2.asInstanceOf[Array[Byte]])) + } catch { +- case _: SketchesArgumentException | _: java.lang.Error => ++ case _: SketchesArgumentException | _: java.lang.Error ++ | _: ArrayIndexOutOfBoundsException => + throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName) + } + val allowDifferentLgConfigK = value3.asInstanceOf[Boolean] +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +index 9db2ac7f9b0..0f74389a9a5 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +@@ -1562,7 +1562,7 @@ abstract class RoundBase(child: Expression, scale: Expression, + val decimal = input1.asInstanceOf[Decimal] + if (_scale >= 0) { + // Overflow cannot happen, so no need to control nullOnOverflow +- decimal.toPrecision(decimal.precision, s, mode) ++ decimal.toPrecision(p, s, mode) + } else { + Decimal(decimal.toBigDecimal.setScale(_scale, mode), p, s) + } +@@ -1634,10 +1634,9 @@ abstract class RoundBase(child: Expression, scale: Expression, + case DecimalType.Fixed(p, s) => + if (_scale >= 0) { + s""" +- ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, +- Decimal.$modeStr(), true, null); ++ ${ev.value} = ${ce.value}.toPrecision($p, $s, Decimal.$modeStr(), true, null); + ${ev.isNull} = ${ev.value} == null;""" +- } else { ++ } else { + s""" + ${ev.value} = new Decimal().set(${ce.value}.toBigDecimal() + .setScale(${_scale}, Decimal.$modeStr()), $p, $s); +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala +index 46815969e7e..d36a71b0439 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala +@@ -26,12 +26,29 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, WINDOW} + * Inserts a `WindowGroupLimit` below `Window` if the `Window` has rank-like functions + * and the function results are further filtered by limit-like predicates. Example query: + * {{{ +- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE rn = 5 +- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE 5 = rn +- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE rn < 5 +- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE 5 > rn +- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE rn <= 5 +- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE 5 >= rn ++ * SELECT * FROM ( ++ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 ++ * ) WHERE rn = 5; ++ * ++ * SELECT * FROM ( ++ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 ++ * ) WHERE 5 = rn; ++ * ++ * SELECT * FROM ( ++ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 ++ * ) WHERE rn < 5; ++ * ++ * SELECT * FROM ( ++ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 ++ * ) WHERE 5 > rn; ++ * ++ * SELECT * FROM ( ++ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 ++ * ) WHERE rn <= 5; ++ * ++ * SELECT * FROM ( ++ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 ++ * ) WHERE 5 >= rn; + * }}} + */ + object InferWindowGroupLimit extends Rule[LogicalPlan] with PredicateHelper { +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +index aa972c81559..7a8deb10f1a 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +@@ -357,6 +357,15 @@ abstract class Optimizer(catalogManager: CatalogManager) + case other => other + } + } ++ ++ private def optimizeSubquery(s: SubqueryExpression): SubqueryExpression = { ++ val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s)) ++ // At this point we have an optimized subquery plan that we are going to attach ++ // to this subquery expression. Here we can safely remove any top level sort ++ // in the plan as tuples produced by a subquery are un-ordered. ++ s.withNewPlan(removeTopLevelSort(newPlan)) ++ } ++ + def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(PLAN_EXPRESSION), ruleId) { + // Do not optimize DPP subquery, as it was created from optimized plan and we should not +@@ -411,12 +420,23 @@ abstract class Optimizer(catalogManager: CatalogManager) + s.withNewPlan( + if (needTopLevelProject) newPlan else newPlan.child + ) ++ case s: Exists => ++ // For an EXISTS join, the subquery might be written as "SELECT * FROM ...". ++ // If we optimize the subquery directly, column pruning may not be applied ++ // effectively. To address this, we add an extra Project node that selects ++ // only the columns referenced in the EXISTS join condition. ++ // This ensures that column pruning can be performed correctly ++ // during subquery optimization. ++ val selectedRefrences = ++ s.plan.output.filter(s.joinCond.flatMap(_.references).contains) ++ val newPlan = if (selectedRefrences.nonEmpty) { ++ s.withNewPlan(Project(selectedRefrences, s.plan)) ++ } else { ++ s ++ } ++ optimizeSubquery(newPlan) + case s: SubqueryExpression => +- val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s)) +- // At this point we have an optimized subquery plan that we are going to attach +- // to this subquery expression. Here we can safely remove any top level sort +- // in the plan as tuples produced by a subquery are un-ordered. +- s.withNewPlan(removeTopLevelSort(newPlan)) ++ optimizeSubquery(s) + } + } + +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +index f8c1b2a9014..94d69fa2179 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +@@ -250,7 +250,7 @@ case class ReplaceData( + write: Option[Write] = None) extends RowLevelWrite { + + override val isByName: Boolean = false +- override val stringArgs: Iterator[Any] = Iterator(table, query, write) ++ override def stringArgs: Iterator[Any] = Iterator(table, query, write) + + override lazy val references: AttributeSet = query.outputSet + +@@ -332,7 +332,7 @@ case class WriteDelta( + write: Option[DeltaWrite] = None) extends RowLevelWrite { + + override val isByName: Boolean = false +- override val stringArgs: Iterator[Any] = Iterator(table, query, write) ++ override def stringArgs: Iterator[Any] = Iterator(table, query, write) + + override lazy val references: AttributeSet = query.outputSet + +@@ -1654,12 +1654,19 @@ case class Call( + } + + override def simpleString(maxFields: Int): String = { +- val name = procedure match { ++ procedure match { + case ResolvedProcedure(catalog, ident, _) => +- s"${quoteIfNeeded(catalog.name)}.${ident.quoted}" ++ val name = s"${quoteIfNeeded(catalog.name)}.${ident.quoted}" ++ simpleString(name, maxFields) + case UnresolvedProcedure(nameParts) => +- nameParts.quoted ++ val name = nameParts.quoted ++ simpleString(name, maxFields) ++ case _ => ++ super.simpleString(maxFields) + } ++ } ++ ++ private def simpleString(name: String, maxFields: Int): String = { + val argsString = truncatedString(args, ", ", maxFields) + s"Call $name($argsString)" + } +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +index 038105f9bfd..dc66b6f30e5 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +@@ -899,10 +899,13 @@ case class KeyGroupedShuffleSpec( + } + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { +- val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map { +- case (c, e: TransformExpression) => TransformExpression( +- e.function, Seq(c), e.numBucketsOpt) +- case (c, _) => c ++ assert(clustering.size == distribution.clustering.size, ++ "Required distributions of join legs should be the same size.") ++ ++ val newExpressions = partitioning.expressions.zip(keyPositions).map { ++ case (te: TransformExpression, positionSet) => ++ te.copy(children = te.children.map(_ => clustering(positionSet.head))) ++ case (_, positionSet) => clustering(positionSet.head) + } + KeyGroupedPartitioning(newExpressions, + partitioning.numPartitions, +diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +index b24ad30e071..72a8c8539bd 100644 +--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala ++++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +@@ -18,6 +18,7 @@ + package org.apache.spark.sql.catalyst.util + + import scala.collection.mutable.ArrayBuffer ++import scala.util.{Failure, Success, Try} + + import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} + import org.apache.spark.internal.{Logging, MDC} +@@ -368,27 +369,33 @@ object ResolveDefaultColumns extends QueryErrorsBase + val defaultSQL = field.metadata.getString(EXISTS_DEFAULT_COLUMN_METADATA_KEY) + + // Parse the expression. +- val expr = Literal.fromSQL(defaultSQL) match { +- // EXISTS_DEFAULT will have a cast from analyze() due to coerceDefaultValue +- // hence we need to add timezone to the cast if necessary +- case c: Cast if c.child.resolved && c.needsTimeZone => +- c.withTimeZone(SQLConf.get.sessionLocalTimeZone) +- case e: Expression => e +- } ++ val resolvedExpr = Try(Literal.fromSQL(defaultSQL)) match { ++ case Success(literal) => ++ val expr = literal match { ++ // EXISTS_DEFAULT will have a cast from analyze() due to coerceDefaultValue ++ // hence we need to add timezone to the cast if necessary ++ case c: Cast if c.child.resolved && c.needsTimeZone => ++ c.withTimeZone(SQLConf.get.sessionLocalTimeZone) ++ case e: Expression => e ++ } + +- // Check invariants +- if (expr.containsPattern(PLAN_EXPRESSION)) { +- throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions( +- "", field.name, defaultSQL) +- } ++ // Check invariants ++ if (expr.containsPattern(PLAN_EXPRESSION)) { ++ throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions( ++ "", field.name, defaultSQL) ++ } ++ ++ expr match { ++ case _: ExprLiteral => expr ++ case c: Cast if c.resolved => expr ++ case _ => ++ fallbackResolveExistenceDefaultValue(field) ++ } + +- val resolvedExpr = expr match { +- case _: ExprLiteral => expr +- case c: Cast if c.resolved => expr +- case _ => ++ case Failure(_) => ++ // If Literal.fromSQL fails, use fallback resolution + fallbackResolveExistenceDefaultValue(field) + } +- + coerceDefaultValue(resolvedExpr, field.dataType, "", field.name, defaultSQL) + } + +diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +index 616c6d65636..0d26b390643 100644 +--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala ++++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +@@ -612,6 +612,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes + provider, + nullable = true)) + .resolveAndBind() ++ assert(encoder.isInstanceOf[Serializable]) + assert(encoder.schema == new StructType().add("value", BinaryType)) + val toRow = encoder.createSerializer() + val fromRow = encoder.createDeserializer() +@@ -659,6 +660,22 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes + assert(fromRow(toRow(new Wrapper(Row(9L, "x")))) == new Wrapper(Row(9L, "x"))) + } + ++ test("SPARK-52614: transforming encoder row encoder in product encoder") { ++ val schema = new StructType().add("a", LongType).add("b", StringType) ++ val wrapperEncoder = TransformingEncoder( ++ classTag[Wrapper[Row]], ++ RowEncoder.encoderFor(schema), ++ new WrapperCodecProvider[Row]) ++ val encoder = ExpressionEncoder(ProductEncoder( ++ classTag[V[Wrapper[Row]]], ++ Seq(EncoderField("v", wrapperEncoder, nullable = false, Metadata.empty)), ++ None)) ++ .resolveAndBind() ++ val toRow = encoder.createSerializer() ++ val fromRow = encoder.createDeserializer() ++ assert(fromRow(toRow(V(new Wrapper(Row(9L, "x"))))) == V(new Wrapper(Row(9L, "x")))) ++ } ++ + // below tests are related to SPARK-49960 and TransformingEncoder usage + test("""Encoder with OptionEncoder of transformation""".stripMargin) { + type T = Option[V[V[Int]]] +@@ -749,6 +766,24 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes + testDataTransformingEnc(enc, data) + } + ++ test("SPARK-52601 TransformingEncoder from primitive to timestamp") { ++ val enc: AgnosticEncoder[Long] = ++ TransformingEncoder[Long, java.sql.Timestamp]( ++ classTag, ++ TimestampEncoder(true), ++ () => ++ new Codec[Long, java.sql.Timestamp] with Serializable { ++ override def encode(in: Long): Timestamp = Timestamp.from(microsToInstant(in)) ++ override def decode(out: Timestamp): Long = instantToMicros(out.toInstant) ++ } ++ ) ++ val data: Seq[Long] = Seq(0L, 1L, 2L) ++ ++ assert(enc.dataType === TimestampType) ++ ++ testDataTransformingEnc(enc, data) ++ } ++ + val longEncForTimestamp: AgnosticEncoder[V[Long]] = + TransformingEncoder[V[Long], java.sql.Timestamp]( + classTag, +diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala +index 0841702cc51..0f7f5ca54be 100644 +--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala ++++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala +@@ -108,4 +108,49 @@ class DatasketchesHllSketchSuite extends SparkFunSuite { + + assert(HllSketch.heapify(Memory.wrap(binary3.asInstanceOf[Array[Byte]])).getLgConfigK == 12) + } ++ ++ test("HllUnionAgg throws proper error for invalid binary input causing ArrayIndexOutOfBounds") { ++ val aggFunc = new HllUnionAgg(BoundReference(0, BinaryType, nullable = true), true) ++ val union = aggFunc.createAggregationBuffer() ++ ++ // Craft a byte array that passes initial size checks but has an invalid CurMode ordinal. ++ // HLL preamble layout: ++ // Byte 0: preInts (preamble size in ints) ++ // Byte 1: serVer (must be 1) ++ // Byte 2: famId (must be 7 for HLL) ++ // Byte 3: lgK (4-21) ++ // Byte 5: flags ++ // Byte 7: modeByte - bits 0-1 contain curMode ordinal (0=LIST, 1=SET, 2=HLL) ++ // ++ // Setting bits 0-1 of byte 7 to 0b11 (=3) causes CurMode.fromOrdinal(3) to throw ++ // ArrayIndexOutOfBoundsException since CurMode only has ordinals 0, 1, 2. ++ // This happens in PreambleUtil.extractCurMode() before other validations run. ++ val invalidBinary = Array[Byte]( ++ 2, // byte 0: preInts = 2 (LIST_PREINTS, passes check) ++ 1, // byte 1: serVer = 1 (valid) ++ 7, // byte 2: famId = 7 (HLL family) ++ 12, // byte 3: lgK = 12 (valid range 4-21) ++ 0, // byte 4: unused ++ 0, // byte 5: flags = 0 ++ 0, // byte 6: unused ++ 3 // byte 7: modeByte with bits 0-1 = 0b11 = 3 (INVALID curMode ordinal!) ++ ) ++ ++ val exception = intercept[Exception] { ++ aggFunc.update(union, InternalRow(invalidBinary)) ++ } ++ ++ // Verify that ArrayIndexOutOfBoundsException is properly caught and converted ++ // to the user-friendly HLL_INVALID_INPUT_SKETCH_BUFFER error ++ assert( ++ !exception.isInstanceOf[ArrayIndexOutOfBoundsException], ++ s"ArrayIndexOutOfBoundsException should be caught and converted to " + ++ s"HLL_INVALID_INPUT_SKETCH_BUFFER error, but got: ${exception.getClass.getName}" ++ ) ++ assert( ++ exception.getMessage.contains("HLL_INVALID_INPUT_SKETCH_BUFFER"), ++ s"Expected HLL_INVALID_INPUT_SKETCH_BUFFER error, " + ++ s"but got: ${exception.getClass.getName}: ${exception.getMessage}" ++ ) ++ } + } +diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +index 5dd45d3d449..42579f6cc6e 100644 +--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala ++++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +@@ -856,6 +856,13 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { + "CAST(CURRENT_TIMESTAMP AS BIGINT)") + .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, + "CAST(CURRENT_TIMESTAMP AS BIGINT)") ++ .build()), ++ StructField("c3", StringType, true, ++ new MetadataBuilder() ++ .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, ++ "CONCAT(YEAR(CURRENT_DATE), LPAD(WEEKOFYEAR(CURRENT_DATE), 2, '0'))") ++ .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, ++ "CONCAT(YEAR(CURRENT_DATE), LPAD(WEEKOFYEAR(CURRENT_DATE), 2, '0'))") + .build()))) + val res = ResolveDefaultColumns.existenceDefaultValues(source) + assert(res(0) == null) +@@ -864,5 +871,9 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { + val res2Wrapper = new LongWrapper + assert(res(2).asInstanceOf[UTF8String].toLong(res2Wrapper)) + assert(res2Wrapper.value > 0) ++ ++ val res3Wrapper = new LongWrapper ++ assert(res(3).asInstanceOf[UTF8String].toLong(res3Wrapper)) ++ assert(res3Wrapper.value > 0) + } + } +diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala +index 04b090d7001..2f58e722c05 100644 +--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala ++++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala +@@ -17,6 +17,7 @@ + + package org.apache.spark.sql.types + ++import org.apache.spark.sql.Row + import org.apache.spark.sql.catalyst.InternalRow + import org.apache.spark.sql.catalyst.expressions.GenericInternalRow + import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +@@ -132,3 +133,22 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] + + override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] + } ++ ++ ++class ExampleIntRowUDT(cols: Int) extends UserDefinedType[Row] { ++ override def sqlType: DataType = { ++ StructType((0 until cols).map(i => ++ StructField(s"col$i", IntegerType, nullable = false))) ++ } ++ ++ override def serialize(obj: Row): InternalRow = { ++ InternalRow.fromSeq(obj.toSeq) ++ } ++ ++ override def deserialize(datum: Any): Row = { ++ val internalRow = datum.asInstanceOf[InternalRow] ++ Row.fromSeq(internalRow.toSeq(sqlType.asInstanceOf[StructType])) ++ } ++ ++ override def userClass: Class[Row] = classOf[Row] ++} +diff --git a/sql/connect/client/jvm/pom.xml b/sql/connect/client/jvm/pom.xml +index 3de1cf368f8..bd586e86adc 100644 +--- a/sql/connect/client/jvm/pom.xml ++++ b/sql/connect/client/jvm/pom.xml +@@ -22,7 +22,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../../../pom.xml + + +diff --git a/sql/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar b/sql/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar +new file mode 100644 +index 00000000000..6dee8fcd9c9 +Binary files /dev/null and b/sql/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar differ +diff --git a/sql/connect/client/jvm/src/test/resources/udf2.13.jar b/sql/connect/client/jvm/src/test/resources/udf2.13.jar +new file mode 100644 +index 00000000000..c89830f127c +Binary files /dev/null and b/sql/connect/client/jvm/src/test/resources/udf2.13.jar differ +diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +index a548ec7007d..e19f1eacfd8 100644 +--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala ++++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +@@ -3390,12 +3390,24 @@ class PlanGenerationTestSuite + fn.typedLit(java.time.Duration.ofSeconds(200L)), + fn.typedLit(java.time.Period.ofDays(100)), + fn.typedLit(new CalendarInterval(2, 20, 100L)), ++ fn.typedLit( ++ ( ++ java.time.LocalDate.of(2020, 10, 10), ++ java.time.Instant.ofEpochMilli(1677155519808L), ++ new java.sql.Timestamp(12345L), ++ java.time.LocalDateTime.of(2023, 2, 23, 20, 36), ++ java.sql.Date.valueOf("2023-02-23"), ++ java.time.Duration.ofSeconds(200L), ++ java.time.Period.ofDays(100), ++ new CalendarInterval(2, 20, 100L))), + + // Handle parameterized scala types e.g.: List, Seq and Map. + fn.typedLit(Some(1)), + fn.typedLit(Array(1, 2, 3)), ++ fn.typedLit[Array[Integer]](Array(null, null)), + fn.typedLit(Seq(1, 2, 3)), +- fn.typedLit(Map("a" -> 1, "b" -> 2)), ++ fn.typedLit(mutable.LinkedHashMap("a" -> 1, "b" -> 2)), ++ fn.typedLit(mutable.LinkedHashMap[String, Integer]("a" -> null, "b" -> null)), + fn.typedLit(("a", 2, 1.0)), + fn.typedLit[Option[Int]](None), + fn.typedLit[Array[Option[Int]]](Array(Some(1))), +diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala +index 3b6dd090caf..afc2b1db023 100644 +--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala ++++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala +@@ -1681,6 +1681,13 @@ class ClientE2ETestSuite + assert(df.count() == 100) + } + } ++ ++ test("SPARK-53553: null value handling in literals") { ++ val df = spark.sql("select 1").select(typedlit(Array[Integer](1, null)).as("arr_col")) ++ val result = df.collect() ++ assert(result.length === 1) ++ assert(result(0).getAs[Array[Integer]]("arr_col") === Array(1, null)) ++ } + } + + private[sql] case class ClassData(a: String, b: Int) +diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala +index 1d022489b70..4c0073cad56 100644 +--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala ++++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala +@@ -16,7 +16,7 @@ + */ + package org.apache.spark.sql.connect + +-import java.util.concurrent.ForkJoinPool ++import java.util.concurrent.Executors + + import scala.collection.mutable + import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} +@@ -146,7 +146,7 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession { + // global ExecutionContext has only 2 threads in Apache Spark CI + // create own thread pool for four Futures used in this test + val numThreads = 4 +- val fpool = new ForkJoinPool(numThreads) ++ val fpool = Executors.newFixedThreadPool(numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + + val q1 = Future { +diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +index cbaa4f5ea07..8afa28b1f38 100644 +--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala ++++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +@@ -234,6 +234,8 @@ object CheckConnectJvmClientCompatibility { + "org.apache.spark.sql.artifact.ArtifactManager$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.artifact.ArtifactManager$SparkContextResourceType$"), ++ ProblemFilters.exclude[MissingClassProblem]( ++ "org.apache.spark.sql.artifact.ArtifactManager$StateCleanupRunner"), + + // ColumnNode conversions + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession"), +diff --git a/sql/connect/common/pom.xml b/sql/connect/common/pom.xml +index 1966bf4b303..58441cde7b3 100644 +--- a/sql/connect/common/pom.xml ++++ b/sql/connect/common/pom.xml +@@ -22,7 +22,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../../pom.xml + + +diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +index 1f3496fa898..d64f5d7cdf2 100644 +--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala ++++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +@@ -163,6 +163,14 @@ object LiteralValueProtoConverter { + } + + (literal, dataType) match { ++ case (v: Option[_], _: DataType) => ++ if (v.isDefined) { ++ toLiteralProtoBuilder(v.get) ++ } else { ++ builder.setNull(toConnectProtoType(dataType)) ++ } ++ case (null, _) => ++ builder.setNull(toConnectProtoType(dataType)) + case (v: mutable.ArraySeq[_], ArrayType(_, _)) => + toLiteralProtoBuilder(v.array, dataType) + case (v: immutable.ArraySeq[_], ArrayType(_, _)) => +@@ -175,12 +183,6 @@ object LiteralValueProtoConverter { + builder.setMap(mapBuilder(v, keyType, valueType)) + case (v, structType: StructType) => + builder.setStruct(structBuilder(v, structType)) +- case (v: Option[_], _: DataType) => +- if (v.isDefined) { +- toLiteralProtoBuilder(v.get) +- } else { +- builder.setNull(toConnectProtoType(dataType)) +- } + case _ => toLiteralProtoBuilder(literal) + } + } +@@ -296,8 +298,8 @@ object LiteralValueProtoConverter { + } + } + +- private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { +- if (dataType.hasShort) { v => ++ private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { ++ val converter: proto.Expression.Literal => Any = if (dataType.hasShort) { v => + v.getShort.toShort + } else if (dataType.hasInteger) { v => + v.getInteger +@@ -316,15 +318,15 @@ object LiteralValueProtoConverter { + } else if (dataType.hasBinary) { v => + v.getBinary.toByteArray + } else if (dataType.hasDate) { v => +- v.getDate ++ SparkDateTimeUtils.toJavaDate(v.getDate) + } else if (dataType.hasTimestamp) { v => +- v.getTimestamp ++ SparkDateTimeUtils.toJavaTimestamp(v.getTimestamp) + } else if (dataType.hasTimestampNtz) { v => +- v.getTimestampNtz ++ SparkDateTimeUtils.microsToLocalDateTime(v.getTimestampNtz) + } else if (dataType.hasDayTimeInterval) { v => +- v.getDayTimeInterval ++ SparkIntervalUtils.microsToDuration(v.getDayTimeInterval) + } else if (dataType.hasYearMonthInterval) { v => +- v.getYearMonthInterval ++ SparkIntervalUtils.monthsToPeriod(v.getYearMonthInterval) + } else if (dataType.hasDecimal) { v => + Decimal(v.getDecimal.getValue) + } else if (dataType.hasCalendarInterval) { v => +@@ -339,6 +341,7 @@ object LiteralValueProtoConverter { + } else { + throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)") + } ++ v => if (v.hasNull) null else converter(v) + } + + def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = { +@@ -354,7 +357,7 @@ object LiteralValueProtoConverter { + builder.result() + } + +- makeArrayData(getConverter(array.getElementType)) ++ makeArrayData(getScalaConverter(array.getElementType)) + } + + def toCatalystMap(map: proto.Expression.Literal.Map): mutable.Map[_, _] = { +@@ -373,7 +376,7 @@ object LiteralValueProtoConverter { + builder + } + +- makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType)) ++ makeMapData(getScalaConverter(map.getKeyType), getScalaConverter(map.getValueType)) + } + + def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = { +@@ -392,7 +395,7 @@ object LiteralValueProtoConverter { + val structData = elements + .zip(dataTypes) + .map { case (element, dataType) => +- getConverter(dataType)(element) ++ getScalaConverter(dataType)(element) + } + .asInstanceOf[scala.collection.Seq[Object]] + .toSeq +diff --git a/sql/connect/common/src/test/resources/artifact-tests/Hello.class b/sql/connect/common/src/test/resources/artifact-tests/Hello.class +new file mode 100644 +index 00000000000..56725764de2 +Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/Hello.class differ +diff --git a/sql/connect/common/src/test/resources/artifact-tests/junitLargeJar.jar b/sql/connect/common/src/test/resources/artifact-tests/junitLargeJar.jar +new file mode 100755 +index 00000000000..6da55d8b852 +Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/junitLargeJar.jar differ +diff --git a/sql/connect/common/src/test/resources/artifact-tests/smallClassFile.class b/sql/connect/common/src/test/resources/artifact-tests/smallClassFile.class +new file mode 100755 +index 00000000000..e796030e471 +Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/smallClassFile.class differ +diff --git a/sql/connect/common/src/test/resources/artifact-tests/smallClassFileDup.class b/sql/connect/common/src/test/resources/artifact-tests/smallClassFileDup.class +new file mode 100755 +index 00000000000..e796030e471 +Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/smallClassFileDup.class differ +diff --git a/sql/connect/common/src/test/resources/artifact-tests/smallJar.jar b/sql/connect/common/src/test/resources/artifact-tests/smallJar.jar +new file mode 100755 +index 00000000000..3c4930e8e95 +Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/smallJar.jar differ +diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain +index 6d854da250f..a566430136f 100644 +--- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain ++++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain +@@ -1,2 +1,2 @@ +-Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 18 more fields] ++Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 21 more fields] + +- LocalRelation , [id#0L, a#0, b#0] +diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +index e56b6e1f3ee..456033244a9 100644 +--- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json ++++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +@@ -77,7 +77,8 @@ + }, { + "literal": { + "null": { +- "null": { ++ "string": { ++ "collation": "UTF8_BINARY" + } + } + }, +@@ -652,6 +653,114 @@ + } + } + } ++ }, { ++ "literal": { ++ "struct": { ++ "structType": { ++ "struct": { ++ "fields": [{ ++ "name": "_1", ++ "dataType": { ++ "date": { ++ } ++ }, ++ "nullable": true ++ }, { ++ "name": "_2", ++ "dataType": { ++ "timestamp": { ++ } ++ }, ++ "nullable": true ++ }, { ++ "name": "_3", ++ "dataType": { ++ "timestamp": { ++ } ++ }, ++ "nullable": true ++ }, { ++ "name": "_4", ++ "dataType": { ++ "timestampNtz": { ++ } ++ }, ++ "nullable": true ++ }, { ++ "name": "_5", ++ "dataType": { ++ "date": { ++ } ++ }, ++ "nullable": true ++ }, { ++ "name": "_6", ++ "dataType": { ++ "dayTimeInterval": { ++ "startField": 0, ++ "endField": 3 ++ } ++ }, ++ "nullable": true ++ }, { ++ "name": "_7", ++ "dataType": { ++ "yearMonthInterval": { ++ "startField": 0, ++ "endField": 1 ++ } ++ }, ++ "nullable": true ++ }, { ++ "name": "_8", ++ "dataType": { ++ "calendarInterval": { ++ } ++ }, ++ "nullable": true ++ }] ++ } ++ }, ++ "elements": [{ ++ "date": 18545 ++ }, { ++ "timestamp": "1677155519808000" ++ }, { ++ "timestamp": "12345000" ++ }, { ++ "timestampNtz": "1677184560000000" ++ }, { ++ "date": 19411 ++ }, { ++ "dayTimeInterval": "200000000" ++ }, { ++ "yearMonthInterval": 0 ++ }, { ++ "calendarInterval": { ++ "months": 2, ++ "days": 20, ++ "microseconds": "100" ++ } ++ }] ++ } ++ }, ++ "common": { ++ "origin": { ++ "jvmOrigin": { ++ "stackTrace": [{ ++ "classLoaderName": "app", ++ "declaringClass": "org.apache.spark.sql.functions$", ++ "methodName": "typedLit", ++ "fileName": "functions.scala" ++ }, { ++ "classLoaderName": "app", ++ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", ++ "methodName": "~~trimmed~anonfun~~", ++ "fileName": "PlanGenerationTestSuite.scala" ++ }] ++ } ++ } ++ } + }, { + "literal": { + "integer": 1 +@@ -706,6 +815,43 @@ + } + } + } ++ }, { ++ "literal": { ++ "array": { ++ "elementType": { ++ "integer": { ++ } ++ }, ++ "elements": [{ ++ "null": { ++ "integer": { ++ } ++ } ++ }, { ++ "null": { ++ "integer": { ++ } ++ } ++ }] ++ } ++ }, ++ "common": { ++ "origin": { ++ "jvmOrigin": { ++ "stackTrace": [{ ++ "classLoaderName": "app", ++ "declaringClass": "org.apache.spark.sql.functions$", ++ "methodName": "typedLit", ++ "fileName": "functions.scala" ++ }, { ++ "classLoaderName": "app", ++ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", ++ "methodName": "~~trimmed~anonfun~~", ++ "fileName": "PlanGenerationTestSuite.scala" ++ }] ++ } ++ } ++ } + }, { + "literal": { + "array": { +@@ -780,6 +926,53 @@ + } + } + } ++ }, { ++ "literal": { ++ "map": { ++ "keyType": { ++ "string": { ++ "collation": "UTF8_BINARY" ++ } ++ }, ++ "valueType": { ++ "integer": { ++ } ++ }, ++ "keys": [{ ++ "string": "a" ++ }, { ++ "string": "b" ++ }], ++ "values": [{ ++ "null": { ++ "integer": { ++ } ++ } ++ }, { ++ "null": { ++ "integer": { ++ } ++ } ++ }] ++ } ++ }, ++ "common": { ++ "origin": { ++ "jvmOrigin": { ++ "stackTrace": [{ ++ "classLoaderName": "app", ++ "declaringClass": "org.apache.spark.sql.functions$", ++ "methodName": "typedLit", ++ "fileName": "functions.scala" ++ }, { ++ "classLoaderName": "app", ++ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", ++ "methodName": "~~trimmed~anonfun~~", ++ "fileName": "PlanGenerationTestSuite.scala" ++ }] ++ } ++ } ++ } + }, { + "literal": { + "struct": { +diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin +index 38a6ce63005..749da55007d 100644 +Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ +diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml +index d4b98aaf26d..ab9470eeeef 100644 +--- a/sql/connect/server/pom.xml ++++ b/sql/connect/server/pom.xml +@@ -22,7 +22,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../../pom.xml + + +diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala +index 3a707495ff3..785b254d7af 100644 +--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala ++++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala +@@ -263,7 +263,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( + timeoutNs = Math.min(progressTimeout * NANOS_PER_MILLIS, timeoutNs) + } + logTrace(s"Wait for response to become available with timeout=$timeoutNs ns.") +- executionObserver.responseLock.wait(timeoutNs / NANOS_PER_MILLIS) ++ executionObserver.responseLock.wait(Math.max(1, timeoutNs / NANOS_PER_MILLIS)) + enqueueProgressMessage(force = true) + logTrace(s"Reacquired executionObserver lock after waiting.") + sleepEnd = System.nanoTime() +@@ -384,7 +384,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( + val timeoutNs = Math.max(1, deadlineTimeNs - System.nanoTime()) + var sleepStart = System.nanoTime() + logTrace(s"Wait for grpcCallObserver to become ready with timeout=$timeoutNs ns.") +- grpcCallObserverReadySignal.wait(timeoutNs / NANOS_PER_MILLIS) ++ grpcCallObserverReadySignal.wait(Math.max(1, timeoutNs / NANOS_PER_MILLIS)) + logTrace(s"Reacquired grpcCallObserverReadySignal lock after waiting.") + sleepEnd = System.nanoTime() + } +diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +index bf1b6e7e00e..d5b81223707 100644 +--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala ++++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +@@ -32,7 +32,7 @@ import io.grpc.{Context, Status, StatusRuntimeException} + import io.grpc.stub.StreamObserver + import org.apache.commons.lang3.exception.ExceptionUtils + +-import org.apache.spark.{SparkEnv, TaskContext} ++import org.apache.spark.{SparkEnv, SparkException, TaskContext} + import org.apache.spark.annotation.{DeveloperApi, Since} + import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} + import org.apache.spark.connect.proto +@@ -44,7 +44,7 @@ import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase + import org.apache.spark.internal.{Logging, LogKeys, MDC} + import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} + import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} +-import org.apache.spark.sql.{Column, Encoders, ForeachWriter, Observation, Row} ++import org.apache.spark.sql.{AnalysisException, Column, Encoders, ForeachWriter, Observation, Row} + import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} + import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns, UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTranspose} + import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder} +@@ -1091,9 +1091,20 @@ class SparkConnectPlanner( + // for backward compatibility + rel.getRenameColumnsMapMap.asScala.toSeq.unzip + } +- Project( +- Seq(UnresolvedStarWithColumnsRenames(existingNames = colNames, newNames = newColNames)), +- transformRelation(rel.getInput)) ++ ++ val child = transformRelation(rel.getInput) ++ try { ++ // Try the eager analysis first. ++ Dataset ++ .ofRows(session, child) ++ .withColumnsRenamed(colNames, newColNames) ++ .logicalPlan ++ } catch { ++ case _: AnalysisException | _: SparkException => ++ Project( ++ Seq(UnresolvedStarWithColumnsRenames(existingNames = colNames, newNames = newColNames)), ++ child) ++ } + } + + private def transformWithColumns(rel: proto.WithColumns): LogicalPlan = { +@@ -1113,13 +1124,23 @@ class SparkConnectPlanner( + (alias.getName(0), transformExpression(alias.getExpr), metadata) + }.unzip3 + +- Project( +- Seq( +- UnresolvedStarWithColumns( +- colNames = colNames, +- exprs = exprs, +- explicitMetadata = Some(metadata))), +- transformRelation(rel.getInput)) ++ val child = transformRelation(rel.getInput) ++ try { ++ // Try the eager analysis first. ++ Dataset ++ .ofRows(session, child) ++ .withColumns(colNames, exprs.map(expr => Column(expr)), metadata) ++ .logicalPlan ++ } catch { ++ case _: AnalysisException | _: SparkException => ++ Project( ++ Seq( ++ UnresolvedStarWithColumns( ++ colNames = colNames, ++ exprs = exprs, ++ explicitMetadata = Some(metadata))), ++ child) ++ } + } + + private def transformWithWatermark(rel: proto.WithWatermark): LogicalPlan = { +diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +index 5e887256916..c6daa92e973 100644 +--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala ++++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +@@ -193,10 +193,11 @@ class SparkConnectServiceSuite + } + + override def onCompleted(): Unit = { ++ verifyEvents.onCompleted(Some(100)) + done = true + } + }) +- verifyEvents.onCompleted(Some(100)) ++ verifyEvents.assertClosed() + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + +@@ -294,10 +295,11 @@ class SparkConnectServiceSuite + } + + override def onCompleted(): Unit = { ++ verifyEvents.onCompleted(Some(6)) + done = true + } + }) +- verifyEvents.onCompleted(Some(6)) ++ verifyEvents.assertClosed() + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + +@@ -530,10 +532,11 @@ class SparkConnectServiceSuite + } + + override def onCompleted(): Unit = { ++ verifyEvents.onCompleted(producedNumRows) + done = true + } + }) +- verifyEvents.onCompleted(producedNumRows) ++ verifyEvents.assertClosed() + // The current implementation is expected to be blocking. + // This is here to make sure it is. + assert(done) +@@ -621,7 +624,7 @@ class SparkConnectServiceSuite + } + }) + thread.join() +- verifyEvents.onCompleted() ++ verifyEvents.assertClosed() + } + } + +@@ -684,7 +687,7 @@ class SparkConnectServiceSuite + } + }) + assert(failures.isEmpty, s"this should have no failures but got $failures") +- verifyEvents.onCompleted() ++ verifyEvents.assertClosed() + } + } + +@@ -883,9 +886,6 @@ class SparkConnectServiceSuite + } + } + def onNext(v: proto.ExecutePlanResponse): Unit = { +- if (v.hasSchema) { +- assert(executeHolder.eventsManager.status == ExecuteStatus.Analyzed) +- } + if (v.hasMetrics) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Finished) + } +@@ -896,6 +896,8 @@ class SparkConnectServiceSuite + } + def onCompleted(producedRowCount: Option[Long] = None): Unit = { + assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) ++ } ++ def assertClosed(): Unit = { + // The eventsManager is closed asynchronously + Eventually.eventually(EVENT_WAIT_TIMEOUT) { + assert( +diff --git a/sql/connect/shims/pom.xml b/sql/connect/shims/pom.xml +index 236d1624bfa..ad4d88bf293 100644 +--- a/sql/connect/shims/pom.xml ++++ b/sql/connect/shims/pom.xml +@@ -22,7 +22,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../../pom.xml + + diff --git a/sql/core/pom.xml b/sql/core/pom.xml -index 6e73c154fcc..642d9b444e5 100644 +index dcf6223a98b..642d9b444e5 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml +@@ -22,7 +22,7 @@ + + org.apache.spark + spark-parent_2.13 +- 4.0.1 ++ 4.0.3-SNAPSHOT + ../../pom.xml + + @@ -90,6 +90,10 @@ org.apache.spark spark-tags_${scala.binary.version} @@ -52,6 +4456,33 @@ index 6e73c154fcc..642d9b444e5 100644 - 10.16.1.1 - 1.15.2 -- 2.1.3 -+ 2.1.4 - shaded-protobuf - 11.0.24 - 5.0.0 @@ -148,6 +148,8 @@ 4.0.3 2.5.3 @@ -2209,29 +11,7 @@ index 22922143fc3..568e1f12f81 100644 org.apache.datasketches -@@ -3150,6 +3177,10 @@ - com.google.common - ${spark.shade.packageName}.guava - -+ -+ com.google.thirdparty -+ ${spark.shade.packageName}.guava.thirdparty -+ - - org.dmg.pmml - ${spark.shade.packageName}.dmg.pmml -diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala -index cded163e81f..c484fef8516 100644 ---- a/project/SparkBuild.scala -+++ b/project/SparkBuild.scala -@@ -364,7 +364,8 @@ object SparkBuild extends PomBuild { - /* Enable shared settings on all projects */ - (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ -- ExcludedDependencies.settings ++ Checkstyle.settings ++ ExcludeShims.settings)) -+ ExcludedDependencies.settings ++ (if (noLintOnCompile) Nil else Checkstyle.settings) ++ -+ ExcludeShims.settings)) - - /* Enable tests settings for all projects except examples, assembly and tools */ - (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) -@@ -1471,7 +1472,7 @@ object Unidoc { - ) ++ ( - // Add links to sources when generating Scaladoc for a non-snapshot release - if (!isSnapshot.value) { -- Opts.doc.sourceUrl(unidocSourceBase.value + "€{FILE_PATH}.scala") -+ Opts.doc.sourceUrl(unidocSourceBase.value + "€{FILE_PATH_EXT}") - } else { - Seq() - } -diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py -index da4d25cc908..661ba5a8a7e 100755 ---- a/python/packaging/classic/setup.py -+++ b/python/packaging/classic/setup.py -@@ -344,7 +344,7 @@ try: - license="http://www.apache.org/licenses/LICENSE-2.0", - # Don't forget to update python/docs/source/getting_started/install.rst - # if you're updating the versions or dependencies. -- install_requires=["py4j==0.10.9.9"], -+ install_requires=["py4j>=0.10.9.7,<0.10.9.10"], - extras_require={ - "ml": ["numpy>=%s" % _minimum_numpy_version], - "mllib": ["numpy>=%s" % _minimum_numpy_version], -diff --git a/python/pyspark/ml/connect/feature.py b/python/pyspark/ml/connect/feature.py -index a0e5b6a943d..e08b37337c6 100644 ---- a/python/pyspark/ml/connect/feature.py -+++ b/python/pyspark/ml/connect/feature.py -@@ -15,11 +15,11 @@ - # limitations under the License. - # - --import pickle - from typing import Any, Union, List, Tuple, Callable, Dict, Optional - - import numpy as np - import pandas as pd -+import pyarrow as pa - - from pyspark import keyword_only - from pyspark.sql import DataFrame -@@ -132,27 +132,29 @@ class MaxAbsScalerModel(Model, HasInputCol, HasOutputCol, ParamsReadWrite, CoreM - return transform_fn - - def _get_core_model_filename(self) -> str: -- return self.__class__.__name__ + ".sklearn.pkl" -+ return self.__class__.__name__ + ".arrow.parquet" - - def _save_core_model(self, path: str) -> None: -- from sklearn.preprocessing import MaxAbsScaler as sk_MaxAbsScaler -- -- sk_model = sk_MaxAbsScaler() -- sk_model.scale_ = self.scale_values -- sk_model.max_abs_ = self.max_abs_values -- sk_model.n_features_in_ = len(self.max_abs_values) # type: ignore[arg-type] -- sk_model.n_samples_seen_ = self.n_samples_seen -- -- with open(path, "wb") as fp: -- pickle.dump(sk_model, fp) -+ import pyarrow.parquet as pq -+ -+ table = pa.Table.from_arrays( -+ [ -+ pa.array([self.scale_values], pa.list_(pa.float64())), -+ pa.array([self.max_abs_values], pa.list_(pa.float64())), -+ pa.array([self.n_samples_seen], pa.int64()), -+ ], -+ names=["scale", "max_abs", "n_samples"], -+ ) -+ pq.write_table(table, path) - - def _load_core_model(self, path: str) -> None: -- with open(path, "rb") as fp: -- sk_model = pickle.load(fp) -+ import pyarrow.parquet as pq -+ -+ table = pq.read_table(path) - -- self.max_abs_values = sk_model.max_abs_ -- self.scale_values = sk_model.scale_ -- self.n_samples_seen = sk_model.n_samples_seen_ -+ self.max_abs_values = np.array(table.column("scale")[0].as_py()) -+ self.scale_values = np.array(table.column("max_abs")[0].as_py()) -+ self.n_samples_seen = table.column("n_samples")[0].as_py() - - - class StandardScaler(Estimator, HasInputCol, HasOutputCol, ParamsReadWrite): -@@ -251,29 +253,31 @@ class StandardScalerModel(Model, HasInputCol, HasOutputCol, ParamsReadWrite, Cor - return transform_fn - - def _get_core_model_filename(self) -> str: -- return self.__class__.__name__ + ".sklearn.pkl" -+ return self.__class__.__name__ + ".arrow.parquet" - - def _save_core_model(self, path: str) -> None: -- from sklearn.preprocessing import StandardScaler as sk_StandardScaler -- -- sk_model = sk_StandardScaler(with_mean=True, with_std=True) -- sk_model.scale_ = self.scale_values -- sk_model.var_ = self.std_values * self.std_values # type: ignore[operator] -- sk_model.mean_ = self.mean_values -- sk_model.n_features_in_ = len(self.std_values) # type: ignore[arg-type] -- sk_model.n_samples_seen_ = self.n_samples_seen -- -- with open(path, "wb") as fp: -- pickle.dump(sk_model, fp) -+ import pyarrow.parquet as pq -+ -+ table = pa.Table.from_arrays( -+ [ -+ pa.array([self.scale_values], pa.list_(pa.float64())), -+ pa.array([self.mean_values], pa.list_(pa.float64())), -+ pa.array([self.std_values], pa.list_(pa.float64())), -+ pa.array([self.n_samples_seen], pa.int64()), -+ ], -+ names=["scale", "mean", "std", "n_samples"], -+ ) -+ pq.write_table(table, path) - - def _load_core_model(self, path: str) -> None: -- with open(path, "rb") as fp: -- sk_model = pickle.load(fp) -+ import pyarrow.parquet as pq -+ -+ table = pq.read_table(path) - -- self.std_values = np.sqrt(sk_model.var_) -- self.scale_values = sk_model.scale_ -- self.mean_values = sk_model.mean_ -- self.n_samples_seen = sk_model.n_samples_seen_ -+ self.scale_values = np.array(table.column("scale")[0].as_py()) -+ self.mean_values = np.array(table.column("mean")[0].as_py()) -+ self.std_values = np.array(table.column("std")[0].as_py()) -+ self.n_samples_seen = table.column("n_samples")[0].as_py() - - - class ArrayAssembler( -diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py -index 6812db77845..96f153b7b1b 100644 ---- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py -+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py -@@ -17,7 +17,6 @@ - # - - import os --import pickle - import tempfile - import unittest - -@@ -85,12 +84,6 @@ class FeatureTestsMixin: - np.testing.assert_allclose(model.max_abs_values, loaded_model.max_abs_values) - assert model.n_samples_seen == loaded_model.n_samples_seen - -- # Test loading core model as scikit-learn model -- with open(os.path.join(model_path, "MaxAbsScalerModel.sklearn.pkl"), "rb") as f: -- sk_model = pickle.load(f) -- sk_result = sk_model.transform(np.stack(list(local_df1.features))) -- np.testing.assert_allclose(sk_result, expected_result) -- - def test_standard_scaler(self): - df1 = self.spark.createDataFrame( - [ -@@ -141,12 +134,6 @@ class FeatureTestsMixin: - np.testing.assert_allclose(model.scale_values, loaded_model.scale_values) - assert model.n_samples_seen == loaded_model.n_samples_seen - -- # Test loading core model as scikit-learn model -- with open(os.path.join(model_path, "StandardScalerModel.sklearn.pkl"), "rb") as f: -- sk_model = pickle.load(f) -- sk_result = sk_model.transform(np.stack(list(local_df1.features))) -- np.testing.assert_allclose(sk_result, expected_result) -- - def test_array_assembler(self): - spark_df = self.spark.createDataFrame( - [ -diff --git a/python/pyspark/pandas/tests/io/test_feather.py b/python/pyspark/pandas/tests/io/test_feather.py -index 74fa6bc7d7b..10638d915c0 100644 ---- a/python/pyspark/pandas/tests/io/test_feather.py -+++ b/python/pyspark/pandas/tests/io/test_feather.py -@@ -17,8 +17,10 @@ - import unittest - - import pandas as pd -+import sys - - from pyspark import pandas as ps -+from pyspark.loose_version import LooseVersion - from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils - - -@@ -34,6 +36,16 @@ class FeatherMixin: - def psdf(self): - return ps.from_pandas(self.pdf) - -+ has_arrow_21_or_below = False -+ try: -+ import pyarrow as pa -+ -+ if LooseVersion(pa.__version__) < LooseVersion("22.0.0"): -+ has_arrow_21_or_below = True -+ except ImportError: -+ pass -+ -+ @unittest.skipIf(not has_arrow_21_or_below, "SPARK-54068") - def test_to_feather(self): - with self.temp_dir() as dirpath: - path1 = f"{dirpath}/file1.feather" -diff --git a/python/pyspark/pandas/tests/io/test_stata.py b/python/pyspark/pandas/tests/io/test_stata.py -index 6fe7cf13513..3cdf2cdb150 100644 ---- a/python/pyspark/pandas/tests/io/test_stata.py -+++ b/python/pyspark/pandas/tests/io/test_stata.py -@@ -14,6 +14,7 @@ - # See the License for the specific language governing permissions and - # limitations under the License. - # -+import os - import unittest - - import pandas as pd -@@ -33,6 +34,9 @@ class StataMixin: - def psdf(self): - return ps.from_pandas(self.pdf) - -+ @unittest.skipIf( -+ os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54486: To be reenabled" -+ ) - def test_to_feather(self): - with self.temp_dir() as dirpath: - path1 = f"{dirpath}/file1.dta" -diff --git a/python/pyspark/pandas/tests/test_typedef.py b/python/pyspark/pandas/tests/test_typedef.py -index cac9aaf193a..afed59660d7 100644 ---- a/python/pyspark/pandas/tests/test_typedef.py -+++ b/python/pyspark/pandas/tests/test_typedef.py -@@ -15,6 +15,7 @@ - # limitations under the License. - # - -+import os - import sys - import unittest - import datetime -@@ -313,7 +314,6 @@ class TypeHintTestsMixin: - def test_as_spark_type_pandas_on_spark_dtype(self): - type_mapper = { - # binary -- np.character: (np.character, BinaryType()), - np.bytes_: (np.bytes_, BinaryType()), - bytes: (np.bytes_, BinaryType()), - # integer -@@ -348,6 +348,10 @@ class TypeHintTestsMixin: - ), - } - -+ if LooseVersion(np.__version__) < LooseVersion("2.3"): -+ # binary -+ type_mapper.update({np.character: (np.character, BinaryType())}) -+ - for numpy_or_python_type, (dtype, spark_type) in type_mapper.items(): - self.assertEqual(as_spark_type(numpy_or_python_type), spark_type) - self.assertEqual(pandas_on_spark_type(numpy_or_python_type), (dtype, spark_type)) -diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py -index 48545d124b2..a4ed9f996fe 100644 ---- a/python/pyspark/pandas/typedef/typehints.py -+++ b/python/pyspark/pandas/typedef/typehints.py -@@ -342,7 +342,7 @@ def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.Dat - try: - dtype = pandas_dtype(tpe) - spark_type = as_spark_type(dtype) -- except TypeError: -+ except (TypeError, ValueError): - spark_type = as_spark_type(tpe) - dtype = spark_type_to_pandas_dtype(spark_type) - return dtype, spark_type -diff --git a/python/pyspark/sql/connect/window.py b/python/pyspark/sql/connect/window.py -index bf6d60df635..952258e8db4 100644 ---- a/python/pyspark/sql/connect/window.py -+++ b/python/pyspark/sql/connect/window.py -@@ -18,7 +18,7 @@ from pyspark.sql.connect.utils import check_dependencies - - check_dependencies(__name__) - --from typing import TYPE_CHECKING, Union, Sequence, List, Optional, Tuple, cast, Iterable -+from typing import TYPE_CHECKING, Any, Union, Sequence, List, Optional, Tuple, cast, Iterable - - from pyspark.sql.column import Column - from pyspark.sql.window import ( -@@ -69,6 +69,9 @@ class WindowSpec(ParentWindowSpec): - self.__init__(partitionSpec, orderSpec, frame) # type: ignore[misc] - return self - -+ def __getnewargs__(self) -> Tuple[Any, ...]: -+ return (self._partitionSpec, self._orderSpec, self._frame) -+ - def __init__( - self, - partitionSpec: Sequence[Expression], -diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py -index cd06b3fa3ee..a3f8bc7a0f0 100644 ---- a/python/pyspark/sql/dataframe.py -+++ b/python/pyspark/sql/dataframe.py -@@ -852,7 +852,6 @@ class DataFrame: - - Notes - ----- -- - Unlike `count()`, this method does not trigger any computation. - - An empty DataFrame has no rows. It may have columns, but no data. - - Examples -diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py -index d2f9f0957e0..45ca818d7ae 100644 ---- a/python/pyspark/sql/streaming/query.py -+++ b/python/pyspark/sql/streaming/query.py -@@ -283,7 +283,10 @@ class StreamingQuery: - - >>> sq.stop() - """ -- return [StreamingQueryProgress.fromJObject(p) for p in self._jsq.recentProgress()] -+ return [ -+ StreamingQueryProgress.fromJson(json.loads(p.json())) -+ for p in self._jsq.recentProgress() -+ ] - - @property - def lastProgress(self) -> Optional[StreamingQueryProgress]: -@@ -314,7 +317,7 @@ class StreamingQuery: - """ - lastProgress = self._jsq.lastProgress() - if lastProgress: -- return StreamingQueryProgress.fromJObject(lastProgress) -+ return StreamingQueryProgress.fromJson(json.loads(lastProgress.json())) - else: - return None - -diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py -index f0637056ab8..bf51c0839f6 100755 ---- a/python/pyspark/sql/tests/connect/test_connect_basic.py -+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py -@@ -145,6 +145,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): - cdf2 = loads(data) - self.assertEqual(cdf.collect(), cdf2.collect()) - -+ def test_window_spec_serialization(self): -+ from pyspark.sql.connect.window import Window -+ from pyspark.serializers import CPickleSerializer -+ -+ pickle_ser = CPickleSerializer() -+ w = Window.partitionBy("some_string").orderBy("value") -+ b = pickle_ser.dumps(w) -+ w2 = pickle_ser.loads(b) -+ self.assertEqual(str(w), str(w2)) -+ - def test_df_getattr_behavior(self): - cdf = self.connect.range(10) - sdf = self.spark.range(10) -diff --git a/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py b/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py -index c6ef9810c68..c3b50341bbd 100644 ---- a/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py -+++ b/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py -@@ -19,7 +19,10 @@ import os - import unittest - - from pyspark.tests.test_memory_profiler import MemoryProfiler2TestsMixin, _do_computation --from pyspark.testing.connectutils import ReusedConnectTestCase -+from pyspark.testing.connectutils import ( -+ ReusedConnectTestCase, -+ skip_if_server_version_is_greater_than_or_equal_to, -+) - - - class MemoryProfilerParityTests(MemoryProfiler2TestsMixin, ReusedConnectTestCase): -@@ -27,6 +30,14 @@ class MemoryProfilerParityTests(MemoryProfiler2TestsMixin, ReusedConnectTestCase - super().setUp() - self.spark._profiler_collector._value = None - -+ @skip_if_server_version_is_greater_than_or_equal_to("4.1.0") -+ def test_memory_profiler_pandas_udf_iterator_not_supported(self): -+ super().test_memory_profiler_pandas_udf_iterator_not_supported() -+ -+ @skip_if_server_version_is_greater_than_or_equal_to("4.1.0") -+ def test_memory_profiler_map_in_pandas_not_supported(self): -+ super().test_memory_profiler_map_in_pandas_not_supported() -+ - - class MemoryProfilerWithoutPlanCacheParityTests(MemoryProfilerParityTests): - @classmethod -diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py -index 5c46130c5b5..11bc4ef8384 100644 ---- a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py -+++ b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py -@@ -22,7 +22,10 @@ from pyspark.sql.tests.test_udf_profiler import ( - UDFProfiler2TestsMixin, - _do_computation, - ) --from pyspark.testing.connectutils import ReusedConnectTestCase -+from pyspark.testing.connectutils import ( -+ ReusedConnectTestCase, -+ skip_if_server_version_is_greater_than_or_equal_to, -+) - from pyspark.testing.utils import have_flameprof - - -@@ -31,6 +34,14 @@ class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase): - super().setUp() - self.spark._profiler_collector._value = None - -+ @skip_if_server_version_is_greater_than_or_equal_to("4.1.0") -+ def test_perf_profiler_pandas_udf_iterator_not_supported(self): -+ super().test_perf_profiler_pandas_udf_iterator_not_supported() -+ -+ @skip_if_server_version_is_greater_than_or_equal_to("4.1.0") -+ def test_perf_profiler_map_in_pandas_not_supported(self): -+ super().test_perf_profiler_map_in_pandas_not_supported() -+ - - class UDFProfilerWithoutPlanCacheParityTests(UDFProfilerParityTests): - @classmethod -diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py -index 1f953235267..3a6ab9c98eb 100644 ---- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py -+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py -@@ -262,7 +262,7 @@ class CogroupedApplyInPandasTestsMixin: - "`spark.sql.execution.pandas.convertToArrowArraySafely`." - ) - self._test_merge_error( -- fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["2.0"]}), -+ fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["test_string"]}), - output_schema="id long, k double", - errorClass=PythonException, - error_message_regex=expected, -diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py -index 4ef334549ef..d60e31d8879 100644 ---- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py -+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py -@@ -17,6 +17,7 @@ - - import datetime - import unittest -+import os - - from collections import OrderedDict - from decimal import Decimal -@@ -288,28 +289,20 @@ class GroupedApplyInPandasTestsMixin: - ): - self._test_apply_in_pandas(lambda key, pdf: key) - -- @staticmethod -- def stats_with_column_names(key, pdf): -- # order of column can be different to applyInPandas schema when column names are given -- return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"]) -- -- @staticmethod -- def stats_with_no_column_names(key, pdf): -- # columns must be in order of applyInPandas schema when no columns given -- return pd.DataFrame([key + (pdf.v.mean(),)]) -- - def test_apply_in_pandas_returning_column_names(self): -- self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_column_names) -+ self._test_apply_in_pandas( -+ lambda key, pdf: pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"]) -+ ) - - def test_apply_in_pandas_returning_no_column_names(self): -- self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_no_column_names) -+ self._test_apply_in_pandas(lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),)])) - - def test_apply_in_pandas_returning_column_names_sometimes(self): - def stats(key, pdf): - if key[0] % 2: -- return GroupedApplyInPandasTestsMixin.stats_with_column_names(key, pdf) -+ return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"]) - else: -- return GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf) -+ return pd.DataFrame([key + (pdf.v.mean(),)]) - - self._test_apply_in_pandas(stats) - -@@ -343,9 +336,15 @@ class GroupedApplyInPandasTestsMixin: - lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(), pdf.v.std())]) - ) - -+ @unittest.skipIf( -+ os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54482: To be reenabled" -+ ) - def test_apply_in_pandas_returning_empty_dataframe(self): - self._test_apply_in_pandas_returning_empty_dataframe(pd.DataFrame()) - -+ @unittest.skipIf( -+ os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54482: To be reenabled" -+ ) - def test_apply_in_pandas_returning_incompatible_type(self): - with self.quiet(): - self.check_apply_in_pandas_returning_incompatible_type() -@@ -846,7 +845,7 @@ class GroupedApplyInPandasTestsMixin: - - def stats(key, pdf): - if key[0] % 2 == 0: -- return GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf) -+ return pd.DataFrame([key + (pdf.v.mean(),)]) - return empty_df - - result = ( -diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py -index 692f9705411..e5d0b56be69 100644 ---- a/python/pyspark/sql/tests/pandas/test_pandas_map.py -+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py -@@ -251,16 +251,17 @@ class MapInPandasTestsMixin: - self.check_dataframes_with_incompatible_types() - - def check_dataframes_with_incompatible_types(self): -- def func(iterator): -- for pdf in iterator: -- yield pdf.assign(id=pdf["id"].apply(str)) -- - for safely in [True, False]: - with self.subTest(convertToArrowArraySafely=safely), self.sql_conf( - {"spark.sql.execution.pandas.convertToArrowArraySafely": safely} - ): - # sometimes we see ValueErrors - with self.subTest(convert="string to double"): -+ -+ def func(iterator): -+ for pdf in iterator: -+ yield pdf.assign(id="test_string") -+ - expected = ( - r"ValueError: Exception thrown when converting pandas.Series " - r"\(object\) with name 'id' to Arrow Array \(double\)." -@@ -279,18 +280,31 @@ class MapInPandasTestsMixin: - .collect() - ) - -- # sometimes we see TypeErrors -- with self.subTest(convert="double to string"): -- with self.assertRaisesRegex( -- PythonException, -- r"TypeError: Exception thrown when converting pandas.Series " -- r"\(float64\) with name 'id' to Arrow Array \(string\).\n", -- ): -- ( -- self.spark.range(10, numPartitions=3) -- .select(col("id").cast("double")) -- .mapInPandas(self.identity_dataframes_iter("id"), "id string") -- .collect() -+ with self.subTest(convert="float to int precision loss"): -+ -+ def func(iterator): -+ for pdf in iterator: -+ yield pdf.assign(id=pdf["id"] + 0.1) -+ -+ df = ( -+ self.spark.range(10, numPartitions=3) -+ .select(col("id").cast("double")) -+ .mapInPandas(func, "id int") -+ ) -+ if safely: -+ expected = ( -+ r"ValueError: Exception thrown when converting pandas.Series " -+ r"\(float64\) with name 'id' to Arrow Array \(int32\)." -+ " It can be caused by overflows or other " -+ "unsafe conversions warned by Arrow. Arrow safe type check " -+ "can be disabled by using SQL config " -+ "`spark.sql.execution.pandas.convertToArrowArraySafely`." -+ ) -+ with self.assertRaisesRegex(PythonException, expected + "\n"): -+ df.collect() -+ else: -+ self.assertEqual( -+ df.collect(), self.spark.range(10, numPartitions=3).collect() - ) - - def test_empty_iterator(self): -diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py -index fe027875880..ae62124153c 100644 ---- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py -+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py -@@ -1601,6 +1601,49 @@ class TransformWithStateInPandasTestsMixin: - check_exception=check_exception, - ) - -+ def test_transform_with_state_in_pandas_large_values(self): -+ """Test large state values (512KB) to validate readFully fix for SPARK-53870""" -+ -+ def check_results(batch_df, batch_id): -+ batch_df.collect() -+ target_size_bytes = 512 * 1024 -+ large_string = "a" * target_size_bytes -+ expected_list_elements = ",".join( -+ [large_string, large_string + "b", large_string + "c"] -+ ) -+ expected_map_result = f"large_string_key:{large_string}" -+ -+ assert set(batch_df.sort("id").collect()) == { -+ Row( -+ id="0", -+ valueStateResult=large_string, -+ listStateResult=expected_list_elements, -+ mapStateResult=expected_map_result, -+ ), -+ Row( -+ id="1", -+ valueStateResult=large_string, -+ listStateResult=expected_list_elements, -+ mapStateResult=expected_map_result, -+ ), -+ } -+ -+ output_schema = StructType( -+ [ -+ StructField("id", StringType(), True), -+ StructField("valueStateResult", StringType(), True), -+ StructField("listStateResult", StringType(), True), -+ StructField("mapStateResult", StringType(), True), -+ ] -+ ) -+ -+ self._test_transform_with_state_in_pandas_basic( -+ PandasLargeValueStatefulProcessor(), -+ check_results, -+ single_batch=True, -+ output_schema=output_schema, -+ ) -+ - - class SimpleStatefulProcessorWithInitialState(StatefulProcessor): - # this dict is the same as input initial state dataframe -@@ -2374,6 +2417,46 @@ class PandasStatefulProcessorCompositeType(StatefulProcessor): - pass - - -+class PandasLargeValueStatefulProcessor(StatefulProcessor): -+ """Test processor for large state values (512KB) to validate readFully fix""" -+ -+ def init(self, handle: StatefulProcessorHandle): -+ value_state_schema = StructType([StructField("value", StringType(), True)]) -+ self.value_state = handle.getValueState("valueState", value_state_schema) -+ -+ list_state_schema = StructType([StructField("value", StringType(), True)]) -+ self.list_state = handle.getListState("listState", list_state_schema) -+ -+ self.map_state = handle.getMapState("mapState", "key string", "value string") -+ -+ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: -+ target_size_bytes = 512 * 1024 -+ large_string = "a" * target_size_bytes -+ -+ self.value_state.update((large_string,)) -+ value_retrieved = self.value_state.get()[0] -+ -+ self.list_state.put([(large_string,), (large_string + "b",), (large_string + "c",)]) -+ list_retrieved = list(self.list_state.get()) -+ list_elements = ",".join([elem[0] for elem in list_retrieved]) -+ -+ map_key = ("large_string_key",) -+ self.map_state.updateValue(map_key, (large_string,)) -+ map_retrieved = f"{map_key[0]}:{self.map_state.getValue(map_key)[0]}" -+ -+ yield pd.DataFrame( -+ { -+ "id": key, -+ "valueStateResult": [value_retrieved], -+ "listStateResult": [list_elements], -+ "mapStateResult": [map_retrieved], -+ } -+ ) -+ -+ def close(self) -> None: -+ pass -+ -+ - class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): - pass - -diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py -index 423a717e8ab..b4573d5fb5c 100644 ---- a/python/pyspark/testing/connectutils.py -+++ b/python/pyspark/testing/connectutils.py -@@ -16,12 +16,12 @@ - # - import shutil - import tempfile --import typing - import os - import functools - import unittest - import uuid - import contextlib -+from typing import Callable, Optional - - from pyspark.testing import ( - grpc_requirement_message, -@@ -36,6 +36,7 @@ from pyspark.testing import ( - should_test_connect, - ) - from pyspark import Row, SparkConf -+from pyspark.loose_version import LooseVersion - from pyspark.util import is_remote_only - from pyspark.testing.utils import PySparkErrorTestUtils - from pyspark.testing.sqlutils import ( -@@ -197,3 +198,28 @@ class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils, PySparkErrorTestUti - return QuietTest(self._legacy_sc) - else: - return contextlib.nullcontext() -+ -+ -+def skip_if_server_version_is( -+ cond: Callable[[LooseVersion], bool], reason: Optional[str] = None -+) -> Callable: -+ def decorator(f: Callable) -> Callable: -+ @functools.wraps(f) -+ def wrapper(self, *args, **kwargs): -+ version = self.spark.version -+ if cond(LooseVersion(version)): -+ raise unittest.SkipTest( -+ f"Skipping test {f.__name__} because server version is {version}" -+ + (f" ({reason})" if reason else "") -+ ) -+ return f(self, *args, **kwargs) -+ -+ return wrapper -+ -+ return decorator -+ -+ -+def skip_if_server_version_is_greater_than_or_equal_to( -+ version: str, reason: Optional[str] = None -+) -> Callable: -+ return skip_if_server_version_is(lambda v: v >= LooseVersion(version), reason) -diff --git a/python/pyspark/version.py b/python/pyspark/version.py -index bfcc501ff93..41148c646f7 100644 ---- a/python/pyspark/version.py -+++ b/python/pyspark/version.py -@@ -16,4 +16,4 @@ - # See the License for the specific language governing permissions and - # limitations under the License. - --__version__: str = "4.0.1" -+__version__: str = "4.0.3.dev0" -diff --git a/repl/pom.xml b/repl/pom.xml -index 02ed999e9b9..8f962239689 100644 ---- a/repl/pom.xml -+++ b/repl/pom.xml -@@ -21,7 +21,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../pom.xml - - -diff --git a/repl/src/test/resources/IntSumUdf.class b/repl/src/test/resources/IntSumUdf.class -new file mode 100644 -index 00000000000..75a41446cfc -Binary files /dev/null and b/repl/src/test/resources/IntSumUdf.class differ -diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml -index f3bace4ec6a..19f19273f6b 100644 ---- a/resource-managers/kubernetes/core/pom.xml -+++ b/resource-managers/kubernetes/core/pom.xml -@@ -20,7 +20,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../../pom.xml - - -diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml -index 5c31a10641b..ce77018ff85 100644 ---- a/resource-managers/kubernetes/integration-tests/pom.xml -+++ b/resource-managers/kubernetes/integration-tests/pom.xml -@@ -20,7 +20,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../../pom.xml - - -diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml -index 8a9437a04f6..581762e4bef 100644 ---- a/resource-managers/yarn/pom.xml -+++ b/resource-managers/yarn/pom.xml -@@ -20,7 +20,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../pom.xml - - -diff --git a/sql/api/pom.xml b/sql/api/pom.xml -index 09d458bdc5a..db17f3a5f5d 100644 ---- a/sql/api/pom.xml -+++ b/sql/api/pom.xml -@@ -22,7 +22,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../pom.xml - - -diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala -index 0f219725523..b90d9f8013d 100644 ---- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala -+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala -@@ -55,7 +55,7 @@ object JavaSerializationCodec extends (() => Codec[Any, Array[Byte]]) { - * server (driver & executors) very tricky. As a workaround a user can define their own Codec - * which internalizes the Kryo configuration. - */ --object KryoSerializationCodec extends (() => Codec[Any, Array[Byte]]) { -+object KryoSerializationCodec extends (() => Codec[Any, Array[Byte]]) with Serializable { - private lazy val kryoCodecConstructor: MethodHandle = { - val cls = SparkClassUtils.classForName( - "org.apache.spark.sql.catalyst.encoders.KryoSerializationCodecImpl") -diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala -index dd8ca26c524..044100c9226 100644 ---- a/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala -+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala -@@ -93,7 +93,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa - case _ => false - } - -- override def catalogString: String = sqlType.simpleString -+ override def catalogString: String = sqlType.catalogString - } - - private[spark] object UserDefinedType { -diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml -index 3b3e2a07b0c..bfc482e581c 100644 ---- a/sql/catalyst/pom.xml -+++ b/sql/catalyst/pom.xml -@@ -22,7 +22,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../pom.xml - - -diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsTriggerAvailableNow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsTriggerAvailableNow.java -index 47662dc97cc..268fa577b29 100644 ---- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsTriggerAvailableNow.java -+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SupportsTriggerAvailableNow.java -@@ -36,6 +36,13 @@ public interface SupportsTriggerAvailableNow extends SupportsAdmissionControl { - * the query). The source will behave as if there is no new data coming in after the target - * offset, i.e., the source will not return an offset higher than the target offset when - * {@link #latestOffset(Offset, ReadLimit) latestOffset} is called. -+ *

-+ * Note that there is an exception on the first uncommitted batch after a restart, where the end -+ * offset is not derived from the current latest offset. Sources need to take special -+ * considerations if wanting to assert such relation. One possible way is to have an internal -+ * flag in the source to indicate whether it is Trigger.AvailableNow, set the flag in this method, -+ * and record the target offset in the first call of -+ * {@link #latestOffset(Offset, ReadLimit) latestOffset}. - */ - void prepareForTriggerAvailableNow(); - } -diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java -index ac05981da5a..b14cd3429e4 100644 ---- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java -+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java -@@ -164,6 +164,7 @@ public final class ColumnarRow extends InternalRow { - - @Override - public Object get(int ordinal, DataType dataType) { -+ if (isNullAt(ordinal)) return null; - if (dataType instanceof BooleanType) { - return getBoolean(ordinal); - } else if (dataType instanceof ByteType) { -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala -index 492ea741236..9dcaba8c2bc 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala -@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.{expressions => exprs} - import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} - import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec} - import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} --import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder} -+import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, externalDataTypeFor, isNativeEncoder} - import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} - import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} - import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils} -@@ -241,19 +241,12 @@ object DeserializerBuildHelper { - val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName) - // Assumes we are deserializing the first column of a row. - val input = GetColumnByOrdinal(0, enc.dataType) -- enc match { -- case AgnosticEncoders.RowEncoder(fields) => -- val children = fields.zipWithIndex.map { case (f, i) => -- createDeserializer(f.enc, GetStructField(input, i), walkedTypePath) -- } -- CreateExternalRow(children, enc.schema) -- case _ => -- val deserializer = createDeserializer( -- enc, -- upCastToExpectedType(input, enc.dataType, walkedTypePath), -- walkedTypePath) -- expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) -- } -+ val deserializer = createDeserializer( -+ enc, -+ upCastToExpectedType(input, enc.dataType, walkedTypePath), -+ walkedTypePath, -+ isTopLevel = true) -+ expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) - } - - /** -@@ -265,11 +258,13 @@ object DeserializerBuildHelper { - * external representation. - * @param path The expression which can be used to extract serialized value. - * @param walkedTypePath The paths from top to bottom to access current field when deserializing. -+ * @param isTopLevel true if we are creating a deserializer for the top level value. - */ - private def createDeserializer( - enc: AgnosticEncoder[_], - path: Expression, -- walkedTypePath: WalkedTypePath): Expression = enc match { -+ walkedTypePath: WalkedTypePath, -+ isTopLevel: Boolean = false): Expression = enc match { - case ae: AgnosticExpressionPathEncoder[_] => - ae.fromCatalyst(path) - case _ if isNativeEncoder(enc) => -@@ -408,13 +403,12 @@ object DeserializerBuildHelper { - NewInstance(cls, arguments, Nil, propagateNull = false, dt, outerPointerGetter)) - - case AgnosticEncoders.RowEncoder(fields) => -- val isExternalRow = !path.dataType.isInstanceOf[StructType] - val convertedFields = fields.zipWithIndex.map { case (f, i) => - val newTypePath = walkedTypePath.recordField( - f.enc.clsTag.runtimeClass.getName, - f.name) - val deserializer = createDeserializer(f.enc, GetStructField(path, i), newTypePath) -- if (isExternalRow) { -+ if (!isTopLevel) { - exprs.If( - Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil), - exprs.Literal.create(null, externalDataTypeFor(f.enc)), -@@ -459,8 +453,8 @@ object DeserializerBuildHelper { - Invoke( - Literal.create(provider(), ObjectType(classOf[Codec[_, _]])), - "decode", -- ObjectType(tag.runtimeClass), -- createDeserializer(encoder, path, walkedTypePath) :: Nil) -+ dataTypeForClass(tag.runtimeClass), -+ createDeserializer(encoder, path, walkedTypePath, isTopLevel) :: Nil) - } - - private def deserializeArray( -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala -index 5c4e9d4bddc..b568722c38a 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala -@@ -756,7 +756,7 @@ object CatalogTable { - props.get(key).orElse { - if (props.exists { case (mapKey, _) => mapKey.startsWith(key) }) { - props.get(s"$key.numParts") match { -- case None => throw QueryCompilationErrors.insufficientTablePropertyError(key) -+ case None => None - case Some(numParts) => - val parts = (0 until numParts.toInt).map { index => - val keyPart = s"$key.part.$index" -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala -index 8f717795605..16d5adb064d 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala -@@ -152,6 +152,12 @@ object EncoderUtils { - VariantType -> classOf[VariantVal] - ) - -+ def dataTypeForClass(c: Class[_]): DataType = -+ javaClassToPrimitiveType.get(c).getOrElse(ObjectType(c)) -+ -+ private val javaClassToPrimitiveType: Map[Class[_], DataType] = -+ typeJavaMapping.iterator.filter(_._2.isPrimitive).map(_.swap).toMap -+ - val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]]( - BooleanType -> classOf[java.lang.Boolean], - ByteType -> classOf[java.lang.Byte], -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala -index 784bea899c4..e3ff7c5f05f 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala -@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch - import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType} - import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral} - import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper -+import org.apache.spark.sql.catalyst.optimizer.ScalarSubqueryReference - import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE - import org.apache.spark.sql.types._ - import org.apache.spark.util.sketch.BloomFilter -@@ -58,6 +59,7 @@ case class BloomFilterMightContain( - case GetStructField(subquery: PlanExpression[_], _, _) - if !subquery.containsPattern(OUTER_REFERENCE) => - TypeCheckResult.TypeCheckSuccess -+ case _: ScalarSubqueryReference => TypeCheckResult.TypeCheckSuccess - case _ => - DataTypeMismatch( - errorSubClass = "BLOOM_FILTER_BINARY_OP_WRONG_TYPE", -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala -index cbc8a8f273e..d3165e3a3e6 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala -@@ -328,7 +328,8 @@ case class HllUnionAgg( - union.update(sketch) - Some(union) - } catch { -- case _: SketchesArgumentException | _: java.lang.Error => -+ case _: SketchesArgumentException | _: java.lang.Error -+ | _: ArrayIndexOutOfBoundsException => - throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName) - } - case _ => -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala -index a4ac0bdbb11..1880d71e7d5 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala -@@ -56,7 +56,8 @@ case class HllSketchEstimate(child: Expression) - try { - Math.round(HllSketch.heapify(Memory.wrap(buffer)).getEstimate) - } catch { -- case _: SketchesArgumentException | _: java.lang.Error => -+ case _: SketchesArgumentException | _: java.lang.Error -+ | _: ArrayIndexOutOfBoundsException => - throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName) - } - } -@@ -108,13 +109,15 @@ case class HllUnion(first: Expression, second: Expression, third: Expression) - val sketch1 = try { - HllSketch.heapify(Memory.wrap(value1.asInstanceOf[Array[Byte]])) - } catch { -- case _: SketchesArgumentException | _: java.lang.Error => -+ case _: SketchesArgumentException | _: java.lang.Error -+ | _: ArrayIndexOutOfBoundsException => - throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName) - } - val sketch2 = try { - HllSketch.heapify(Memory.wrap(value2.asInstanceOf[Array[Byte]])) - } catch { -- case _: SketchesArgumentException | _: java.lang.Error => -+ case _: SketchesArgumentException | _: java.lang.Error -+ | _: ArrayIndexOutOfBoundsException => - throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName) - } - val allowDifferentLgConfigK = value3.asInstanceOf[Boolean] -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala -index 9db2ac7f9b0..0f74389a9a5 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala -@@ -1562,7 +1562,7 @@ abstract class RoundBase(child: Expression, scale: Expression, - val decimal = input1.asInstanceOf[Decimal] - if (_scale >= 0) { - // Overflow cannot happen, so no need to control nullOnOverflow -- decimal.toPrecision(decimal.precision, s, mode) -+ decimal.toPrecision(p, s, mode) - } else { - Decimal(decimal.toBigDecimal.setScale(_scale, mode), p, s) - } -@@ -1634,10 +1634,9 @@ abstract class RoundBase(child: Expression, scale: Expression, - case DecimalType.Fixed(p, s) => - if (_scale >= 0) { - s""" -- ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, -- Decimal.$modeStr(), true, null); -+ ${ev.value} = ${ce.value}.toPrecision($p, $s, Decimal.$modeStr(), true, null); - ${ev.isNull} = ${ev.value} == null;""" -- } else { -+ } else { - s""" - ${ev.value} = new Decimal().set(${ce.value}.toBigDecimal() - .setScale(${_scale}, Decimal.$modeStr()), $p, $s); -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala -index 46815969e7e..d36a71b0439 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala -@@ -26,12 +26,29 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, WINDOW} - * Inserts a `WindowGroupLimit` below `Window` if the `Window` has rank-like functions - * and the function results are further filtered by limit-like predicates. Example query: - * {{{ -- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE rn = 5 -- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE 5 = rn -- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE rn < 5 -- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE 5 > rn -- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE rn <= 5 -- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE 5 >= rn -+ * SELECT * FROM ( -+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 -+ * ) WHERE rn = 5; -+ * -+ * SELECT * FROM ( -+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 -+ * ) WHERE 5 = rn; -+ * -+ * SELECT * FROM ( -+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 -+ * ) WHERE rn < 5; -+ * -+ * SELECT * FROM ( -+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 -+ * ) WHERE 5 > rn; -+ * -+ * SELECT * FROM ( -+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 -+ * ) WHERE rn <= 5; -+ * -+ * SELECT * FROM ( -+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 -+ * ) WHERE 5 >= rn; - * }}} - */ - object InferWindowGroupLimit extends Rule[LogicalPlan] with PredicateHelper { -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala -index aa972c81559..7a8deb10f1a 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala -@@ -357,6 +357,15 @@ abstract class Optimizer(catalogManager: CatalogManager) - case other => other - } - } -+ -+ private def optimizeSubquery(s: SubqueryExpression): SubqueryExpression = { -+ val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s)) -+ // At this point we have an optimized subquery plan that we are going to attach -+ // to this subquery expression. Here we can safely remove any top level sort -+ // in the plan as tuples produced by a subquery are un-ordered. -+ s.withNewPlan(removeTopLevelSort(newPlan)) -+ } -+ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( - _.containsPattern(PLAN_EXPRESSION), ruleId) { - // Do not optimize DPP subquery, as it was created from optimized plan and we should not -@@ -411,12 +420,23 @@ abstract class Optimizer(catalogManager: CatalogManager) - s.withNewPlan( - if (needTopLevelProject) newPlan else newPlan.child - ) -+ case s: Exists => -+ // For an EXISTS join, the subquery might be written as "SELECT * FROM ...". -+ // If we optimize the subquery directly, column pruning may not be applied -+ // effectively. To address this, we add an extra Project node that selects -+ // only the columns referenced in the EXISTS join condition. -+ // This ensures that column pruning can be performed correctly -+ // during subquery optimization. -+ val selectedRefrences = -+ s.plan.output.filter(s.joinCond.flatMap(_.references).contains) -+ val newPlan = if (selectedRefrences.nonEmpty) { -+ s.withNewPlan(Project(selectedRefrences, s.plan)) -+ } else { -+ s -+ } -+ optimizeSubquery(newPlan) - case s: SubqueryExpression => -- val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s)) -- // At this point we have an optimized subquery plan that we are going to attach -- // to this subquery expression. Here we can safely remove any top level sort -- // in the plan as tuples produced by a subquery are un-ordered. -- s.withNewPlan(removeTopLevelSort(newPlan)) -+ optimizeSubquery(s) - } - } - -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala -index f8c1b2a9014..94d69fa2179 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala -@@ -250,7 +250,7 @@ case class ReplaceData( - write: Option[Write] = None) extends RowLevelWrite { - - override val isByName: Boolean = false -- override val stringArgs: Iterator[Any] = Iterator(table, query, write) -+ override def stringArgs: Iterator[Any] = Iterator(table, query, write) - - override lazy val references: AttributeSet = query.outputSet - -@@ -332,7 +332,7 @@ case class WriteDelta( - write: Option[DeltaWrite] = None) extends RowLevelWrite { - - override val isByName: Boolean = false -- override val stringArgs: Iterator[Any] = Iterator(table, query, write) -+ override def stringArgs: Iterator[Any] = Iterator(table, query, write) - - override lazy val references: AttributeSet = query.outputSet - -@@ -1654,12 +1654,19 @@ case class Call( - } - - override def simpleString(maxFields: Int): String = { -- val name = procedure match { -+ procedure match { - case ResolvedProcedure(catalog, ident, _) => -- s"${quoteIfNeeded(catalog.name)}.${ident.quoted}" -+ val name = s"${quoteIfNeeded(catalog.name)}.${ident.quoted}" -+ simpleString(name, maxFields) - case UnresolvedProcedure(nameParts) => -- nameParts.quoted -+ val name = nameParts.quoted -+ simpleString(name, maxFields) -+ case _ => -+ super.simpleString(maxFields) - } -+ } -+ -+ private def simpleString(name: String, maxFields: Int): String = { - val argsString = truncatedString(args, ", ", maxFields) - s"Call $name($argsString)" - } -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala -index 038105f9bfd..dc66b6f30e5 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala -@@ -899,10 +899,13 @@ case class KeyGroupedShuffleSpec( - } - - override def createPartitioning(clustering: Seq[Expression]): Partitioning = { -- val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map { -- case (c, e: TransformExpression) => TransformExpression( -- e.function, Seq(c), e.numBucketsOpt) -- case (c, _) => c -+ assert(clustering.size == distribution.clustering.size, -+ "Required distributions of join legs should be the same size.") -+ -+ val newExpressions = partitioning.expressions.zip(keyPositions).map { -+ case (te: TransformExpression, positionSet) => -+ te.copy(children = te.children.map(_ => clustering(positionSet.head))) -+ case (_, positionSet) => clustering(positionSet.head) - } - KeyGroupedPartitioning(newExpressions, - partitioning.numPartitions, -diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala -index b24ad30e071..72a8c8539bd 100644 ---- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala -+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala -@@ -18,6 +18,7 @@ - package org.apache.spark.sql.catalyst.util - - import scala.collection.mutable.ArrayBuffer -+import scala.util.{Failure, Success, Try} - - import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} - import org.apache.spark.internal.{Logging, MDC} -@@ -368,27 +369,33 @@ object ResolveDefaultColumns extends QueryErrorsBase - val defaultSQL = field.metadata.getString(EXISTS_DEFAULT_COLUMN_METADATA_KEY) - - // Parse the expression. -- val expr = Literal.fromSQL(defaultSQL) match { -- // EXISTS_DEFAULT will have a cast from analyze() due to coerceDefaultValue -- // hence we need to add timezone to the cast if necessary -- case c: Cast if c.child.resolved && c.needsTimeZone => -- c.withTimeZone(SQLConf.get.sessionLocalTimeZone) -- case e: Expression => e -- } -+ val resolvedExpr = Try(Literal.fromSQL(defaultSQL)) match { -+ case Success(literal) => -+ val expr = literal match { -+ // EXISTS_DEFAULT will have a cast from analyze() due to coerceDefaultValue -+ // hence we need to add timezone to the cast if necessary -+ case c: Cast if c.child.resolved && c.needsTimeZone => -+ c.withTimeZone(SQLConf.get.sessionLocalTimeZone) -+ case e: Expression => e -+ } - -- // Check invariants -- if (expr.containsPattern(PLAN_EXPRESSION)) { -- throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions( -- "", field.name, defaultSQL) -- } -+ // Check invariants -+ if (expr.containsPattern(PLAN_EXPRESSION)) { -+ throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions( -+ "", field.name, defaultSQL) -+ } -+ -+ expr match { -+ case _: ExprLiteral => expr -+ case c: Cast if c.resolved => expr -+ case _ => -+ fallbackResolveExistenceDefaultValue(field) -+ } - -- val resolvedExpr = expr match { -- case _: ExprLiteral => expr -- case c: Cast if c.resolved => expr -- case _ => -+ case Failure(_) => -+ // If Literal.fromSQL fails, use fallback resolution - fallbackResolveExistenceDefaultValue(field) - } -- - coerceDefaultValue(resolvedExpr, field.dataType, "", field.name, defaultSQL) - } - -diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala -index 616c6d65636..0d26b390643 100644 ---- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala -+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala -@@ -612,6 +612,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes - provider, - nullable = true)) - .resolveAndBind() -+ assert(encoder.isInstanceOf[Serializable]) - assert(encoder.schema == new StructType().add("value", BinaryType)) - val toRow = encoder.createSerializer() - val fromRow = encoder.createDeserializer() -@@ -659,6 +660,22 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes - assert(fromRow(toRow(new Wrapper(Row(9L, "x")))) == new Wrapper(Row(9L, "x"))) - } - -+ test("SPARK-52614: transforming encoder row encoder in product encoder") { -+ val schema = new StructType().add("a", LongType).add("b", StringType) -+ val wrapperEncoder = TransformingEncoder( -+ classTag[Wrapper[Row]], -+ RowEncoder.encoderFor(schema), -+ new WrapperCodecProvider[Row]) -+ val encoder = ExpressionEncoder(ProductEncoder( -+ classTag[V[Wrapper[Row]]], -+ Seq(EncoderField("v", wrapperEncoder, nullable = false, Metadata.empty)), -+ None)) -+ .resolveAndBind() -+ val toRow = encoder.createSerializer() -+ val fromRow = encoder.createDeserializer() -+ assert(fromRow(toRow(V(new Wrapper(Row(9L, "x"))))) == V(new Wrapper(Row(9L, "x")))) -+ } -+ - // below tests are related to SPARK-49960 and TransformingEncoder usage - test("""Encoder with OptionEncoder of transformation""".stripMargin) { - type T = Option[V[V[Int]]] -@@ -749,6 +766,24 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes - testDataTransformingEnc(enc, data) - } - -+ test("SPARK-52601 TransformingEncoder from primitive to timestamp") { -+ val enc: AgnosticEncoder[Long] = -+ TransformingEncoder[Long, java.sql.Timestamp]( -+ classTag, -+ TimestampEncoder(true), -+ () => -+ new Codec[Long, java.sql.Timestamp] with Serializable { -+ override def encode(in: Long): Timestamp = Timestamp.from(microsToInstant(in)) -+ override def decode(out: Timestamp): Long = instantToMicros(out.toInstant) -+ } -+ ) -+ val data: Seq[Long] = Seq(0L, 1L, 2L) -+ -+ assert(enc.dataType === TimestampType) -+ -+ testDataTransformingEnc(enc, data) -+ } -+ - val longEncForTimestamp: AgnosticEncoder[V[Long]] = - TransformingEncoder[V[Long], java.sql.Timestamp]( - classTag, -diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala -index 0841702cc51..0f7f5ca54be 100644 ---- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala -+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala -@@ -108,4 +108,49 @@ class DatasketchesHllSketchSuite extends SparkFunSuite { - - assert(HllSketch.heapify(Memory.wrap(binary3.asInstanceOf[Array[Byte]])).getLgConfigK == 12) - } -+ -+ test("HllUnionAgg throws proper error for invalid binary input causing ArrayIndexOutOfBounds") { -+ val aggFunc = new HllUnionAgg(BoundReference(0, BinaryType, nullable = true), true) -+ val union = aggFunc.createAggregationBuffer() -+ -+ // Craft a byte array that passes initial size checks but has an invalid CurMode ordinal. -+ // HLL preamble layout: -+ // Byte 0: preInts (preamble size in ints) -+ // Byte 1: serVer (must be 1) -+ // Byte 2: famId (must be 7 for HLL) -+ // Byte 3: lgK (4-21) -+ // Byte 5: flags -+ // Byte 7: modeByte - bits 0-1 contain curMode ordinal (0=LIST, 1=SET, 2=HLL) -+ // -+ // Setting bits 0-1 of byte 7 to 0b11 (=3) causes CurMode.fromOrdinal(3) to throw -+ // ArrayIndexOutOfBoundsException since CurMode only has ordinals 0, 1, 2. -+ // This happens in PreambleUtil.extractCurMode() before other validations run. -+ val invalidBinary = Array[Byte]( -+ 2, // byte 0: preInts = 2 (LIST_PREINTS, passes check) -+ 1, // byte 1: serVer = 1 (valid) -+ 7, // byte 2: famId = 7 (HLL family) -+ 12, // byte 3: lgK = 12 (valid range 4-21) -+ 0, // byte 4: unused -+ 0, // byte 5: flags = 0 -+ 0, // byte 6: unused -+ 3 // byte 7: modeByte with bits 0-1 = 0b11 = 3 (INVALID curMode ordinal!) -+ ) -+ -+ val exception = intercept[Exception] { -+ aggFunc.update(union, InternalRow(invalidBinary)) -+ } -+ -+ // Verify that ArrayIndexOutOfBoundsException is properly caught and converted -+ // to the user-friendly HLL_INVALID_INPUT_SKETCH_BUFFER error -+ assert( -+ !exception.isInstanceOf[ArrayIndexOutOfBoundsException], -+ s"ArrayIndexOutOfBoundsException should be caught and converted to " + -+ s"HLL_INVALID_INPUT_SKETCH_BUFFER error, but got: ${exception.getClass.getName}" -+ ) -+ assert( -+ exception.getMessage.contains("HLL_INVALID_INPUT_SKETCH_BUFFER"), -+ s"Expected HLL_INVALID_INPUT_SKETCH_BUFFER error, " + -+ s"but got: ${exception.getClass.getName}: ${exception.getMessage}" -+ ) -+ } - } -diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala -index 5dd45d3d449..42579f6cc6e 100644 ---- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala -+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala -@@ -856,6 +856,13 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { - "CAST(CURRENT_TIMESTAMP AS BIGINT)") - .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, - "CAST(CURRENT_TIMESTAMP AS BIGINT)") -+ .build()), -+ StructField("c3", StringType, true, -+ new MetadataBuilder() -+ .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, -+ "CONCAT(YEAR(CURRENT_DATE), LPAD(WEEKOFYEAR(CURRENT_DATE), 2, '0'))") -+ .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, -+ "CONCAT(YEAR(CURRENT_DATE), LPAD(WEEKOFYEAR(CURRENT_DATE), 2, '0'))") - .build()))) - val res = ResolveDefaultColumns.existenceDefaultValues(source) - assert(res(0) == null) -@@ -864,5 +871,9 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { - val res2Wrapper = new LongWrapper - assert(res(2).asInstanceOf[UTF8String].toLong(res2Wrapper)) - assert(res2Wrapper.value > 0) -+ -+ val res3Wrapper = new LongWrapper -+ assert(res(3).asInstanceOf[UTF8String].toLong(res3Wrapper)) -+ assert(res3Wrapper.value > 0) - } - } -diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala -index 04b090d7001..2f58e722c05 100644 ---- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala -+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala -@@ -17,6 +17,7 @@ - - package org.apache.spark.sql.types - -+import org.apache.spark.sql.Row - import org.apache.spark.sql.catalyst.InternalRow - import org.apache.spark.sql.catalyst.expressions.GenericInternalRow - import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} -@@ -132,3 +133,22 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] - - override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] - } -+ -+ -+class ExampleIntRowUDT(cols: Int) extends UserDefinedType[Row] { -+ override def sqlType: DataType = { -+ StructType((0 until cols).map(i => -+ StructField(s"col$i", IntegerType, nullable = false))) -+ } -+ -+ override def serialize(obj: Row): InternalRow = { -+ InternalRow.fromSeq(obj.toSeq) -+ } -+ -+ override def deserialize(datum: Any): Row = { -+ val internalRow = datum.asInstanceOf[InternalRow] -+ Row.fromSeq(internalRow.toSeq(sqlType.asInstanceOf[StructType])) -+ } -+ -+ override def userClass: Class[Row] = classOf[Row] -+} -diff --git a/sql/connect/client/jvm/pom.xml b/sql/connect/client/jvm/pom.xml -index 3de1cf368f8..bd586e86adc 100644 ---- a/sql/connect/client/jvm/pom.xml -+++ b/sql/connect/client/jvm/pom.xml -@@ -22,7 +22,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../../../pom.xml - - -diff --git a/sql/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar b/sql/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar -new file mode 100644 -index 00000000000..6dee8fcd9c9 -Binary files /dev/null and b/sql/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar differ -diff --git a/sql/connect/client/jvm/src/test/resources/udf2.13.jar b/sql/connect/client/jvm/src/test/resources/udf2.13.jar -new file mode 100644 -index 00000000000..c89830f127c -Binary files /dev/null and b/sql/connect/client/jvm/src/test/resources/udf2.13.jar differ -diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala -index a548ec7007d..e19f1eacfd8 100644 ---- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala -+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala -@@ -3390,12 +3390,24 @@ class PlanGenerationTestSuite - fn.typedLit(java.time.Duration.ofSeconds(200L)), - fn.typedLit(java.time.Period.ofDays(100)), - fn.typedLit(new CalendarInterval(2, 20, 100L)), -+ fn.typedLit( -+ ( -+ java.time.LocalDate.of(2020, 10, 10), -+ java.time.Instant.ofEpochMilli(1677155519808L), -+ new java.sql.Timestamp(12345L), -+ java.time.LocalDateTime.of(2023, 2, 23, 20, 36), -+ java.sql.Date.valueOf("2023-02-23"), -+ java.time.Duration.ofSeconds(200L), -+ java.time.Period.ofDays(100), -+ new CalendarInterval(2, 20, 100L))), - - // Handle parameterized scala types e.g.: List, Seq and Map. - fn.typedLit(Some(1)), - fn.typedLit(Array(1, 2, 3)), -+ fn.typedLit[Array[Integer]](Array(null, null)), - fn.typedLit(Seq(1, 2, 3)), -- fn.typedLit(Map("a" -> 1, "b" -> 2)), -+ fn.typedLit(mutable.LinkedHashMap("a" -> 1, "b" -> 2)), -+ fn.typedLit(mutable.LinkedHashMap[String, Integer]("a" -> null, "b" -> null)), - fn.typedLit(("a", 2, 1.0)), - fn.typedLit[Option[Int]](None), - fn.typedLit[Array[Option[Int]]](Array(Some(1))), -diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala -index 3b6dd090caf..afc2b1db023 100644 ---- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala -+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala -@@ -1681,6 +1681,13 @@ class ClientE2ETestSuite - assert(df.count() == 100) - } - } -+ -+ test("SPARK-53553: null value handling in literals") { -+ val df = spark.sql("select 1").select(typedlit(Array[Integer](1, null)).as("arr_col")) -+ val result = df.collect() -+ assert(result.length === 1) -+ assert(result(0).getAs[Array[Integer]]("arr_col") === Array(1, null)) -+ } - } - - private[sql] case class ClassData(a: String, b: Int) -diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala -index 1d022489b70..4c0073cad56 100644 ---- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala -+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala -@@ -16,7 +16,7 @@ - */ - package org.apache.spark.sql.connect - --import java.util.concurrent.ForkJoinPool -+import java.util.concurrent.Executors - - import scala.collection.mutable - import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} -@@ -146,7 +146,7 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession { - // global ExecutionContext has only 2 threads in Apache Spark CI - // create own thread pool for four Futures used in this test - val numThreads = 4 -- val fpool = new ForkJoinPool(numThreads) -+ val fpool = Executors.newFixedThreadPool(numThreads) - val executionContext = ExecutionContext.fromExecutorService(fpool) - - val q1 = Future { -diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala -index cbaa4f5ea07..8afa28b1f38 100644 ---- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala -+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala -@@ -234,6 +234,8 @@ object CheckConnectJvmClientCompatibility { - "org.apache.spark.sql.artifact.ArtifactManager$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.artifact.ArtifactManager$SparkContextResourceType$"), -+ ProblemFilters.exclude[MissingClassProblem]( -+ "org.apache.spark.sql.artifact.ArtifactManager$StateCleanupRunner"), - - // ColumnNode conversions - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession"), -diff --git a/sql/connect/common/pom.xml b/sql/connect/common/pom.xml -index 1966bf4b303..58441cde7b3 100644 ---- a/sql/connect/common/pom.xml -+++ b/sql/connect/common/pom.xml -@@ -22,7 +22,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../../pom.xml - - -diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala -index 1f3496fa898..d64f5d7cdf2 100644 ---- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala -+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala -@@ -163,6 +163,14 @@ object LiteralValueProtoConverter { - } - - (literal, dataType) match { -+ case (v: Option[_], _: DataType) => -+ if (v.isDefined) { -+ toLiteralProtoBuilder(v.get) -+ } else { -+ builder.setNull(toConnectProtoType(dataType)) -+ } -+ case (null, _) => -+ builder.setNull(toConnectProtoType(dataType)) - case (v: mutable.ArraySeq[_], ArrayType(_, _)) => - toLiteralProtoBuilder(v.array, dataType) - case (v: immutable.ArraySeq[_], ArrayType(_, _)) => -@@ -175,12 +183,6 @@ object LiteralValueProtoConverter { - builder.setMap(mapBuilder(v, keyType, valueType)) - case (v, structType: StructType) => - builder.setStruct(structBuilder(v, structType)) -- case (v: Option[_], _: DataType) => -- if (v.isDefined) { -- toLiteralProtoBuilder(v.get) -- } else { -- builder.setNull(toConnectProtoType(dataType)) -- } - case _ => toLiteralProtoBuilder(literal) - } - } -@@ -296,8 +298,8 @@ object LiteralValueProtoConverter { - } - } - -- private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { -- if (dataType.hasShort) { v => -+ private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { -+ val converter: proto.Expression.Literal => Any = if (dataType.hasShort) { v => - v.getShort.toShort - } else if (dataType.hasInteger) { v => - v.getInteger -@@ -316,15 +318,15 @@ object LiteralValueProtoConverter { - } else if (dataType.hasBinary) { v => - v.getBinary.toByteArray - } else if (dataType.hasDate) { v => -- v.getDate -+ SparkDateTimeUtils.toJavaDate(v.getDate) - } else if (dataType.hasTimestamp) { v => -- v.getTimestamp -+ SparkDateTimeUtils.toJavaTimestamp(v.getTimestamp) - } else if (dataType.hasTimestampNtz) { v => -- v.getTimestampNtz -+ SparkDateTimeUtils.microsToLocalDateTime(v.getTimestampNtz) - } else if (dataType.hasDayTimeInterval) { v => -- v.getDayTimeInterval -+ SparkIntervalUtils.microsToDuration(v.getDayTimeInterval) - } else if (dataType.hasYearMonthInterval) { v => -- v.getYearMonthInterval -+ SparkIntervalUtils.monthsToPeriod(v.getYearMonthInterval) - } else if (dataType.hasDecimal) { v => - Decimal(v.getDecimal.getValue) - } else if (dataType.hasCalendarInterval) { v => -@@ -339,6 +341,7 @@ object LiteralValueProtoConverter { - } else { - throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)") - } -+ v => if (v.hasNull) null else converter(v) - } - - def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = { -@@ -354,7 +357,7 @@ object LiteralValueProtoConverter { - builder.result() - } - -- makeArrayData(getConverter(array.getElementType)) -+ makeArrayData(getScalaConverter(array.getElementType)) - } - - def toCatalystMap(map: proto.Expression.Literal.Map): mutable.Map[_, _] = { -@@ -373,7 +376,7 @@ object LiteralValueProtoConverter { - builder - } - -- makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType)) -+ makeMapData(getScalaConverter(map.getKeyType), getScalaConverter(map.getValueType)) - } - - def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = { -@@ -392,7 +395,7 @@ object LiteralValueProtoConverter { - val structData = elements - .zip(dataTypes) - .map { case (element, dataType) => -- getConverter(dataType)(element) -+ getScalaConverter(dataType)(element) - } - .asInstanceOf[scala.collection.Seq[Object]] - .toSeq -diff --git a/sql/connect/common/src/test/resources/artifact-tests/Hello.class b/sql/connect/common/src/test/resources/artifact-tests/Hello.class -new file mode 100644 -index 00000000000..56725764de2 -Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/Hello.class differ -diff --git a/sql/connect/common/src/test/resources/artifact-tests/junitLargeJar.jar b/sql/connect/common/src/test/resources/artifact-tests/junitLargeJar.jar -new file mode 100755 -index 00000000000..6da55d8b852 -Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/junitLargeJar.jar differ -diff --git a/sql/connect/common/src/test/resources/artifact-tests/smallClassFile.class b/sql/connect/common/src/test/resources/artifact-tests/smallClassFile.class -new file mode 100755 -index 00000000000..e796030e471 -Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/smallClassFile.class differ -diff --git a/sql/connect/common/src/test/resources/artifact-tests/smallClassFileDup.class b/sql/connect/common/src/test/resources/artifact-tests/smallClassFileDup.class -new file mode 100755 -index 00000000000..e796030e471 -Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/smallClassFileDup.class differ -diff --git a/sql/connect/common/src/test/resources/artifact-tests/smallJar.jar b/sql/connect/common/src/test/resources/artifact-tests/smallJar.jar -new file mode 100755 -index 00000000000..3c4930e8e95 -Binary files /dev/null and b/sql/connect/common/src/test/resources/artifact-tests/smallJar.jar differ -diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain -index 6d854da250f..a566430136f 100644 ---- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain -+++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain -@@ -1,2 +1,2 @@ --Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 18 more fields] -+Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 21 more fields] - +- LocalRelation , [id#0L, a#0, b#0] -diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json -index e56b6e1f3ee..456033244a9 100644 ---- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json -+++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json -@@ -77,7 +77,8 @@ - }, { - "literal": { - "null": { -- "null": { -+ "string": { -+ "collation": "UTF8_BINARY" - } - } - }, -@@ -652,6 +653,114 @@ - } - } - } -+ }, { -+ "literal": { -+ "struct": { -+ "structType": { -+ "struct": { -+ "fields": [{ -+ "name": "_1", -+ "dataType": { -+ "date": { -+ } -+ }, -+ "nullable": true -+ }, { -+ "name": "_2", -+ "dataType": { -+ "timestamp": { -+ } -+ }, -+ "nullable": true -+ }, { -+ "name": "_3", -+ "dataType": { -+ "timestamp": { -+ } -+ }, -+ "nullable": true -+ }, { -+ "name": "_4", -+ "dataType": { -+ "timestampNtz": { -+ } -+ }, -+ "nullable": true -+ }, { -+ "name": "_5", -+ "dataType": { -+ "date": { -+ } -+ }, -+ "nullable": true -+ }, { -+ "name": "_6", -+ "dataType": { -+ "dayTimeInterval": { -+ "startField": 0, -+ "endField": 3 -+ } -+ }, -+ "nullable": true -+ }, { -+ "name": "_7", -+ "dataType": { -+ "yearMonthInterval": { -+ "startField": 0, -+ "endField": 1 -+ } -+ }, -+ "nullable": true -+ }, { -+ "name": "_8", -+ "dataType": { -+ "calendarInterval": { -+ } -+ }, -+ "nullable": true -+ }] -+ } -+ }, -+ "elements": [{ -+ "date": 18545 -+ }, { -+ "timestamp": "1677155519808000" -+ }, { -+ "timestamp": "12345000" -+ }, { -+ "timestampNtz": "1677184560000000" -+ }, { -+ "date": 19411 -+ }, { -+ "dayTimeInterval": "200000000" -+ }, { -+ "yearMonthInterval": 0 -+ }, { -+ "calendarInterval": { -+ "months": 2, -+ "days": 20, -+ "microseconds": "100" -+ } -+ }] -+ } -+ }, -+ "common": { -+ "origin": { -+ "jvmOrigin": { -+ "stackTrace": [{ -+ "classLoaderName": "app", -+ "declaringClass": "org.apache.spark.sql.functions$", -+ "methodName": "typedLit", -+ "fileName": "functions.scala" -+ }, { -+ "classLoaderName": "app", -+ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", -+ "methodName": "~~trimmed~anonfun~~", -+ "fileName": "PlanGenerationTestSuite.scala" -+ }] -+ } -+ } -+ } - }, { - "literal": { - "integer": 1 -@@ -706,6 +815,43 @@ - } - } - } -+ }, { -+ "literal": { -+ "array": { -+ "elementType": { -+ "integer": { -+ } -+ }, -+ "elements": [{ -+ "null": { -+ "integer": { -+ } -+ } -+ }, { -+ "null": { -+ "integer": { -+ } -+ } -+ }] -+ } -+ }, -+ "common": { -+ "origin": { -+ "jvmOrigin": { -+ "stackTrace": [{ -+ "classLoaderName": "app", -+ "declaringClass": "org.apache.spark.sql.functions$", -+ "methodName": "typedLit", -+ "fileName": "functions.scala" -+ }, { -+ "classLoaderName": "app", -+ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", -+ "methodName": "~~trimmed~anonfun~~", -+ "fileName": "PlanGenerationTestSuite.scala" -+ }] -+ } -+ } -+ } - }, { - "literal": { - "array": { -@@ -780,6 +926,53 @@ - } - } - } -+ }, { -+ "literal": { -+ "map": { -+ "keyType": { -+ "string": { -+ "collation": "UTF8_BINARY" -+ } -+ }, -+ "valueType": { -+ "integer": { -+ } -+ }, -+ "keys": [{ -+ "string": "a" -+ }, { -+ "string": "b" -+ }], -+ "values": [{ -+ "null": { -+ "integer": { -+ } -+ } -+ }, { -+ "null": { -+ "integer": { -+ } -+ } -+ }] -+ } -+ }, -+ "common": { -+ "origin": { -+ "jvmOrigin": { -+ "stackTrace": [{ -+ "classLoaderName": "app", -+ "declaringClass": "org.apache.spark.sql.functions$", -+ "methodName": "typedLit", -+ "fileName": "functions.scala" -+ }, { -+ "classLoaderName": "app", -+ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", -+ "methodName": "~~trimmed~anonfun~~", -+ "fileName": "PlanGenerationTestSuite.scala" -+ }] -+ } -+ } -+ } - }, { - "literal": { - "struct": { -diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin -index 38a6ce63005..749da55007d 100644 -Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ -diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml -index d4b98aaf26d..ab9470eeeef 100644 ---- a/sql/connect/server/pom.xml -+++ b/sql/connect/server/pom.xml -@@ -22,7 +22,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../../pom.xml - - -diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala -index 3a707495ff3..785b254d7af 100644 ---- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala -+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala -@@ -263,7 +263,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( - timeoutNs = Math.min(progressTimeout * NANOS_PER_MILLIS, timeoutNs) - } - logTrace(s"Wait for response to become available with timeout=$timeoutNs ns.") -- executionObserver.responseLock.wait(timeoutNs / NANOS_PER_MILLIS) -+ executionObserver.responseLock.wait(Math.max(1, timeoutNs / NANOS_PER_MILLIS)) - enqueueProgressMessage(force = true) - logTrace(s"Reacquired executionObserver lock after waiting.") - sleepEnd = System.nanoTime() -@@ -384,7 +384,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( - val timeoutNs = Math.max(1, deadlineTimeNs - System.nanoTime()) - var sleepStart = System.nanoTime() - logTrace(s"Wait for grpcCallObserver to become ready with timeout=$timeoutNs ns.") -- grpcCallObserverReadySignal.wait(timeoutNs / NANOS_PER_MILLIS) -+ grpcCallObserverReadySignal.wait(Math.max(1, timeoutNs / NANOS_PER_MILLIS)) - logTrace(s"Reacquired grpcCallObserverReadySignal lock after waiting.") - sleepEnd = System.nanoTime() - } -diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala -index bf1b6e7e00e..d5b81223707 100644 ---- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala -+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala -@@ -32,7 +32,7 @@ import io.grpc.{Context, Status, StatusRuntimeException} - import io.grpc.stub.StreamObserver - import org.apache.commons.lang3.exception.ExceptionUtils - --import org.apache.spark.{SparkEnv, TaskContext} -+import org.apache.spark.{SparkEnv, SparkException, TaskContext} - import org.apache.spark.annotation.{DeveloperApi, Since} - import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} - import org.apache.spark.connect.proto -@@ -44,7 +44,7 @@ import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase - import org.apache.spark.internal.{Logging, LogKeys, MDC} - import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} - import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} --import org.apache.spark.sql.{Column, Encoders, ForeachWriter, Observation, Row} -+import org.apache.spark.sql.{AnalysisException, Column, Encoders, ForeachWriter, Observation, Row} - import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} - import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns, UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTranspose} - import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder} -@@ -1091,9 +1091,20 @@ class SparkConnectPlanner( - // for backward compatibility - rel.getRenameColumnsMapMap.asScala.toSeq.unzip - } -- Project( -- Seq(UnresolvedStarWithColumnsRenames(existingNames = colNames, newNames = newColNames)), -- transformRelation(rel.getInput)) -+ -+ val child = transformRelation(rel.getInput) -+ try { -+ // Try the eager analysis first. -+ Dataset -+ .ofRows(session, child) -+ .withColumnsRenamed(colNames, newColNames) -+ .logicalPlan -+ } catch { -+ case _: AnalysisException | _: SparkException => -+ Project( -+ Seq(UnresolvedStarWithColumnsRenames(existingNames = colNames, newNames = newColNames)), -+ child) -+ } - } - - private def transformWithColumns(rel: proto.WithColumns): LogicalPlan = { -@@ -1113,13 +1124,23 @@ class SparkConnectPlanner( - (alias.getName(0), transformExpression(alias.getExpr), metadata) - }.unzip3 - -- Project( -- Seq( -- UnresolvedStarWithColumns( -- colNames = colNames, -- exprs = exprs, -- explicitMetadata = Some(metadata))), -- transformRelation(rel.getInput)) -+ val child = transformRelation(rel.getInput) -+ try { -+ // Try the eager analysis first. -+ Dataset -+ .ofRows(session, child) -+ .withColumns(colNames, exprs.map(expr => Column(expr)), metadata) -+ .logicalPlan -+ } catch { -+ case _: AnalysisException | _: SparkException => -+ Project( -+ Seq( -+ UnresolvedStarWithColumns( -+ colNames = colNames, -+ exprs = exprs, -+ explicitMetadata = Some(metadata))), -+ child) -+ } - } - - private def transformWithWatermark(rel: proto.WithWatermark): LogicalPlan = { -diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala -index 5e887256916..c6daa92e973 100644 ---- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala -+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala -@@ -193,10 +193,11 @@ class SparkConnectServiceSuite - } - - override def onCompleted(): Unit = { -+ verifyEvents.onCompleted(Some(100)) - done = true - } - }) -- verifyEvents.onCompleted(Some(100)) -+ verifyEvents.assertClosed() - // The current implementation is expected to be blocking. This is here to make sure it is. - assert(done) - -@@ -294,10 +295,11 @@ class SparkConnectServiceSuite - } - - override def onCompleted(): Unit = { -+ verifyEvents.onCompleted(Some(6)) - done = true - } - }) -- verifyEvents.onCompleted(Some(6)) -+ verifyEvents.assertClosed() - // The current implementation is expected to be blocking. This is here to make sure it is. - assert(done) - -@@ -530,10 +532,11 @@ class SparkConnectServiceSuite - } - - override def onCompleted(): Unit = { -+ verifyEvents.onCompleted(producedNumRows) - done = true - } - }) -- verifyEvents.onCompleted(producedNumRows) -+ verifyEvents.assertClosed() - // The current implementation is expected to be blocking. - // This is here to make sure it is. - assert(done) -@@ -621,7 +624,7 @@ class SparkConnectServiceSuite - } - }) - thread.join() -- verifyEvents.onCompleted() -+ verifyEvents.assertClosed() - } - } - -@@ -684,7 +687,7 @@ class SparkConnectServiceSuite - } - }) - assert(failures.isEmpty, s"this should have no failures but got $failures") -- verifyEvents.onCompleted() -+ verifyEvents.assertClosed() - } - } - -@@ -883,9 +886,6 @@ class SparkConnectServiceSuite - } - } - def onNext(v: proto.ExecutePlanResponse): Unit = { -- if (v.hasSchema) { -- assert(executeHolder.eventsManager.status == ExecuteStatus.Analyzed) -- } - if (v.hasMetrics) { - assert(executeHolder.eventsManager.status == ExecuteStatus.Finished) - } -@@ -896,6 +896,8 @@ class SparkConnectServiceSuite - } - def onCompleted(producedRowCount: Option[Long] = None): Unit = { - assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) -+ } -+ def assertClosed(): Unit = { - // The eventsManager is closed asynchronously - Eventually.eventually(EVENT_WAIT_TIMEOUT) { - assert( -diff --git a/sql/connect/shims/pom.xml b/sql/connect/shims/pom.xml -index 236d1624bfa..ad4d88bf293 100644 ---- a/sql/connect/shims/pom.xml -+++ b/sql/connect/shims/pom.xml -@@ -22,7 +22,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../../pom.xml - - diff --git a/sql/core/pom.xml b/sql/core/pom.xml -index dcf6223a98b..642d9b444e5 100644 +index dcf6223a98b..0458a5bb640 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml -@@ -22,7 +22,7 @@ - - org.apache.spark - spark-parent_2.13 -- 4.0.1 -+ 4.0.3-SNAPSHOT - ../../pom.xml - - @@ -90,6 +90,10 @@ org.apache.spark spark-tags_${scala.binary.version} @@ -4456,33 +52,6 @@ index dcf6223a98b..642d9b444e5 100644