Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 128 additions & 5 deletions src/fdw/convert.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
use arrow::array::cast::AsArray;
use arrow::array::{
Array, BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray,
FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray, Int16Array,
Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray,
StructArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
FixedSizeBinaryArray, FixedSizeListArray, Float16Array, Float32Array, Float64Array,
GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray,
LargeStringArray, StringArray, StructArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array,
UInt32Array, UInt64Array, UInt8Array,
};
use arrow::datatypes::{DataType, TimeUnit as ArrowTimeUnit};
use base64::{engine::general_purpose::STANDARD, Engine as _};
use chrono::Datelike;
use pgrx::datum::{Date, Timestamp, TimestampWithTimeZone};
use pgrx::pg_sys;
use pgrx::prelude::IntoDatum;
use pgrx::AnyNumeric;
use pgrx::JsonB;
use serde_json::{Map, Number, Value};

Expand Down Expand Up @@ -256,6 +259,41 @@ pub fn arrow_value_to_datum(
Ok((ts.into_datum().ok_or("failed to convert timestamp")?, false))
}
}
DataType::Decimal128(_, _) => {
let v = array
.as_any()
.downcast_ref::<Decimal128Array>()
.ok_or("invalid decimal128 array")?
.value_as_string(row_idx);
let numeric = AnyNumeric::try_from(v.as_str()).map_err(|_| "invalid numeric")?;
Ok((
numeric.into_datum().ok_or("failed to convert numeric")?,
false,
))
}
DataType::Decimal256(_, _) => {
let v = array
.as_any()
.downcast_ref::<Decimal256Array>()
.ok_or("invalid decimal256 array")?
.value_as_string(row_idx);
let numeric = AnyNumeric::try_from(v.as_str()).map_err(|_| "invalid numeric")?;
Ok((
numeric.into_datum().ok_or("failed to convert numeric")?,
false,
))
}
DataType::Dictionary(_, _) => {
let dict = array
.as_any_dictionary_opt()
.ok_or("invalid dictionary array")?;
let value_idx = dictionary_key_to_usize(dict.keys(), row_idx)?;
let values = dict.values().as_ref();
if value_idx >= values.len() {
return Err("dictionary key out of range");
}
arrow_value_to_datum(values, value_idx, target_type_oid)
}
DataType::List(_) | DataType::LargeList(_) => {
let elem_oid = unsafe { pg_sys::get_element_type(target_type_oid) };
if elem_oid == pg_sys::InvalidOid {
Expand Down Expand Up @@ -631,6 +669,33 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result<Value, &'sta
}
Ok(Value::Object(map))
}
DataType::Decimal128(_, _) => {
let v = array
.as_any()
.downcast_ref::<Decimal128Array>()
.ok_or("invalid decimal128 array")?
.value_as_string(row_idx);
Ok(Value::String(v))
}
DataType::Decimal256(_, _) => {
let v = array
.as_any()
.downcast_ref::<Decimal256Array>()
.ok_or("invalid decimal256 array")?
.value_as_string(row_idx);
Ok(Value::String(v))
}
DataType::Dictionary(_, _) => {
let dict = array
.as_any_dictionary_opt()
.ok_or("invalid dictionary array")?;
let value_idx = dictionary_key_to_usize(dict.keys(), row_idx)?;
let values = dict.values().as_ref();
if value_idx >= values.len() {
return Err("dictionary key out of range");
}
arrow_value_to_json(values, value_idx)
}
_ => Ok(Value::String(format!(
"<unsupported_type: {:?}>",
array.data_type()
Expand All @@ -641,3 +706,61 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result<Value, &'sta
fn json_number(v: i64) -> Value {
Value::Number(Number::from(v))
}

fn dictionary_key_to_usize(keys: &dyn Array, row_idx: usize) -> Result<usize, &'static str> {
match keys.data_type() {
DataType::Int8 => {
let v = keys
.as_any()
.downcast_ref::<Int8Array>()
.ok_or("invalid dictionary keys (int8)")?
.value(row_idx) as i64;
usize::try_from(v).map_err(|_| "negative dictionary key")
}
DataType::Int16 => {
let v = keys
.as_any()
.downcast_ref::<Int16Array>()
.ok_or("invalid dictionary keys (int16)")?
.value(row_idx) as i64;
usize::try_from(v).map_err(|_| "negative dictionary key")
}
DataType::Int32 => {
let v = keys
.as_any()
.downcast_ref::<Int32Array>()
.ok_or("invalid dictionary keys (int32)")?
.value(row_idx) as i64;
usize::try_from(v).map_err(|_| "negative dictionary key")
}
DataType::Int64 => {
let v = keys
.as_any()
.downcast_ref::<Int64Array>()
.ok_or("invalid dictionary keys (int64)")?
.value(row_idx);
usize::try_from(v).map_err(|_| "negative dictionary key")
}
DataType::UInt8 => Ok(keys
.as_any()
.downcast_ref::<UInt8Array>()
.ok_or("invalid dictionary keys (uint8)")?
.value(row_idx) as usize),
DataType::UInt16 => Ok(keys
.as_any()
.downcast_ref::<UInt16Array>()
.ok_or("invalid dictionary keys (uint16)")?
.value(row_idx) as usize),
DataType::UInt32 => Ok(keys
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or("invalid dictionary keys (uint32)")?
.value(row_idx) as usize),
DataType::UInt64 => Ok(keys
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or("invalid dictionary keys (uint64)")?
.value(row_idx) as usize),
_ => Err("unsupported dictionary key type"),
}
}
6 changes: 4 additions & 2 deletions src/fdw/type_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ fn pg_type_for_arrow(
match field.data_type() {
DataType::Boolean => Ok("boolean".to_string()),
DataType::Int8 | DataType::UInt8 => Ok("int2".to_string()),
DataType::Int16 | DataType::UInt16 => Ok("int2".to_string()),
DataType::Int32 | DataType::UInt32 => Ok("int4".to_string()),
DataType::Int16 => Ok("int2".to_string()),
DataType::UInt16 => Ok("int4".to_string()),
DataType::Int32 => Ok("int4".to_string()),
DataType::UInt32 => Ok("int8".to_string()),
DataType::Int64 | DataType::UInt64 => Ok("int8".to_string()),
DataType::Float16 | DataType::Float32 => Ok("float4".to_string()),
DataType::Float64 => Ok("float8".to_string()),
Expand Down
61 changes: 56 additions & 5 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use pgrx::prelude::*;
mod tests {
use super::*;
use arrow::array::{
Array, BooleanArray, Float32Array, Int32Array, ListBuilder, StringArray, StructArray,
builder::StringDictionaryBuilder, Array, BooleanArray, Decimal128Array, Float32Array,
Int32Array, ListBuilder, StringArray, StructArray, UInt16Array, UInt32Array,
};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, Field, Int32Type, Schema};
use arrow::record_batch::RecordBatch;
use lance_rs::Dataset;
use sqllogictest::{DBOutput, DefaultColumnType, Runner};
Expand Down Expand Up @@ -145,6 +146,49 @@ mod tests {
})
}

fn create_table_with_decimal_and_dictionary(
&self,
) -> Result<std::path::PathBuf, Box<dyn std::error::Error>> {
let table_path = self.temp_dir.path().join("fdw_misc");

let u16_array = UInt16Array::from(vec![1, u16::MAX, 2]);
let u32_array = UInt32Array::from(vec![1, u32::MAX, 42]);

let dec_array = Decimal128Array::from(vec![Some(12345i128), Some(-10i128), None])
.with_precision_and_scale(10, 2)?;

let mut dict_builder = StringDictionaryBuilder::<Int32Type>::new();
dict_builder.append("foo")?;
dict_builder.append("bar")?;
dict_builder.append_null();
let dict_array = dict_builder.finish();

let schema = Arc::new(Schema::new(vec![
Field::new("u16", DataType::UInt16, false),
Field::new("u32", DataType::UInt32, false),
Field::new("dec", dec_array.data_type().clone(), true),
Field::new("dict", dict_array.data_type().clone(), true),
]));

let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(u16_array),
Arc::new(u32_array),
Arc::new(dec_array),
Arc::new(dict_array),
],
)?;

let reader = arrow::record_batch::RecordBatchIterator::new(vec![Ok(batch)], schema);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
Dataset::write(reader, table_path.to_str().unwrap(), None).await
})?;

Ok(table_path)
}

fn create_table_with_struct_and_list(
&self,
) -> Result<std::path::PathBuf, Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -254,10 +298,15 @@ mod tests {
Spi::run("SELECT pg_advisory_lock(424242)").expect("advisory lock");

let gen = LanceTestDataGenerator::new().expect("generator");
let path = gen
let struct_list_path = gen
.create_table_with_struct_and_list()
.expect("create table");
let uri = path.to_str().expect("uri").replace('\'', "''");
let struct_list_uri = struct_list_path.to_str().expect("uri").replace('\'', "''");

let misc_path = gen
.create_table_with_decimal_and_dictionary()
.expect("create table");
let misc_uri = misc_path.to_str().expect("uri").replace('\'', "''");

let scripts_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/sql");
let slt_files = list_slt_files(&scripts_dir);
Expand All @@ -276,7 +325,9 @@ mod tests {
let server = format!("{}_srv", schema);

let mut script = fs::read_to_string(file).expect("read .slt file");
script = script.replace("${LANCE_URI}", &uri);
script = script.replace("${LANCE_URI}", &struct_list_uri);
script = script.replace("${LANCE_URI_STRUCT_LIST}", &struct_list_uri);
script = script.replace("${LANCE_URI_MISC}", &misc_uri);
script = script.replace("${SCHEMA}", &schema);
script = script.replace("${SERVER}", &server);

Expand Down
37 changes: 37 additions & 0 deletions tests/sql/01_type_mapping.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Ensure declared type mapping matches runtime value conversion.

statement ok
SELECT lance_import('${SERVER}', '${SCHEMA}', 't_misc', '${LANCE_URI_MISC}', NULL);

query T
SELECT atttypid::regtype::text
FROM pg_attribute
WHERE attrelid = 't_misc'::regclass
AND attname IN ('u16', 'u32', 'dec', 'dict')
ORDER BY attname;
----
numeric
text
integer
bigint

query I
SELECT u32 FROM t_misc WHERE u16 = 65535;
----
4294967295

query T
SELECT dec::text FROM t_misc WHERE u16 = 1;
----
123.45

query T
SELECT dict FROM t_misc WHERE u16 = 1;
----
foo

query T
SELECT (dec IS NULL) AND (dict IS NULL) FROM t_misc WHERE u16 = 2;
----
t

Loading