Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.parquet.parquetdataformat.bridge;

import org.opensearch.index.engine.exec.merge.RowIdMapping;

import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
Expand Down Expand Up @@ -39,5 +41,5 @@ public class RustBridge {


// Native method declarations - these will be implemented in the JNI library
public static native void mergeParquetFilesInRust(List<Path> inputFiles, String outputFile);
public static native RowIdMapping mergeParquetFilesInRust(List<Path> inputFiles, String outputFile);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
import org.opensearch.index.engine.exec.DataFormat;
import org.opensearch.index.engine.exec.WriterFileSet;
import org.opensearch.index.engine.exec.merge.MergeResult;
import org.opensearch.index.engine.exec.merge.RowId;
import org.opensearch.index.engine.exec.merge.RowIdMapping;

import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -53,11 +51,8 @@ public MergeResult mergeParquetFiles(List<WriterFileSet> files, long writerGener
String mergedFileName = getMergedFileName(writerGeneration);

try {
// Merge files in Rust
mergeParquetFilesInRust(filePaths, mergedFilePath);

// Build row ID mapping
Map<RowId, Long> rowIdMapping = new HashMap<>();
// Merge files in Rust and get row ID mappings
RowIdMapping rowIdMapping = mergeParquetFilesInRust(filePaths, mergedFilePath);

WriterFileSet mergedWriterFileSet =
WriterFileSet.builder().directory(Path.of(outputDirectory)).addFile(mergedFileName).writerGeneration(writerGeneration).build();
Expand All @@ -67,7 +62,7 @@ public MergeResult mergeParquetFiles(List<WriterFileSet> files, long writerGener
mergedWriterFileSet
);

return new MergeResult(new RowIdMapping(rowIdMapping, mergedFileName), mergedWriterFileSetMap);
return new MergeResult(rowIdMapping, mergedWriterFileSetMap);

} catch (Exception exception) {
logger.error(
Expand Down
140 changes: 114 additions & 26 deletions modules/parquet-data-format/src/main/rust/src/parquet_merge.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use jni::JNIEnv;
use jni::objects::{JClass, JObject, JString};
use jni::sys::jint;
use jni::objects::{JClass, JObject, JString, JValue};
use std::fs::File;
use std::error::Error;
use std::any::Any;
Expand Down Expand Up @@ -53,14 +52,21 @@ struct ProcessingStats {
total_batches: usize,
}

// JNI Entry Point
// Row ID mapping for cross-format merge
struct RowIdMappingData {
old_file_id: String,
old_row_id: i64,
new_row_id: i64,
}

// JNI Entry Point - returns RowIdMapping to Java
#[unsafe(no_mangle)]
pub extern "system" fn Java_com_parquet_parquetdataformat_bridge_RustBridge_mergeParquetFilesInRust(
mut env: JNIEnv,
_class: JClass,
input_files: JObject,
output_file: JString,
) -> jint {
pub extern "system" fn Java_com_parquet_parquetdataformat_bridge_RustBridge_mergeParquetFilesInRust<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
input_files: JObject<'local>,
output_file: JString<'local>,
) -> JObject<'local> {
let result = catch_unwind(|| {
let input_files_vec = convert_java_list_to_vec(&mut env, input_files)
.map_err(|e| format!("Failed to convert Java list: {}", e))?;
Expand All @@ -72,31 +78,41 @@ pub extern "system" fn Java_com_parquet_parquetdataformat_bridge_RustBridge_merg

log_info!("Starting merge of {} files to {}", input_files_vec.len(), output_path);

process_parquet_files(&input_files_vec, &output_path)?;
let (mappings, output_file_id) = process_parquet_files(&input_files_vec, &output_path)?;

log_info!("Merge completed successfully");
Ok(())
Ok((mappings, output_file_id))
});

match result {
Ok(Ok(_)) => 0,
Ok(Ok((mappings, output_file_id))) => {
match create_row_id_mapping_object(&mut env, mappings, &output_file_id) {
Ok(obj) => obj,
Err(e) => {
let error_msg = format!("Failed to create RowIdMapping: {}", e);
log_error!("{}", error_msg);
let _ = env.throw_new("java/lang/RuntimeException", &error_msg);
JObject::null()
}
}
}
Ok(Err(e)) => {
let error_msg = format!("Error processing Parquet files: {}", e);
log_error!("{}", error_msg);
let _ = env.throw_new("java/lang/RuntimeException", &error_msg);
-1
JObject::null()
}
Err(e) => {
let error_msg = format!("Rust panic occurred: {:?}", e);
log_error!("{}", error_msg);
let _ = env.throw_new("java/lang/RuntimeException", &error_msg);
-1
JObject::null()
}
}
}

// Main processing function
pub fn process_parquet_files(input_files: &[String], output_path: &str) -> Result<(), Box<dyn Error>> {
// Main processing function - returns row ID mappings
pub fn process_parquet_files(input_files: &[String], output_path: &str) -> Result<(Vec<RowIdMappingData>, String), Box<dyn Error>> {
// Validate input
validate_input(input_files)?;

Expand All @@ -107,19 +123,25 @@ pub fn process_parquet_files(input_files: &[String], output_path: &str) -> Resul
// Create writer
let mut writer = create_writer(output_path, schema.clone())?;

// Process files
let stats = process_files(input_files, &schema, &mut writer)?;
// Process files and collect mappings
let (stats, mappings) = process_files(input_files, &schema, &mut writer)?;

// Close writer
writer.close()
.map_err(|e| ParquetMergeError::WriterCreationError(format!("Failed to close writer: {}", e)))?;

log_info!(
"Processing complete: {} files, {} rows, {} batches",
stats.files_processed, stats.total_rows, stats.total_batches
"Processing complete: {} files, {} rows, {} batches, {} mappings",
stats.files_processed, stats.total_rows, stats.total_batches, mappings.len()
);

Ok(())
let output_file_id = std::path::Path::new(output_path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(output_path)
.to_string();

Ok((mappings, output_file_id))
}

// Validation functions
Expand Down Expand Up @@ -165,22 +187,29 @@ fn create_writer(output_path: &str, schema: SchemaRef) -> Result<ArrowWriter<Rat
.map_err(|e| ParquetMergeError::WriterCreationError(format!("Failed to create writer: {}", e)).into())
}

// File processing
// File processing - collects row ID mappings
fn process_files(
input_files: &[String],
schema: &SchemaRef,
writer: &mut ArrowWriter<RateLimitedWriter<File>>,
) -> Result<ProcessingStats, Box<dyn Error>> {
) -> Result<(ProcessingStats, Vec<RowIdMappingData>), Box<dyn Error>> {
let mut current_row_id: i64 = 0;
let mut stats = ProcessingStats {
files_processed: 0,
total_rows: 0,
total_batches: 0,
};
let mut mappings = Vec::new();

for path in input_files {
log_info!("Processing file: {}", path);

let old_file_id = std::path::Path::new(path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(path)
.to_string();

let file = File::open(path)
.map_err(|e| ParquetMergeError::InvalidFile(format!("{}: {}", path, e)))?;

Expand All @@ -192,6 +221,7 @@ fn process_files(

let mut file_rows = 0;
let mut file_batches = 0;
let file_start_row_id = current_row_id;

for batch_result in reader {
let original_batch = batch_result
Expand All @@ -209,14 +239,23 @@ fn process_files(
file_batches += 1;
}

// Create mappings for this file
for old_row_id in 0..file_rows as i64 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are creating RowIdMappingData struct for every row, for a merge of files totaling N rows, this will be 3N Java heap allocations triggered from Rust through JNI + 4N JNI calls from create_row_id_mapping_object
Instead, since the mapping is just new_row_id = file_start_offset + old_row_id, we can pass only per-file segment metadata (file ID, start offset, row count) across JNI and materialize the full mapping in Java.

mappings.push(RowIdMappingData {
old_file_id: old_file_id.clone(),
old_row_id,
new_row_id: file_start_row_id + old_row_id,
});
}

stats.files_processed += 1;
stats.total_rows += file_rows;
stats.total_batches += file_batches;

log_info!("File processed: {} rows, {} batches", file_rows, file_batches);
}

Ok(stats)
Ok((stats, mappings))
}

// Row ID update logic
Expand Down Expand Up @@ -267,12 +306,61 @@ fn convert_java_list_to_vec(env: &mut JNIEnv, list: JObject) -> Result<Vec<Strin
Ok(result)
}

fn catch_unwind<F: FnOnce() -> Result<(), Box<dyn Error>>>(
fn catch_unwind<F: FnOnce() -> Result<(Vec<RowIdMappingData>, String), Box<dyn Error>>>(
f: F
) -> Result<Result<(), Box<dyn Error>>, Box<dyn Any + Send>> {
) -> Result<Result<(Vec<RowIdMappingData>, String), Box<dyn Error>>, Box<dyn Any + Send>> {
std::panic::catch_unwind(AssertUnwindSafe(f))
}

// Create Java RowIdMapping object
fn create_row_id_mapping_object<'local>(
env: &mut JNIEnv<'local>,
mappings: Vec<RowIdMappingData>,
output_file_id: &str,
) -> Result<JObject<'local>, Box<dyn Error>> {
// Create HashMap<RowId, Long>
let hash_map = env.new_object("java/util/HashMap", "()V", &[])?;

for mapping in mappings {
// Create RowId object
let row_id_obj = env.new_object(
"org/opensearch/index/engine/exec/merge/RowId",
"(JLjava/lang/String;)V",
&[
JValue::Long(mapping.old_row_id),
JValue::Object(&env.new_string(&mapping.old_file_id)?.into()),
],
)?;

// Create Long object for new row ID
let new_row_id_obj = env.new_object(
"java/lang/Long",
"(J)V",
&[JValue::Long(mapping.new_row_id)],
)?;

// Put into HashMap
env.call_method(
&hash_map,
"put",
"(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;",
&[JValue::Object(&row_id_obj), JValue::Object(&new_row_id_obj)],
)?;
}

// Create RowIdMapping object
let row_id_mapping = env.new_object(
"org/opensearch/index/engine/exec/merge/RowIdMapping",
"(Ljava/util/Map;Ljava/lang/String;)V",
&[
JValue::Object(&hash_map),
JValue::Object(&env.new_string(output_file_id)?.into()),
],
)?;

Ok(row_id_mapping)
}


// Close function
// #[no_mangle]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@

package org.opensearch.index.engine.exec.merge;

import java.util.Collections;
import java.util.Map;
import java.util.Objects;

public class RowIdMapping {
public final class RowIdMapping {

Map<RowId, Long> mapping;
private final Map<RowId, Long> mapping;
private final String fileId;

public RowIdMapping(Map<RowId, Long> mapping, String fileId) {
this.mapping = mapping;
this.fileId = fileId;
this.mapping = Collections.unmodifiableMap(Objects.requireNonNull(mapping, "mapping cannot be null"));
this.fileId = Objects.requireNonNull(fileId, "fileId cannot be null");
}

public long getNewRowId(RowId oldRowId) {
return mapping.get(oldRowId);
return mapping.getOrDefault(oldRowId, -1L);
}
public String getFileId() {
return fileId;
Expand Down
Loading