Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 191 additions & 10 deletions native/core/src/parquet/cast_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
// specific language governing permissions and limitations
// under the License.
use arrow::{
array::{make_array, ArrayRef, TimestampMicrosecondArray, TimestampMillisecondArray},
array::{
make_array, Array, ArrayRef, LargeListArray, ListArray, MapArray, StructArray,
TimestampMicrosecondArray, TimestampMillisecondArray,
},
compute::CastOptions,
datatypes::{DataType, FieldRef, Schema, TimeUnit},
record_batch::RecordBatch,
Expand Down Expand Up @@ -78,13 +81,66 @@ fn relabel_array(array: ArrayRef, target_type: &DataType) -> ArrayRef {
if array.data_type() == target_type {
return array;
}
let data = array.to_data();
let new_data = data
.into_builder()
.data_type(target_type.clone())
.build()
.expect("relabel_array: data layout must be compatible");
make_array(new_data)
match target_type {
DataType::List(target_field) => {
let list = array.as_any().downcast_ref::<ListArray>().unwrap();
let values = relabel_array(Arc::clone(list.values()), target_field.data_type());
Arc::new(ListArray::new(
Arc::clone(target_field),
list.offsets().clone(),
values,
list.nulls().cloned(),
))
}
DataType::LargeList(target_field) => {
let list = array.as_any().downcast_ref::<LargeListArray>().unwrap();
let values = relabel_array(Arc::clone(list.values()), target_field.data_type());
Arc::new(LargeListArray::new(
Arc::clone(target_field),
list.offsets().clone(),
values,
list.nulls().cloned(),
))
}
DataType::Map(target_entries_field, sorted) => {
let map = array.as_any().downcast_ref::<MapArray>().unwrap();
let entries = relabel_array(
Arc::new(map.entries().clone()),
target_entries_field.data_type(),
);
let entries_struct = entries.as_any().downcast_ref::<StructArray>().unwrap();
Arc::new(MapArray::new(
Arc::clone(target_entries_field),
map.offsets().clone(),
entries_struct.clone(),
map.nulls().cloned(),
*sorted,
))
}
DataType::Struct(target_fields) => {
let struct_arr = array.as_any().downcast_ref::<StructArray>().unwrap();
let columns: Vec<ArrayRef> = target_fields
.iter()
.zip(struct_arr.columns())
.map(|(tf, col)| relabel_array(Arc::clone(col), tf.data_type()))
.collect();
Arc::new(StructArray::new(
target_fields.clone(),
columns,
struct_arr.nulls().cloned(),
))
}
// Primitive types - shallow swap is safe
_ => {
let data = array.to_data();
let new_data = data
.into_builder()
.data_type(target_type.clone())
.build()
.expect("relabel_array: data layout must be compatible");
make_array(new_data)
}
}
}

/// Casts a Timestamp(Microsecond) array to Timestamp(Millisecond) by dividing values by 1000.
Expand Down Expand Up @@ -300,8 +356,8 @@ impl PhysicalExpr for CometCastColumnExpr {
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Array;
use arrow::datatypes::Field;
use arrow::array::{Array, Int32Array, StringArray};
use arrow::datatypes::{Field, Fields};
use datafusion::physical_expr::expressions::Column;

#[test]
Expand Down Expand Up @@ -455,4 +511,129 @@ mod tests {
_ => panic!("Expected Scalar result"),
}
}

#[test]
fn test_relabel_list_field_name() {
// Physical: List(Field("item", Int32))
// Logical: List(Field("element", Int32))
let physical_field = Arc::new(Field::new("item", DataType::Int32, true));
let logical_field = Arc::new(Field::new("element", DataType::Int32, true));

let values = Int32Array::from(vec![1, 2, 3]);
let list = ListArray::new(
physical_field,
arrow::buffer::OffsetBuffer::new(vec![0, 2, 3].into()),
Arc::new(values),
None,
);
let array: ArrayRef = Arc::new(list);

let target_type = DataType::List(logical_field.clone());
let result = relabel_array(array, &target_type);
assert_eq!(result.data_type(), &target_type);
}

#[test]
fn test_relabel_map_entries_field_name() {
// Physical: Map(Field("key_value", Struct{key, value}))
// Logical: Map(Field("entries", Struct{key, value}))
let key_field = Arc::new(Field::new("key", DataType::Utf8, false));
let value_field = Arc::new(Field::new("value", DataType::Int32, true));
let struct_fields = Fields::from(vec![key_field.clone(), value_field.clone()]);

let physical_entries_field = Arc::new(Field::new(
"key_value",
DataType::Struct(struct_fields.clone()),
false,
));
let logical_entries_field = Arc::new(Field::new(
"entries",
DataType::Struct(struct_fields.clone()),
false,
));

let keys = StringArray::from(vec!["a", "b"]);
let values = Int32Array::from(vec![1, 2]);
let entries = StructArray::new(struct_fields, vec![Arc::new(keys), Arc::new(values)], None);
let map = MapArray::new(
physical_entries_field,
arrow::buffer::OffsetBuffer::new(vec![0, 2].into()),
entries,
None,
false,
);
let array: ArrayRef = Arc::new(map);

let target_type = DataType::Map(logical_entries_field, false);
let result = relabel_array(array, &target_type);
assert_eq!(result.data_type(), &target_type);
}

#[test]
fn test_relabel_struct_metadata() {
// Physical: Struct { Field("a", Int32, metadata={"PARQUET:field_id": "1"}) }
// Logical: Struct { Field("a", Int32, metadata={}) }
let mut metadata = std::collections::HashMap::new();
metadata.insert("PARQUET:field_id".to_string(), "1".to_string());
let physical_field =
Arc::new(Field::new("a", DataType::Int32, true).with_metadata(metadata));
let logical_field = Arc::new(Field::new("a", DataType::Int32, true));

let col = Int32Array::from(vec![10, 20]);
let physical_fields = Fields::from(vec![physical_field]);
let logical_fields = Fields::from(vec![logical_field]);

let struct_arr = StructArray::new(physical_fields, vec![Arc::new(col)], None);
let array: ArrayRef = Arc::new(struct_arr);

let target_type = DataType::Struct(logical_fields);
let result = relabel_array(array, &target_type);
assert_eq!(result.data_type(), &target_type);
}

#[test]
fn test_relabel_nested_struct_containing_list() {
// Physical: Struct { Field("col", List(Field("item", Int32))) }
// Logical: Struct { Field("col", List(Field("element", Int32))) }
let physical_list_field = Arc::new(Field::new("item", DataType::Int32, true));
let logical_list_field = Arc::new(Field::new("element", DataType::Int32, true));

let physical_struct_field = Arc::new(Field::new(
"col",
DataType::List(physical_list_field.clone()),
true,
));
let logical_struct_field = Arc::new(Field::new(
"col",
DataType::List(logical_list_field.clone()),
true,
));

let values = Int32Array::from(vec![1, 2, 3]);
let list = ListArray::new(
physical_list_field,
arrow::buffer::OffsetBuffer::new(vec![0, 2, 3].into()),
Arc::new(values),
None,
);

let physical_fields = Fields::from(vec![physical_struct_field]);
let logical_fields = Fields::from(vec![logical_struct_field]);

let struct_arr = StructArray::new(physical_fields, vec![Arc::new(list) as ArrayRef], None);
let array: ArrayRef = Arc::new(struct_arr);

let target_type = DataType::Struct(logical_fields);
let result = relabel_array(array, &target_type);
assert_eq!(result.data_type(), &target_type);

// Verify we can access the nested data without panics
let result_struct = result.as_any().downcast_ref::<StructArray>().unwrap();
let result_list = result_struct
.column(0)
.as_any()
.downcast_ref::<ListArray>()
.unwrap();
assert_eq!(result_list.len(), 2);
}
}
Loading