|
| 1 | +use std::any::Any; |
| 2 | +use std::sync::Arc; |
| 3 | + |
| 4 | +use datafusion::arrow::array::{Array, ArrayRef, AsArray, UnionArray}; |
| 5 | +use datafusion::arrow::datatypes::{ |
| 6 | + DataType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, |
| 7 | + UInt8Type, |
| 8 | +}; |
| 9 | +use datafusion::common::{exec_datafusion_err, exec_err, plan_err, Result as DataFusionResult, ScalarValue}; |
| 10 | +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; |
| 11 | + |
| 12 | +use crate::common_macros::make_udf_function; |
| 13 | +use crate::common_union::{JsonUnion, JsonUnionField}; |
| 14 | + |
| 15 | +make_udf_function!( |
| 16 | + JsonFromScalar, |
| 17 | + json_from_scalar, |
| 18 | + value, |
| 19 | + r"Convert a scalar value (null, bool, integer, float, or string) to a JSON union type" |
| 20 | +); |
| 21 | + |
| 22 | +#[derive(Debug, PartialEq, Eq, Hash)] |
| 23 | +pub(super) struct JsonFromScalar { |
| 24 | + signature: Signature, |
| 25 | + aliases: [String; 2], |
| 26 | +} |
| 27 | + |
| 28 | +impl Default for JsonFromScalar { |
| 29 | + fn default() -> Self { |
| 30 | + Self { |
| 31 | + signature: Signature::any(1, Volatility::Immutable), |
| 32 | + aliases: ["json_from_scalar".to_string(), "scalar_to_json".to_string()], |
| 33 | + } |
| 34 | + } |
| 35 | +} |
| 36 | + |
| 37 | +impl ScalarUDFImpl for JsonFromScalar { |
| 38 | + fn as_any(&self) -> &dyn Any { |
| 39 | + self |
| 40 | + } |
| 41 | + |
| 42 | + fn name(&self) -> &str { |
| 43 | + self.aliases[0].as_str() |
| 44 | + } |
| 45 | + |
| 46 | + fn signature(&self) -> &Signature { |
| 47 | + &self.signature |
| 48 | + } |
| 49 | + |
| 50 | + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> { |
| 51 | + // Check that the input type is a scalar type that we can convert to JSON |
| 52 | + // The signature check ensures we only get one argument, index access is safe |
| 53 | + match arg_types[0] { |
| 54 | + DataType::Null |
| 55 | + | DataType::Boolean |
| 56 | + | DataType::Int8 |
| 57 | + | DataType::Int16 |
| 58 | + | DataType::Int32 |
| 59 | + | DataType::Int64 |
| 60 | + | DataType::UInt8 |
| 61 | + | DataType::UInt16 |
| 62 | + | DataType::UInt32 |
| 63 | + | DataType::UInt64 |
| 64 | + | DataType::Float32 |
| 65 | + | DataType::Float64 |
| 66 | + | DataType::Utf8 |
| 67 | + | DataType::LargeUtf8 |
| 68 | + | DataType::Utf8View => {} |
| 69 | + _ => { |
| 70 | + return plan_err!("Unsupported type for json_from_scalar: {:?}", arg_types[0]); |
| 71 | + } |
| 72 | + } |
| 73 | + Ok(JsonUnion::data_type()) |
| 74 | + } |
| 75 | + |
| 76 | + fn invoke_with_args(&self, mut args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> { |
| 77 | + // The signature check ensures we only get one argument |
| 78 | + match args.args.pop().expect("Expected exactly one argument") { |
| 79 | + ColumnarValue::Scalar(scalar) => { |
| 80 | + let field = scalar_to_json_union_field(scalar)?; |
| 81 | + Ok(ColumnarValue::Scalar(JsonUnionField::scalar_value(Some(field)))) |
| 82 | + } |
| 83 | + ColumnarValue::Array(array) => { |
| 84 | + let union = array_to_json_union(&array)?; |
| 85 | + let union_array: UnionArray = union.try_into()?; |
| 86 | + Ok(ColumnarValue::Array(Arc::new(union_array) as ArrayRef)) |
| 87 | + } |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + fn aliases(&self) -> &[String] { |
| 92 | + &self.aliases |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +fn scalar_to_json_union_field(scalar: ScalarValue) -> DataFusionResult<JsonUnionField> { |
| 97 | + match scalar { |
| 98 | + // Null type / values |
| 99 | + ScalarValue::Null |
| 100 | + | ScalarValue::Boolean(None) |
| 101 | + | ScalarValue::Int8(None) |
| 102 | + | ScalarValue::Int16(None) |
| 103 | + | ScalarValue::Int32(None) |
| 104 | + | ScalarValue::Int64(None) |
| 105 | + | ScalarValue::UInt8(None) |
| 106 | + | ScalarValue::UInt16(None) |
| 107 | + | ScalarValue::UInt32(None) |
| 108 | + | ScalarValue::UInt64(None) |
| 109 | + | ScalarValue::Float32(None) |
| 110 | + | ScalarValue::Float64(None) |
| 111 | + | ScalarValue::Utf8(None) |
| 112 | + | ScalarValue::LargeUtf8(None) |
| 113 | + | ScalarValue::Utf8View(None) => Ok(JsonUnionField::JsonNull), |
| 114 | + // Boolean type |
| 115 | + ScalarValue::Boolean(Some(b)) => Ok(JsonUnionField::Bool(b)), |
| 116 | + // Integer types - coerce to i64 |
| 117 | + ScalarValue::Int8(Some(v)) => Ok(JsonUnionField::Int(i64::from(v))), |
| 118 | + ScalarValue::Int16(Some(v)) => Ok(JsonUnionField::Int(i64::from(v))), |
| 119 | + ScalarValue::Int32(Some(v)) => Ok(JsonUnionField::Int(i64::from(v))), |
| 120 | + ScalarValue::Int64(Some(v)) => Ok(JsonUnionField::Int(v)), |
| 121 | + ScalarValue::UInt8(Some(v)) => Ok(JsonUnionField::Int(i64::from(v))), |
| 122 | + ScalarValue::UInt16(Some(v)) => Ok(JsonUnionField::Int(i64::from(v))), |
| 123 | + ScalarValue::UInt32(Some(v)) => Ok(JsonUnionField::Int(i64::from(v))), |
| 124 | + ScalarValue::UInt64(Some(v)) => { |
| 125 | + Ok(JsonUnionField::Int(i64::try_from(v).map_err(|_| { |
| 126 | + exec_datafusion_err!("UInt64 value {} is out of range for i64", v) |
| 127 | + })?)) |
| 128 | + } |
| 129 | + // Float types - coerce to f64 |
| 130 | + ScalarValue::Float32(Some(v)) => Ok(JsonUnionField::Float(f64::from(v))), |
| 131 | + ScalarValue::Float64(Some(v)) => Ok(JsonUnionField::Float(v)), |
| 132 | + // String types |
| 133 | + ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => { |
| 134 | + Ok(JsonUnionField::Str(s)) |
| 135 | + } |
| 136 | + _ => exec_err!("Unsupported type for json_from_scalar: {:?}", scalar.data_type()), |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +fn array_to_json_union(array: &ArrayRef) -> DataFusionResult<JsonUnion> { |
| 141 | + Ok(match array.data_type() { |
| 142 | + DataType::Null => (0..array.len()).map(|_| Some(JsonUnionField::JsonNull)).collect(), |
| 143 | + DataType::Boolean => array.as_boolean().iter().map(|v| v.map(JsonUnionField::Bool)).collect(), |
| 144 | + // Integer types - coerce to i64 |
| 145 | + DataType::Int8 => array |
| 146 | + .as_primitive::<Int8Type>() |
| 147 | + .iter() |
| 148 | + .map(|v| v.map(|x| JsonUnionField::Int(i64::from(x)))) |
| 149 | + .collect(), |
| 150 | + DataType::Int16 => array |
| 151 | + .as_primitive::<Int16Type>() |
| 152 | + .iter() |
| 153 | + .map(|v| v.map(|x| JsonUnionField::Int(i64::from(x)))) |
| 154 | + .collect(), |
| 155 | + DataType::Int32 => array |
| 156 | + .as_primitive::<Int32Type>() |
| 157 | + .iter() |
| 158 | + .map(|v| v.map(|x| JsonUnionField::Int(i64::from(x)))) |
| 159 | + .collect(), |
| 160 | + DataType::Int64 => array |
| 161 | + .as_primitive::<Int64Type>() |
| 162 | + .iter() |
| 163 | + .map(|v| v.map(JsonUnionField::Int)) |
| 164 | + .collect(), |
| 165 | + DataType::UInt8 => array |
| 166 | + .as_primitive::<UInt8Type>() |
| 167 | + .iter() |
| 168 | + .map(|v| v.map(|x| JsonUnionField::Int(i64::from(x)))) |
| 169 | + .collect(), |
| 170 | + DataType::UInt16 => array |
| 171 | + .as_primitive::<UInt16Type>() |
| 172 | + .iter() |
| 173 | + .map(|v| v.map(|x| JsonUnionField::Int(i64::from(x)))) |
| 174 | + .collect(), |
| 175 | + DataType::UInt32 => array |
| 176 | + .as_primitive::<UInt32Type>() |
| 177 | + .iter() |
| 178 | + .map(|v| v.map(|x| JsonUnionField::Int(i64::from(x)))) |
| 179 | + .collect(), |
| 180 | + DataType::UInt64 => { |
| 181 | + // UInt64 requires explicit loop for fallible conversion |
| 182 | + let arr = array.as_primitive::<UInt64Type>(); |
| 183 | + let mut union = JsonUnion::new(arr.len()); |
| 184 | + for i in 0..arr.len() { |
| 185 | + if arr.is_null(i) { |
| 186 | + union.push_none(); |
| 187 | + } else { |
| 188 | + union.push(JsonUnionField::Int(i64::try_from(arr.value(i)).map_err(|_| { |
| 189 | + exec_datafusion_err!("UInt64 value {} is out of range for i64", arr.value(i)) |
| 190 | + })?)); |
| 191 | + } |
| 192 | + } |
| 193 | + return Ok(union); |
| 194 | + } |
| 195 | + // Float types - coerce to f64 |
| 196 | + DataType::Float32 => array |
| 197 | + .as_primitive::<Float32Type>() |
| 198 | + .iter() |
| 199 | + .map(|v| v.map(|x| JsonUnionField::Float(f64::from(x)))) |
| 200 | + .collect(), |
| 201 | + DataType::Float64 => array |
| 202 | + .as_primitive::<Float64Type>() |
| 203 | + .iter() |
| 204 | + .map(|v| v.map(JsonUnionField::Float)) |
| 205 | + .collect(), |
| 206 | + // String types |
| 207 | + DataType::Utf8 => array |
| 208 | + .as_string::<i32>() |
| 209 | + .iter() |
| 210 | + .map(|v| v.map(|s| JsonUnionField::Str(s.to_string()))) |
| 211 | + .collect(), |
| 212 | + DataType::LargeUtf8 => array |
| 213 | + .as_string::<i64>() |
| 214 | + .iter() |
| 215 | + .map(|v| v.map(|s| JsonUnionField::Str(s.to_string()))) |
| 216 | + .collect(), |
| 217 | + DataType::Utf8View => array |
| 218 | + .as_string_view() |
| 219 | + .iter() |
| 220 | + .map(|v| v.map(|s| JsonUnionField::Str(s.to_string()))) |
| 221 | + .collect(), |
| 222 | + dt => { |
| 223 | + return exec_err!("Unsupported array type for json_from_scalar: {:?}", dt); |
| 224 | + } |
| 225 | + }) |
| 226 | +} |
0 commit comments