diff --git a/native/core/src/parquet/cast_column.rs b/native/core/src/parquet/cast_column.rs index a44166a70b..8c62735562 100644 --- a/native/core/src/parquet/cast_column.rs +++ b/native/core/src/parquet/cast_column.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. use arrow::{ - array::{make_array, ArrayRef, TimestampMicrosecondArray, TimestampMillisecondArray}, + array::{ + make_array, Array, ArrayRef, LargeListArray, ListArray, MapArray, StructArray, + TimestampMicrosecondArray, TimestampMillisecondArray, + }, compute::CastOptions, datatypes::{DataType, FieldRef, Schema, TimeUnit}, record_batch::RecordBatch, @@ -78,13 +81,66 @@ fn relabel_array(array: ArrayRef, target_type: &DataType) -> ArrayRef { if array.data_type() == target_type { return array; } - let data = array.to_data(); - let new_data = data - .into_builder() - .data_type(target_type.clone()) - .build() - .expect("relabel_array: data layout must be compatible"); - make_array(new_data) + match target_type { + DataType::List(target_field) => { + let list = array.as_any().downcast_ref::().unwrap(); + let values = relabel_array(Arc::clone(list.values()), target_field.data_type()); + Arc::new(ListArray::new( + Arc::clone(target_field), + list.offsets().clone(), + values, + list.nulls().cloned(), + )) + } + DataType::LargeList(target_field) => { + let list = array.as_any().downcast_ref::().unwrap(); + let values = relabel_array(Arc::clone(list.values()), target_field.data_type()); + Arc::new(LargeListArray::new( + Arc::clone(target_field), + list.offsets().clone(), + values, + list.nulls().cloned(), + )) + } + DataType::Map(target_entries_field, sorted) => { + let map = array.as_any().downcast_ref::().unwrap(); + let entries = relabel_array( + Arc::new(map.entries().clone()), + target_entries_field.data_type(), + ); + let entries_struct = entries.as_any().downcast_ref::().unwrap(); + Arc::new(MapArray::new( + Arc::clone(target_entries_field), + map.offsets().clone(), + entries_struct.clone(), + map.nulls().cloned(), + *sorted, + )) + } + DataType::Struct(target_fields) => { + let struct_arr = array.as_any().downcast_ref::().unwrap(); + let columns: Vec = target_fields + .iter() + .zip(struct_arr.columns()) + .map(|(tf, col)| relabel_array(Arc::clone(col), tf.data_type())) + .collect(); + Arc::new(StructArray::new( + target_fields.clone(), + columns, + struct_arr.nulls().cloned(), + )) + } + // Primitive types - shallow swap is safe + _ => { + let data = array.to_data(); + let new_data = data + .into_builder() + .data_type(target_type.clone()) + .build() + .expect("relabel_array: data layout must be compatible"); + make_array(new_data) + } + } } /// Casts a Timestamp(Microsecond) array to Timestamp(Millisecond) by dividing values by 1000. @@ -300,8 +356,8 @@ impl PhysicalExpr for CometCastColumnExpr { #[cfg(test)] mod tests { use super::*; - use arrow::array::Array; - use arrow::datatypes::Field; + use arrow::array::{Array, Int32Array, StringArray}; + use arrow::datatypes::{Field, Fields}; use datafusion::physical_expr::expressions::Column; #[test] @@ -455,4 +511,129 @@ mod tests { _ => panic!("Expected Scalar result"), } } + + #[test] + fn test_relabel_list_field_name() { + // Physical: List(Field("item", Int32)) + // Logical: List(Field("element", Int32)) + let physical_field = Arc::new(Field::new("item", DataType::Int32, true)); + let logical_field = Arc::new(Field::new("element", DataType::Int32, true)); + + let values = Int32Array::from(vec![1, 2, 3]); + let list = ListArray::new( + physical_field, + arrow::buffer::OffsetBuffer::new(vec![0, 2, 3].into()), + Arc::new(values), + None, + ); + let array: ArrayRef = Arc::new(list); + + let target_type = DataType::List(logical_field.clone()); + let result = relabel_array(array, &target_type); + assert_eq!(result.data_type(), &target_type); + } + + #[test] + fn test_relabel_map_entries_field_name() { + // Physical: Map(Field("key_value", Struct{key, value})) + // Logical: Map(Field("entries", Struct{key, value})) + let key_field = Arc::new(Field::new("key", DataType::Utf8, false)); + let value_field = Arc::new(Field::new("value", DataType::Int32, true)); + let struct_fields = Fields::from(vec![key_field.clone(), value_field.clone()]); + + let physical_entries_field = Arc::new(Field::new( + "key_value", + DataType::Struct(struct_fields.clone()), + false, + )); + let logical_entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(struct_fields.clone()), + false, + )); + + let keys = StringArray::from(vec!["a", "b"]); + let values = Int32Array::from(vec![1, 2]); + let entries = StructArray::new(struct_fields, vec![Arc::new(keys), Arc::new(values)], None); + let map = MapArray::new( + physical_entries_field, + arrow::buffer::OffsetBuffer::new(vec![0, 2].into()), + entries, + None, + false, + ); + let array: ArrayRef = Arc::new(map); + + let target_type = DataType::Map(logical_entries_field, false); + let result = relabel_array(array, &target_type); + assert_eq!(result.data_type(), &target_type); + } + + #[test] + fn test_relabel_struct_metadata() { + // Physical: Struct { Field("a", Int32, metadata={"PARQUET:field_id": "1"}) } + // Logical: Struct { Field("a", Int32, metadata={}) } + let mut metadata = std::collections::HashMap::new(); + metadata.insert("PARQUET:field_id".to_string(), "1".to_string()); + let physical_field = + Arc::new(Field::new("a", DataType::Int32, true).with_metadata(metadata)); + let logical_field = Arc::new(Field::new("a", DataType::Int32, true)); + + let col = Int32Array::from(vec![10, 20]); + let physical_fields = Fields::from(vec![physical_field]); + let logical_fields = Fields::from(vec![logical_field]); + + let struct_arr = StructArray::new(physical_fields, vec![Arc::new(col)], None); + let array: ArrayRef = Arc::new(struct_arr); + + let target_type = DataType::Struct(logical_fields); + let result = relabel_array(array, &target_type); + assert_eq!(result.data_type(), &target_type); + } + + #[test] + fn test_relabel_nested_struct_containing_list() { + // Physical: Struct { Field("col", List(Field("item", Int32))) } + // Logical: Struct { Field("col", List(Field("element", Int32))) } + let physical_list_field = Arc::new(Field::new("item", DataType::Int32, true)); + let logical_list_field = Arc::new(Field::new("element", DataType::Int32, true)); + + let physical_struct_field = Arc::new(Field::new( + "col", + DataType::List(physical_list_field.clone()), + true, + )); + let logical_struct_field = Arc::new(Field::new( + "col", + DataType::List(logical_list_field.clone()), + true, + )); + + let values = Int32Array::from(vec![1, 2, 3]); + let list = ListArray::new( + physical_list_field, + arrow::buffer::OffsetBuffer::new(vec![0, 2, 3].into()), + Arc::new(values), + None, + ); + + let physical_fields = Fields::from(vec![physical_struct_field]); + let logical_fields = Fields::from(vec![logical_struct_field]); + + let struct_arr = StructArray::new(physical_fields, vec![Arc::new(list) as ArrayRef], None); + let array: ArrayRef = Arc::new(struct_arr); + + let target_type = DataType::Struct(logical_fields); + let result = relabel_array(array, &target_type); + assert_eq!(result.data_type(), &target_type); + + // Verify we can access the nested data without panics + let result_struct = result.as_any().downcast_ref::().unwrap(); + let result_list = result_struct + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result_list.len(), 2); + } }