Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
746 changes: 393 additions & 353 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ documentation = "https://docs.rs/crate/datafusion-postgres/"

[workspace.dependencies]
arrow = "57"
arrow-schema = "57"
bytes = "1.11.0"
chrono = { version = "0.4", features = ["std"] }
datafusion = { version = "51", default-features = false }
Expand Down
5 changes: 4 additions & 1 deletion arrow-pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ readme = "../README.md"
rust-version.workspace = true

[features]
default = ["arrow"]
default = ["arrow", "geo"]
arrow = ["dep:arrow"]
datafusion = ["dep:datafusion"]
geo = ["postgres-types/with-geo-types-0_7", "dep:geoarrow-schema"]

[dependencies]
arrow = { workspace = true, optional = true }
arrow-schema = { workspace = true}
geoarrow-schema = { version = "0.7", optional = true }
bytes.workspace = true
chrono.workspace = true
datafusion = { workspace = true, optional = true }
Expand Down
133 changes: 76 additions & 57 deletions arrow-pg/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;

#[cfg(not(feature = "datafusion"))]
use arrow::{datatypes::*, record_batch::RecordBatch};
use arrow_schema::extension::ExtensionType;
#[cfg(feature = "datafusion")]
use datafusion::arrow::{datatypes::*, record_batch::RecordBatch};

Expand All @@ -18,34 +19,45 @@ use crate::row_encoder::RowEncoder;
#[cfg(feature = "datafusion")]
pub mod df;

pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
Ok(match arrow_type {
DataType::Null => Type::UNKNOWN,
DataType::Boolean => Type::BOOL,
DataType::Int8 | DataType::UInt8 => Type::CHAR,
DataType::Int16 | DataType::UInt16 => Type::INT2,
DataType::Int32 | DataType::UInt32 => Type::INT4,
DataType::Int64 | DataType::UInt64 => Type::INT8,
DataType::Timestamp(_, tz) => {
if tz.is_some() {
Type::TIMESTAMPTZ
} else {
Type::TIMESTAMP
pub fn into_pg_type(field: &Arc<Field>) -> PgWireResult<Type> {
let arrow_type = field.data_type();

match field.extension_type_name() {
// As of arrow 56, there are additional extension logical type that is
// defined using field metadata, for instance, json or geo.
#[cfg(feature = "geo")]
Some(geoarrow_schema::PointType::NAME) => Ok(Type::POINT),

_ => Ok(match arrow_type {
DataType::Null => Type::UNKNOWN,
DataType::Boolean => Type::BOOL,
DataType::Int8 | DataType::UInt8 => Type::CHAR,
DataType::Int16 | DataType::UInt16 => Type::INT2,
DataType::Int32 | DataType::UInt32 => Type::INT4,
DataType::Int64 | DataType::UInt64 => Type::INT8,
DataType::Timestamp(_, tz) => {
if tz.is_some() {
Type::TIMESTAMPTZ
} else {
Type::TIMESTAMP
}
}
}
DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
DataType::Date32 | DataType::Date64 => Type::DATE,
DataType::Interval(_) => Type::INTERVAL,
DataType::Binary
| DataType::FixedSizeBinary(_)
| DataType::LargeBinary
| DataType::BinaryView => Type::BYTEA,
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
DataType::Float64 => Type::FLOAT8,
DataType::Decimal128(_, _) => Type::NUMERIC,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
match field.data_type() {
DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
DataType::Date32 | DataType::Date64 => Type::DATE,
DataType::Interval(_) => Type::INTERVAL,
DataType::Binary
| DataType::FixedSizeBinary(_)
| DataType::LargeBinary
| DataType::BinaryView => Type::BYTEA,
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
DataType::Float64 => Type::FLOAT8,
DataType::Decimal128(_, _) => Type::NUMERIC,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
DataType::List(field)
| DataType::FixedSizeList(field, _)
| DataType::LargeList(field)
| DataType::ListView(field)
| DataType::LargeListView(field) => match field.data_type() {
DataType::Boolean => Type::BOOL_ARRAY,
DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
Expand All @@ -68,10 +80,10 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
DataType::Float64 => Type::FLOAT8_ARRAY,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
struct_type @ DataType::Struct(_) => Type::new(
DataType::Struct(_) => Type::new(
Type::RECORD_ARRAY.name().into(),
Type::RECORD_ARRAY.oid(),
Kind::Array(into_pg_type(struct_type)?),
Kind::Array(into_pg_type(field)?),
Type::RECORD_ARRAY.schema().into(),
),
list_type => {
Expand All @@ -81,35 +93,42 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
format!("Unsupported List Datatype {list_type}"),
))));
}
},
DataType::Dictionary(_, value_type) => {
let field = Arc::new(Field::new(
Field::LIST_FIELD_DEFAULT_NAME,
*value_type.clone(),
true,
));
into_pg_type(&field)?
}
}
DataType::Dictionary(_, value_type) => into_pg_type(value_type)?,
DataType::Struct(fields) => {
let name: String = fields
.iter()
.map(|x| x.name().clone())
.reduce(|a, b| a + ", " + &b)
.map(|x| format!("({x})"))
.unwrap_or("()".to_string());
let kind = Kind::Composite(
fields
DataType::Struct(fields) => {
let name: String = fields
.iter()
.map(|x| {
into_pg_type(x.data_type())
.map(|_type| postgres_types::Field::new(x.name().clone(), _type))
})
.collect::<Result<Vec<_>, PgWireError>>()?,
);
Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
}
_ => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!("Unsupported Datatype {arrow_type}"),
))));
}
})
.map(|x| x.name().clone())
.reduce(|a, b| a + ", " + &b)
.map(|x| format!("({x})"))
.unwrap_or("()".to_string());
let kind = Kind::Composite(
fields
.iter()
.map(|x| {
into_pg_type(x)
.map(|_type| postgres_types::Field::new(x.name().clone(), _type))
})
.collect::<Result<Vec<_>, PgWireError>>()?,
);
Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
}
_ => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!("Unsupported Datatype {arrow_type}"),
))));
}
}),
}
}

pub fn arrow_schema_to_pg_fields(
Expand All @@ -123,7 +142,7 @@ pub fn arrow_schema_to_pg_fields(
.iter()
.enumerate()
.map(|(idx, f)| {
let pg_type = into_pg_type(f.data_type())?;
let pg_type = into_pg_type(f)?;
let mut field_info =
FieldInfo::new(f.name().into(), None, None, pg_type, format.format_for(idx));
if let Some(data_format_options) = &data_format_options {
Expand Down
4 changes: 2 additions & 2 deletions arrow-pg/src/datatypes/df.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::iter;
use std::sync::Arc;

use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike};
use datafusion::arrow::datatypes::{DataType, Date32Type, TimeUnit};
use datafusion::arrow::datatypes::{DataType, Date32Type, Field, TimeUnit};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::ParamValues;
use datafusion::prelude::*;
Expand Down Expand Up @@ -70,7 +70,7 @@ where
if let Some(ty) = pg_type_hint {
Ok(ty.clone())
} else if let Some(infer_type) = inferenced_type {
into_pg_type(infer_type)
into_pg_type(&Arc::new(Field::new("item", infer_type.clone(), true)))
} else {
Ok(Type::UNKNOWN)
}
Expand Down
44 changes: 30 additions & 14 deletions arrow-pg/src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::sync::Arc;

#[cfg(not(feature = "datafusion"))]
use arrow::{array::*, datatypes::*};
#[cfg(feature = "geo")]
use arrow_schema::extension::ExtensionType;
use bytes::BufMut;
use bytes::BytesMut;
use chrono::NaiveTime;
Expand Down Expand Up @@ -285,11 +287,30 @@ pub fn encode_value<T: Encoder>(
encoder: &mut T,
arr: &Arc<dyn Array>,
idx: usize,
arrow_field: &Field,
pg_field: &FieldInfo,
) -> PgWireResult<()> {
let type_ = pg_field.datatype();
let arrow_type = arrow_field.data_type();

match arrow_field.extension_type_name() {
#[cfg(feature = "geo")]
Some(geoarrow::datatype::PointType::NAME) => {
// downcast array as geoarrow
// convert to point
// encode as point
let geoarrow_array: Arc<dyn geoarrow::array::GeoArrowArray> =
geoarrow::array::from_arrow_array(array, field).unwrap();
match geoarrow_array.data_type() {
GeoArrowType::Point(_) => {
let array: &PointArray = geoarrow_array.as_point();
}
_ => todo!("handle other geometry types"),
}
}
_ => {}
}

match arr.data_type() {
match arrow_type {
DataType::Null => encoder.encode_field(&None::<i8>, pg_field)?,
DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx), pg_field)?,
DataType::Int8 => encoder.encode_field(&get_i8_value(arr, idx), pg_field)?,
Expand Down Expand Up @@ -423,16 +444,8 @@ pub fn encode_value<T: Encoder>(
let value = encode_list(array, pg_field)?;
encoder.encode_field(&value, pg_field)?
}
DataType::Struct(_) => {
let fields = match type_.kind() {
postgres_types::Kind::Composite(fields) => fields,
_ => {
return Err(PgWireError::ApiError(ToSqlError::from(format!(
"Failed to unwrap a composite type from type {type_}"
))));
}
};
let value = encode_struct(arr, idx, fields, pg_field)?;
DataType::Struct(arrow_fields) => {
let value = encode_struct(arr, idx, arrow_fields, pg_field)?;
encoder.encode_field(&value, pg_field)?
}
DataType::Dictionary(_, value_type) => {
Expand Down Expand Up @@ -463,7 +476,9 @@ pub fn encode_value<T: Encoder>(
))
})?;

encode_value(encoder, values, idx, pg_field)?
let inner_arrow_field = Field::new(pg_field.name(), *value_type.clone(), true);

encode_value(encoder, values, idx, &inner_arrow_field, pg_field)?
}
_ => {
return Err(PgWireError::ApiError(ToSqlError::from(format!(
Expand Down Expand Up @@ -512,8 +527,9 @@ mod tests {

let mut encoder = MockEncoder::default();

let arrow_field = Field::new("x", DataType::Utf8, true);
let pg_field = FieldInfo::new("x".to_string(), None, None, Type::TEXT, FieldFormat::Text);
let result = encode_value(&mut encoder, &dict_arr, 2, &pg_field);
let result = encode_value(&mut encoder, &dict_arr, 2, &arrow_field, &pg_field);

assert!(result.is_ok());

Expand Down
22 changes: 2 additions & 20 deletions arrow-pg/src/list_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,27 +386,9 @@ pub(crate) fn encode_list(arr: Arc<dyn Array>, pg_field: &FieldInfo) -> PgWireRe
}
}
},
DataType::Struct(_) => {
let fields = match type_.kind() {
postgres_types::Kind::Array(struct_type_) => Ok(struct_type_),
_ => Err(format!(
"Expected list type found type {} of kind {:?}",
type_,
type_.kind()
)),
}
.and_then(|struct_type| match struct_type.kind() {
postgres_types::Kind::Composite(fields) => Ok(fields),
_ => Err(format!(
"Failed to unwrap a composite type inside from type {} kind {:?}",
type_,
type_.kind()
)),
})
.map_err(ToSqlError::from)?;

DataType::Struct(arrow_fields) => {
let values: PgWireResult<Vec<_>> = (0..arr.len())
.map(|row| encode_struct(&arr, row, fields, pg_field))
.map(|row| encode_struct(&arr, row, arrow_fields, pg_field))
.map(|x| {
if matches!(format, FieldFormat::Text) {
x.map(|opt| {
Expand Down
11 changes: 9 additions & 2 deletions arrow-pg/src/row_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,18 @@ impl RowEncoder {
if self.curr_idx == self.rb.num_rows() {
return None;
}

let arrow_schema = self.rb.schema_ref();
let mut encoder = DataRowEncoder::new(self.fields.clone());
for col in 0..self.rb.num_columns() {
let array = self.rb.column(col);
let field = &self.fields[col];
let arrow_field = arrow_schema.field(col);
let pg_field = &self.fields[col];

encode_value(&mut self.row_encoder, array, self.curr_idx, field).unwrap();
if let Err(e) = encode_value(&mut encoder, array, self.curr_idx, arrow_field, pg_field)
{
return Some(Err(e));
};
}
self.curr_idx += 1;
Some(Ok(self.row_encoder.take_row()))
Expand Down
Loading
Loading