diff --git a/src/main/java/org/duckdb/DuckDBAppender.java b/src/main/java/org/duckdb/DuckDBAppender.java index 96560b562..251622950 100644 --- a/src/main/java/org/duckdb/DuckDBAppender.java +++ b/src/main/java/org/duckdb/DuckDBAppender.java @@ -82,15 +82,12 @@ public class DuckDBAppender implements AutoCloseable { private final Lock appenderRefLock = new ReentrantLock(); private final ByteBuffer chunkRef; - private final Column[] columns; + private final List columns; private long rowIdx = 0; - private int colIdx = 0; - private int structFieldIdx = 0; - private int unionFieldIdx = 0; - private boolean appendingRow = false; - private boolean appendingStruct = false; + private Column currentColumn = null; + private Column prevColumn = null; private boolean writeInlinedStrings = true; @@ -105,12 +102,12 @@ public class DuckDBAppender implements AutoCloseable { ByteBuffer appenderRef = null; ByteBuffer[] colTypes = null; ByteBuffer chunkRef = null; - Column[] vectors = null; + List cols = null; try { appenderRef = createAppender(conn, catalog, schema, table); colTypes = readTableTypes(appenderRef); chunkRef = createChunk(colTypes); - vectors = createVectors(chunkRef, colTypes); + cols = createTopLevelColumns(chunkRef, colTypes); } catch (Exception e) { if (null != chunkRef) { duckdb_destroy_data_chunk(chunkRef); @@ -130,78 +127,92 @@ public class DuckDBAppender implements AutoCloseable { this.appenderRef = appenderRef; this.chunkRef = chunkRef; - this.columns = vectors; + this.columns = cols; } public DuckDBAppender beginRow() throws SQLException { checkOpen(); - checkAppendingRow(false); - checkAppendingStruct(false); - if (0 != colIdx) { - throw new SQLException(createErrMsg("'endRow' must be called before adding next row")); + if (!readyForANewRowInvariant()) { + throw new SQLException(createErrMsg("'endRow' must be called before calling 'beginRow' again")); } - this.appendingRow = true; + if (null == columns || 0 == columns.size()) { + throw new SQLException(createErrMsg("no columns found to append to")); + } + this.currentColumn = columns.get(0); return this; } public DuckDBAppender endRow() throws SQLException { checkOpen(); - checkAppendingRow(true); - checkAppendingStruct(false); - - if (columns.length != colIdx) { - throw new SQLException(createErrMsg("'endRow' can be called only after adding all columns, expected: " + - columns.length + ", actual: " + colIdx)); + if (!rowCompletedInvariant()) { + Column topCol = currentTopLevelColumn(); + if (null != topCol) { + throw new SQLException( + createErrMsg("all columns must be appended to before calling 'endRow', expected columns count: " + + columns.size() + ", actual: " + (topCol.idx + 1))); + } else { + throw new SQLException(createErrMsg( + "calls to 'beginRow' and 'endRow' must be paired and cannot be interleaved with other 'begin*' and 'end*' calls")); + } } rowIdx++; - this.appendingRow = false; + Column prev = prevColumn; + this.prevColumn = null; if (rowIdx >= maxRows) { try { flush(); } catch (SQLException e) { - this.appendingRow = true; + this.prevColumn = prev; rowIdx--; throw e; } } - colIdx = 0; return this; } public DuckDBAppender beginStruct() throws SQLException { checkOpen(); + if (!rowBegunInvariant()) { + throw new SQLException(createErrMsg("'beginRow' must be called before calling 'beginStruct'")); + } checkCurrentColumnType(DUCKDB_TYPE_STRUCT); - checkAppendingStruct(false); - this.appendingStruct = true; + // if (structBegunInvariant()) { + // throw new SQLException(createErrMsg("'endStruct' must be called before calling 'beginStruct' + // again")); + // } + if (0 == currentColumn.children.size()) { + throw new SQLException(createErrMsg("invalid empty struct")); + } + this.currentColumn = currentColumn.children.get(0); return this; } public DuckDBAppender endStruct() throws SQLException { checkOpen(); - checkAppendingStruct(true); - Column structCol = currentTopLevelColumn(); - if (structCol.children.size() != structFieldIdx) { - throw new SQLException( - createErrMsg("'endStruct' can be called only after adding all struct fields, expected: " + - structCol.children.size() + ", actual: " + structFieldIdx)); + if (!structCompletedInvariant()) { + if (structBegunInvariant()) { + throw new SQLException(createErrMsg( + "all struct fields must be appended to before calling 'endStruct', expected fields count: " + + currentColumn.parent.children.size() + ", actual: " + (currentColumn.idx + 1))); + } + throw new SQLException(createErrMsg("all struct fields must be appended to before calling 'endStruct'")); } - this.structFieldIdx = 0; - this.appendingStruct = false; - incrementColOrStructFieldIdx(); + this.prevColumn = this.prevColumn.parent; return this; } public DuckDBAppender beginUnion(String tag) throws SQLException { checkOpen(); + if (!rowBegunInvariant()) { + throw new SQLException(createErrMsg("'beginRow' must be called before calling 'beginUnion'")); + } checkCurrentColumnType(DUCKDB_TYPE_UNION); - checkAppendingUnion(false); - Column structCol = currentTopLevelColumn(); int fieldWithTag = 0; - for (int i = 1; i < structCol.children.size(); i++) { - Column childCol = structCol.children.get(i); + for (int i = 1; i < currentColumn.children.size(); i++) { + Column childCol = currentColumn.children.get(i); if (childCol.structFieldName.equals(tag)) { fieldWithTag = i; } @@ -210,8 +221,9 @@ public DuckDBAppender beginUnion(String tag) throws SQLException { throw new SQLException(createErrMsg("specified union field not found, value: '" + tag + "'")); } - this.appendingStruct = true; // set tag + Column structCol = currentColumn; + this.currentColumn = currentColumn.children.get(0); append((byte) (fieldWithTag - 1)); // set other fields to NULL for (int i = 1; i < structCol.children.size(); i++) { @@ -221,24 +233,24 @@ public DuckDBAppender beginUnion(String tag) throws SQLException { Column childCol = structCol.children.get(i); childCol.setNull(rowIdx); } - this.unionFieldIdx = fieldWithTag; + this.currentColumn = structCol.children.get(fieldWithTag); return this; } public DuckDBAppender endUnion() throws SQLException { checkOpen(); - checkAppendingUnion(true); - this.structFieldIdx = 0; - this.unionFieldIdx = 0; - this.appendingStruct = false; - incrementColOrStructFieldIdx(); + if (!unionCompletedInvariant()) { + throw new SQLException(createErrMsg("union column must be appended to before calling 'endUnion'")); + } + this.prevColumn = this.prevColumn.parent; return this; } public long flush() throws SQLException { checkOpen(); - checkAppendingRow(false); - checkAppendingStruct(false); + if (!readyForANewRowInvariant()) { + throw new SQLException(createErrMsg("'endRow' must be called before calling 'flush'")); + } if (0 == rowIdx) { return rowIdx; @@ -1153,11 +1165,11 @@ public DuckDBAppender appendNull() throws SQLException { } public DuckDBAppender appendDefault() throws SQLException { - currentColumn(); + Column col = currentColumn(); appenderRefLock.lock(); try { checkOpen(); - duckdb_append_default_to_chunk(appenderRef, chunkRef, colIdx, rowIdx); + duckdb_append_default_to_chunk(appenderRef, chunkRef, col.idx, rowIdx); } finally { appenderRefLock.unlock(); } @@ -1189,76 +1201,46 @@ private void checkOpen() throws SQLException { } } - private void checkAppendingRow(boolean expected) throws SQLException { - if (appendingRow != expected) { - throw new SQLException(createErrMsg("'beginRow' and 'endRow' calls must be paired")); - } - } - - private void checkAppendingStruct(boolean expected) throws SQLException { - if (appendingStruct != expected) { - throw new SQLException(createErrMsg( - "'beginStruct' and 'endStruct' calls must be paired and cannot be interleaved with 'beginRow' and 'endRow'")); - } - } - - private void checkAppendingUnion(boolean expected) throws SQLException { - if (appendingStruct != expected) { - throw new SQLException(createErrMsg( - "'beginUnion' and 'endUnion' calls must be paired and cannot be interleaved with 'beginRow' and 'endRow'")); + private Column nextColumn(Column curCol) { + if (null == curCol) { + return null; } - if (appendingStruct && unionFieldIdx == 0) { - throw new SQLException(createErrMsg("invalid zero union field index")); + final List cols; + if (null == curCol.parent) { + cols = columns; + } else { + cols = curCol.parent.children; } - if (!appendingStruct && unionFieldIdx != 0) { - throw new SQLException(createErrMsg("invalid non-zero union field index")); + int nextColIdx = curCol.idx + 1; + if (nextColIdx < cols.size()) { + return cols.get(nextColIdx); + } else { + if (null != curCol.parent) { + curCol = curCol.parent; + // recurse up the tree + return nextColumn(curCol); + } else { + return null; + } } } private void incrementColOrStructFieldIdx() throws SQLException { - if (appendingStruct) { - structFieldIdx++; - return; - } - if (appendingRow) { - colIdx++; - return; - } - throw new SQLException(createErrMsg("'beginRow' must be called before calling `append`")); - } - - private Column currentTopLevelColumn() throws SQLException { - checkOpen(); - - if (colIdx >= columns.length) { - throw new SQLException( - createErrMsg("invalid columns count, expected: " + columns.length + ", actual: " + (colIdx + 1))); + Column col = currentColumn(); + this.prevColumn = currentColumn; + if (unionBegunInvariant()) { + this.currentColumn = nextColumn(col.parent); + } else { + this.currentColumn = nextColumn(col); } - - return columns[colIdx]; } private Column currentColumn() throws SQLException { - Column col = currentTopLevelColumn(); - - if (!appendingStruct || (col.colType != DUCKDB_TYPE_STRUCT && col.colType != DUCKDB_TYPE_UNION)) { - return col; + if (null == currentColumn) { + throw new SQLException(createErrMsg("current column not found, columns count: " + columns.size())); } - if (unionFieldIdx > 0) { - if (unionFieldIdx > col.children.size()) { - throw new SQLException(createErrMsg("invalid union fields count, expected: " + columns.length + - ", actual: " + (structFieldIdx + 1))); - } - return col.children.get(unionFieldIdx); - } - - if (structFieldIdx >= col.children.size()) { - throw new SQLException(createErrMsg("invalid struct fields count, expected: " + columns.length + - ", actual: " + (structFieldIdx + 1))); - } - - return col.children.get(structFieldIdx); + return currentColumn; } private Column currentArrayInnerColumn(CAPIType ctype) throws SQLException { @@ -1407,6 +1389,17 @@ private Column currentColumnWithRowPos(CAPIType[] ctypes) throws SQLException { return col; } + private Column currentTopLevelColumn() { + if (null == currentColumn) { + return null; + } + Column col = currentColumn; + while (null != col.parent) { + col = col.parent; + } + return col; + } + private void checkDecimalPrecision(BigDecimal value, CAPIType decimalInternalType, int maxPrecision) throws SQLException { if (value.precision() > maxPrecision) { @@ -1439,6 +1432,37 @@ private DuckDBAppender appendStringOrBlobInternal(CAPIType ctype, byte[] bytes) return this; } + private boolean rowBegunInvariant() { + return null != currentColumn; + } + + private boolean rowCompletedInvariant() { + return null == currentColumn && null != prevColumn && prevColumn.idx == columns.size() - 1; + } + + private boolean structBegunInvariant() { + return null != currentColumn && null != currentColumn.parent && + currentColumn.parent.colType == DUCKDB_TYPE_STRUCT; + } + + private boolean structCompletedInvariant() { + return null != prevColumn && null != prevColumn.parent && prevColumn.parent.colType == DUCKDB_TYPE_STRUCT && + prevColumn.idx == prevColumn.parent.children.size() - 1; + } + + private boolean unionBegunInvariant() { + return null != currentColumn && null != currentColumn.parent && + currentColumn.parent.colType == DUCKDB_TYPE_UNION; + } + + private boolean unionCompletedInvariant() { + return null != prevColumn && null != prevColumn.parent && prevColumn.parent.colType == DUCKDB_TYPE_UNION; + } + + private boolean readyForANewRowInvariant() { + return null == currentColumn && null == prevColumn; + } + private static byte[] utf8(String str) { if (null == str) { return null; @@ -1512,7 +1536,7 @@ private static void initVecChildren(Column parent) throws SQLException { case DUCKDB_TYPE_LIST: case DUCKDB_TYPE_MAP: { ByteBuffer vec = duckdb_list_vector_get_child(parent.vectorRef); - Column col = new Column(parent, null, vec); + Column col = new Column(parent, 0, null, vec); parent.children.add(col); break; } @@ -1521,42 +1545,44 @@ private static void initVecChildren(Column parent) throws SQLException { long count = duckdb_struct_type_child_count(parent.colTypeRef); for (int i = 0; i < count; i++) { ByteBuffer vec = duckdb_struct_vector_get_child(parent.vectorRef, i); - Column col = new Column(parent, null, vec, i); + Column col = new Column(parent, i, null, vec, i); parent.children.add(col); } break; } case DUCKDB_TYPE_ARRAY: { ByteBuffer vec = duckdb_array_vector_get_child(parent.vectorRef); - Column col = new Column(parent, null, vec); + Column col = new Column(parent, 0, null, vec); parent.children.add(col); break; } } } - private static Column[] createVectors(ByteBuffer chunkRef, ByteBuffer[] colTypes) throws SQLException { - Column[] vectors = new Column[colTypes.length]; + private static List createTopLevelColumns(ByteBuffer chunkRef, ByteBuffer[] colTypes) throws SQLException { + List columns = new ArrayList<>(colTypes.length); try { for (int i = 0; i < colTypes.length; i++) { ByteBuffer vector = duckdb_data_chunk_get_vector(chunkRef, i); - vectors[i] = new Column(null, colTypes[i], vector); + Column col = new Column(null, i, colTypes[i], vector); + columns.add(col); colTypes[i] = null; } } catch (Exception e) { - for (Column col : vectors) { + for (Column col : columns) { if (null != col) { col.destroy(); } } throw e; } - return vectors; + return columns; } private static class Column { private final Column parent; - private ByteBuffer colTypeRef; + private final int idx; + private /* final */ ByteBuffer colTypeRef; private final CAPIType colType; private final CAPIType decimalInternalType; private final int decimalPrecision; @@ -1569,13 +1595,14 @@ private static class Column { private ByteBuffer validity; private final List children = new ArrayList<>(); - private Column(Column parent, ByteBuffer colTypeRef, ByteBuffer vector) throws SQLException { - this(parent, colTypeRef, vector, -1); + private Column(Column parent, int idx, ByteBuffer colTypeRef, ByteBuffer vector) throws SQLException { + this(parent, idx, colTypeRef, vector, -1); } - private Column(Column parent, ByteBuffer colTypeRef, ByteBuffer vector, int structFieldIdx) + private Column(Column parent, int idx, ByteBuffer colTypeRef, ByteBuffer vector, int structFieldIdx) throws SQLException { this.parent = parent; + this.idx = idx; if (null == vector) { throw new SQLException("cannot initialize data chunk vector"); diff --git a/src/test/java/org/duckdb/TestAppender.java b/src/test/java/org/duckdb/TestAppender.java index 94a6b1654..5f09104fc 100644 --- a/src/test/java/org/duckdb/TestAppender.java +++ b/src/test/java/org/duckdb/TestAppender.java @@ -776,7 +776,7 @@ public static void test_appender_decimal_wrong_scale() throws Exception { assertThrows(() -> { try (DuckDBAppender appender = conn.createAppender("decimals")) { - appender.append(1).beginRow().append(new BigDecimal("121.14").setScale(2)); + appender.beginRow().append(1).append(new BigDecimal("121.14").setScale(2)); } }, SQLException.class); @@ -1656,6 +1656,45 @@ public static void test_appender_struct_basic() throws Exception { } } + public static void test_appender_struct_nested() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE tab1 (col1 INTEGER, col2 STRUCT(s1 INTEGER, s2 STRUCT(ns1 INTEGER, ns2 VARCHAR)), col3 VARCHAR)"); + + try (DuckDBAppender appender = conn.createAppender("tab1")) { + appender.beginRow() + .append(42) + .beginStruct() + .append(43) + .beginStruct() + .append(44) + .append("foo") + .endStruct() + .endStruct() + .append("bar") + .endRow() + .flush(); + } + + try (ResultSet rs = stmt.executeQuery("SELECT * FROM tab1 WHERE col1 = 42")) { + assertTrue(rs.next()); + + assertEquals(rs.getInt(1), 42); + DuckDBStruct struct = (DuckDBStruct) rs.getObject(2); + Map map = struct.getMap(); + assertEquals(map.get("s1"), 43); + DuckDBStruct nested = (DuckDBStruct) map.get("s2"); + Map nestedMap = nested.getMap(); + assertEquals(nestedMap.get("ns1"), 44); + assertEquals(nestedMap.get("ns2"), "foo"); + assertEquals(rs.getString(3), "bar"); + + assertFalse(rs.next()); + } + } + } + public static void test_appender_struct_with_array() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { @@ -1921,4 +1960,85 @@ public static void test_appender_union_flush() throws Exception { } } } + + public static void test_appender_union_nested() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1 (" + + "col1 INTEGER, " + + "col2 STRUCT( " + + " s1 INTEGER, " + + " s2 UNION(" + + " u1 INTEGER," + + " u2 STRUCT(" + + " us1 INTEGER, " + + " us2 INTEGER[2]," + + " us3 VARCHAR" + + " )" + + " )" + + "), " + + "col3 VARCHAR)"); + + try (DuckDBAppender appender = conn.createAppender("tab1")) { + appender.beginRow() + .append(42) + .beginStruct() + .append(43) + .beginUnion("u1") + .append(44) + .endUnion() + .endStruct() + .append("foo") + .endRow() + + .beginRow() + .append(45) + .beginStruct() + .append(46) + .beginUnion("u2") + .beginStruct() + .append(47) + .append(new int[] {48, 49}) + .append("bar") + .endStruct() + .endUnion() + .endStruct() + .append("baz") + .endRow(); + } + + try (ResultSet rs = stmt.executeQuery("SELECT * FROM tab1 ORDER BY col1")) { + assertTrue(rs.next()); + + assertEquals(rs.getInt(1), 42); + DuckDBStruct struct = (DuckDBStruct) rs.getObject(2); + Map map = struct.getMap(); + assertEquals(map.size(), 2); + assertEquals(map.get("s1"), 43); + assertEquals(map.get("s2"), 44); + assertEquals(rs.getString(3), "foo"); + + assertTrue(rs.next()); + + assertEquals(rs.getInt(1), 45); + DuckDBStruct struct1 = (DuckDBStruct) rs.getObject(2); + Map map1 = struct1.getMap(); + assertEquals(map1.size(), 2); + assertEquals(map1.get("s1"), 46); + DuckDBStruct struct2 = (DuckDBStruct) map1.get("s2"); + Map map2 = struct2.getMap(); + assertEquals(map2.size(), 3); + assertEquals(map2.get("us1"), 47); + DuckDBArray arrWrapper = (DuckDBArray) map2.get("us2"); + Object[] arr = (Object[]) arrWrapper.getArray(); + assertEquals(arr.length, 2); + assertEquals(arr[0], 48); + assertEquals(arr[1], 49); + assertEquals(map2.get("us3"), "bar"); + assertEquals(rs.getString(3), "baz"); + + assertFalse(rs.next()); + } + } + } }