Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions extensions/native/circuit/cuda/include/native/poseidon2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,17 @@ template <typename T> struct SimplePoseidonSpecificCols {
template <typename T> struct MultiObserveCols {
T pc;
T final_timestamp_increment;
T state_ptr_register;
T ctx_register;
T input_ptr_register;
T hint_id_register;
T state_ptr;
T ctx_ptr;
T input_ptr;
T init_pos;
T len;
T input_register_1;
T input_register_2;
T input_register_3;
T output_register;
T hint_id;
T ctx[4];
MemoryReadAuxCols<T> read_ctx;
T chunk_ts_count;
T is_first;
T is_last;
T curr_len;
Expand Down
32 changes: 24 additions & 8 deletions extensions/native/circuit/cuda/src/poseidon2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
T inside_row;
T simple;
T multi_observe_row;
T not_hint_multi_observe;

T end_inside_row;
T end_top_level;
Expand Down Expand Up @@ -355,31 +356,46 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) {
uint32_t very_start_timestamp =
row[COL_INDEX(Cols, very_first_timestamp)].asUInt32();
for (uint32_t i = 0; i < 4; ++i) {
for (uint32_t i = 0; i < 3; ++i) {
mem_fill_base(
mem_helper,
very_start_timestamp + i,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base))
);
}
mem_fill_base(
mem_helper,
very_start_timestamp + 3,
specific.slice_from(COL_INDEX(MultiObserveCols, read_ctx.base))
);
mem_fill_base(
mem_helper,
very_start_timestamp + 4,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[3].base))
);
} else {
uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32();
uint32_t chunk_start =
specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32();
uint32_t chunk_end =
specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32();
uint32_t is_hint =
specific[COL_INDEX(MultiObserveCols, ctx[2])].asUInt32();
uint32_t ts_per_element = 2 - is_hint;
for (uint32_t j = chunk_start; j < chunk_end; ++j) {
if (!is_hint) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
);
}
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 1,
start_timestamp + (1 - is_hint),
specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base))
);
start_timestamp += 2;
start_timestamp += ts_per_element;
}
if (chunk_end >= CHUNK) {
mem_fill_base(
Expand Down
20 changes: 15 additions & 5 deletions extensions/native/circuit/src/extension/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,29 @@ impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Native>
FriReducedOpeningChipGpu::new(range_checker.clone(), timestamp_max_bits);
inventory.add_executor_chip(fri_reduced_opening);

inventory.next_air::<NativePoseidon2Air<BabyBear, 1>>()?;
let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits);
inventory.add_executor_chip(poseidon2);

let hint_air: &HintSpaceProviderAir = inventory.next_air::<HintSpaceProviderAir>()?;
let cpu_range_checker = range_checker
.cpu_chip
.clone()
.expect("VariableRangeCheckerChipGPU is expected to be hybrid with cpu_chip");
let cpu_chip = Arc::new(HintSpaceProviderChip::new(
hint_air.hint_bus,
range_checker.clone(),
cpu_range_checker,
timestamp_max_bits,
));

let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone());
inventory.add_periphery_chip(provider_gpu);

inventory.next_air::<NativePoseidon2Air<BabyBear, 1>>()?;

let poseidon2 = NativePoseidon2ChipGpu::<1>::new_with_hint_space_provider(
range_checker.clone(),
timestamp_max_bits,
cpu_chip.clone(),
);
inventory.add_executor_chip(poseidon2);

inventory.next_air::<NativeSumcheckAir>()?;
let sumcheck =
NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits, cpu_chip);
Expand Down
36 changes: 19 additions & 17 deletions extensions/native/circuit/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,6 @@ where
);
inventory.add_air(fri_reduced_opening);

let verify_batch = NativePoseidon2Air::<_, 1>::new(
exec_bridge,
memory_bridge,
hint_bridge,
VerifyBatchBus::new(inventory.new_bus_idx()),
Poseidon2Config::default(),
);
inventory.add_air(verify_batch);

let hint_space_provider = HintSpaceProviderAir {
hint_bus: hint_bridge.hint_bus(),
lt_air: IsLtSubAir::new(
Expand All @@ -289,6 +280,15 @@ where
};
inventory.add_air(hint_space_provider);

let verify_batch = NativePoseidon2Air::<_, 1>::new(
exec_bridge,
memory_bridge,
hint_bridge,
VerifyBatchBus::new(inventory.new_bus_idx()),
Poseidon2Config::default(),
);
inventory.add_air(verify_batch);

let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge);
inventory.add_air(tower_evaluate);

Expand Down Expand Up @@ -365,13 +365,6 @@ where
FriReducedOpeningChip::new(FriReducedOpeningFiller::new(), mem_helper.clone());
inventory.add_executor_chip(fri_reduced_opening);

inventory.next_air::<NativePoseidon2Air<Val<SC>, 1>>()?;
let poseidon2 = NativePoseidon2Chip::<_, 1>::new(
NativePoseidon2Filler::new(Poseidon2Config::default()),
mem_helper.clone(),
);
inventory.add_executor_chip(poseidon2);

let hint_bus = inventory.airs().system().hint_bridge.hint_bus();
let hint_space_provider = Arc::new(HintSpaceProviderChip::new(
hint_bus,
Expand All @@ -382,8 +375,17 @@ where
inventory.next_air::<HintSpaceProviderAir>()?;
inventory.add_periphery_chip(hint_space_provider.clone());

inventory.next_air::<NativePoseidon2Air<Val<SC>, 1>>()?;

let poseidon2 = NativePoseidon2Chip::<_, 1>::new(
NativePoseidon2Filler::new(Poseidon2Config::default(), hint_space_provider.clone()),
mem_helper.clone(),
);
inventory.add_executor_chip(poseidon2);

inventory.next_air::<NativeSumcheckAir>()?;
let tower_verify = NativeSumcheckChip::new(
NativeSumcheckFiller::new(hint_space_provider),
NativeSumcheckFiller::new(hint_space_provider.clone()),
mem_helper.clone(),
);
inventory.add_executor_chip(tower_verify);
Expand Down
Loading