Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sim/simx/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ SRCS += $(SRC_DIR)/execute.cpp $(SRC_DIR)/func_unit.cpp
SRCS += $(SRC_DIR)/cache_sim.cpp $(SRC_DIR)/mem_sim.cpp $(SRC_DIR)/local_mem.cpp $(SRC_DIR)/mem_coalescer.cpp
SRCS += $(SRC_DIR)/dcrs.cpp $(SRC_DIR)/types.cpp

# sparse unit; add -DEXT_SPARSE_ENABLE flag later
SRCS += $(SRC_DIR)/sparse_unit.cpp

# Add V extension sources
ifneq ($(findstring -DEXT_V_ENABLE, $(CONFIGS)),)
Expand All @@ -42,6 +40,11 @@ endif
ifneq ($(findstring -DEXT_TCU_ENABLE, $(CONFIGS)),)
SRCS += $(SRC_DIR)/tensor_unit.cpp
endif
# Add VEGETA extension sources
ifneq ($(findstring -DEXT_VEGETA_ENABLE, $(CONFIGS)),)
SRCS += $(SRC_DIR)/vegeta_lsu.cpp
SRCS += $(SRC_DIR)/sparse_unit.cpp
endif

# Debugging
ifdef DEBUG
Expand Down
1 change: 1 addition & 0 deletions sim/simx/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Core::Core(const SimContext& ctx,
#endif
#ifdef EXT_VEGETA_ENABLE
, sparse_unit_(SparseUnit::Create("spu", arch, this))
, vegeta_lsu_(VegetaLsu::Create("vegeta_lsu", this, 1))
#endif
, emulator_(arch, dcrs, this)
, ibuffers_(arch.num_warps(), IBUF_SIZE)
Expand Down
6 changes: 6 additions & 0 deletions sim/simx/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#endif
#ifdef EXT_VEGETA_ENABLE
#include "sparse_unit.h"
#include "vegeta_lsu.h"
#endif

#include "dispatcher.h"
Expand Down Expand Up @@ -184,6 +185,10 @@ class Core : public SimObject<Core> {
SparseUnit::Ptr& sparse_unit() {
return sparse_unit_;
}

VegetaLsu::Ptr& vegeta_lsu() {
return vegeta_lsu_;
}
#endif

auto& trace_pool() {
Expand Down Expand Up @@ -217,6 +222,7 @@ class Core : public SimObject<Core> {

#ifdef EXT_VEGETA_ENABLE
SparseUnit::Ptr sparse_unit_;
VegetaLsu::Ptr vegeta_lsu_;
#endif

Emulator emulator_;
Expand Down
5 changes: 0 additions & 5 deletions sim/simx/execute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,6 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
case VegetaTcuType::TILE_GEMM_T: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
Expand All @@ -1570,7 +1569,6 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
case VegetaTcuType::TILE_GEMM_U: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
Expand All @@ -1585,7 +1583,6 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
case VegetaTcuType::TILE_GEMM_V: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
Expand All @@ -1599,7 +1596,6 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
case VegetaTcuType::TILE_GEMM_R: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
Expand All @@ -1614,7 +1610,6 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
auto tpuArgs = std::get<IntrVegetaTcuArgs>(instrArgs);
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Get metadata from integer registers a0-a7 (x10-x17) for sparse fragA
// These contain metadata values loaded by mma_sync into a0-a7
Expand Down
145 changes: 82 additions & 63 deletions sim/simx/sparse_unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "sparse_cfg.h"
#include <rvfloats.h>
#include "core.h"
#include "vegeta_lsu.h"
#include <cstring>

using namespace vortex;
Expand Down Expand Up @@ -772,90 +773,97 @@ class SparseUnit::Impl {
uint32_t tile_reg_idx = vd;
assert(tile_reg_idx < tile_reg_file_.size() && "Tile register index out of bounds");
auto &tile_reg = tile_reg_file_[tile_reg_idx];
constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); // 4 bytes for fp32
base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads

// Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes
// Use VegetaLsu for bulk tile load (1KB)
constexpr uint32_t T_TILE_SIZE = TILE_DIM * TILE_DIM * sizeof(float);
float tile_buffer[TILE_DIM * TILE_DIM];
core_->vegeta_lsu()->load_tile(base_addr, VegetaLsu::TileType::T_TILE,
tile_reg_idx, wid, tid, tile_buffer);

// Copy from linear buffer to 2D tile register
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; ++col) {
uint64_t mem_addr = base_addr + (row * TILE_DIM + col) * ELEMENT_SIZE;
uint32_t mem_data = 0;
core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE);
trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE});

// Interpret as float and store in tile register
float value;
std::memcpy(&value, &mem_data, ELEMENT_SIZE);
tile_reg[row][col] = value;
tile_reg[row][col] = tile_buffer[row * TILE_DIM + col];
}
}

DP(2, "TILE_LOAD_T: wid=" << wid << ", tid=" << tid
// Record trace for all elements
constexpr uint32_t ELEMENT_SIZE = sizeof(float);
for (uint32_t i = 0; i < TILE_DIM * TILE_DIM; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i * ELEMENT_SIZE, ELEMENT_SIZE});
}

DP(2, "TILE_LOAD_T (via VegetaLsu): wid=" << wid << ", tid=" << tid
<< ", tile_reg_idx=" << tile_reg_idx << ", base_addr=0x" << std::hex << base_addr << std::dec);
break;
}
case VegetaLsuType::TILE_LOAD_U: {
// tile_load_u: DestReg contains ureg index, map to tile registers
// ureg 0 -> tile reg 0, 1
// ureg 0 -> tile reg 0, 1 (2KB total = 2 T-tiles)
std::vector<uint32_t> target_tregs = map_ureg_to_treg(vd);
base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads
constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype);

uint64_t current_addr = base_addr;
for (uint32_t treg_idx : target_tregs) {
// Use VegetaLsu for bulk U-tile load (2KB)
constexpr uint32_t T_TILE_ELEMENTS = TILE_DIM * TILE_DIM;
float tile_buffer[T_TILE_ELEMENTS * 2]; // 2 T-tiles for U-reg
core_->vegeta_lsu()->load_tile(base_addr, VegetaLsu::TileType::U_TILE,
vd, wid, tid, tile_buffer);

// Copy from linear buffer to 2D tile registers
for (uint32_t t = 0; t < target_tregs.size(); ++t) {
uint32_t treg_idx = target_tregs[t];
assert(treg_idx < tile_reg_file_.size() && "Tile register index out of bounds");
auto &tile_reg = tile_reg_file_[treg_idx];

// Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; ++col) {
uint64_t mem_addr = current_addr + (row * TILE_DIM + col) * ELEMENT_SIZE;
uint32_t mem_data = 0;
core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE);
trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE});

float value;
std::memcpy(&value, &mem_data, ELEMENT_SIZE);
tile_reg[row][col] = value;
tile_reg[row][col] = tile_buffer[t * T_TILE_ELEMENTS + row * TILE_DIM + col];
}
}
current_addr += TILE_DIM * TILE_DIM * ELEMENT_SIZE; // Move to next tile (1KB)
}

DP(2, "TILE_LOAD_U: wid=" << wid << ", tid=" << tid
// Record trace for all elements
constexpr uint32_t ELEMENT_SIZE = sizeof(float);
for (uint32_t i = 0; i < T_TILE_ELEMENTS * 2; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i * ELEMENT_SIZE, ELEMENT_SIZE});
}

DP(2, "TILE_LOAD_U (via VegetaLsu): wid=" << wid << ", tid=" << tid
<< ", ureg_idx=" << vd << ", target_tregs=["
<< target_tregs[0] << ", " << target_tregs[1] << "], base_addr=0x" << std::hex << base_addr << std::dec);
break;
}
case VegetaLsuType::TILE_LOAD_V: {
// tile_load_v: DestReg contains vreg index, map to tile registers
// vreg 0 -> tile reg 0, 1, 2, 3
// vreg 0 -> tile reg 0, 1, 2, 3 (4KB total = 4 T-tiles)
std::vector<uint32_t> target_tregs = map_vreg_to_treg(vd);
base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads
constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype);

uint64_t current_addr = base_addr;
for (uint32_t treg_idx : target_tregs) {
// Use VegetaLsu for bulk V-tile load (4KB)
constexpr uint32_t T_TILE_ELEMENTS = TILE_DIM * TILE_DIM;
float tile_buffer[T_TILE_ELEMENTS * 4]; // 4 T-tiles for V-reg
core_->vegeta_lsu()->load_tile(base_addr, VegetaLsu::TileType::V_TILE,
vd, wid, tid, tile_buffer);

// Copy from linear buffer to 2D tile registers
for (uint32_t t = 0; t < target_tregs.size(); ++t) {
uint32_t treg_idx = target_tregs[t];
assert(treg_idx < tile_reg_file_.size() && "Tile register index out of bounds");
auto &tile_reg = tile_reg_file_[treg_idx];

// Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; ++col) {
uint64_t mem_addr = current_addr + (row * TILE_DIM + col) * ELEMENT_SIZE;
uint32_t mem_data = 0;
core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE);
trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE});

float value;
std::memcpy(&value, &mem_data, ELEMENT_SIZE);
tile_reg[row][col] = value;
tile_reg[row][col] = tile_buffer[t * T_TILE_ELEMENTS + row * TILE_DIM + col];
}
}
current_addr += TILE_DIM * TILE_DIM * ELEMENT_SIZE; // Move to next tile (1KB)
}

DP(2, "TILE_LOAD_V: wid=" << wid << ", tid=" << tid
// Record trace for all elements
constexpr uint32_t ELEMENT_SIZE = sizeof(float);
for (uint32_t i = 0; i < T_TILE_ELEMENTS * 4; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i * ELEMENT_SIZE, ELEMENT_SIZE});
}

DP(2, "TILE_LOAD_V (via VegetaLsu): wid=" << wid << ", tid=" << tid
<< ", vreg_idx=" << vd << ", target_tregs=["
<< target_tregs[0] << ", " << target_tregs[1] << ", "
<< target_tregs[2] << ", " << target_tregs[3] << "], base_addr=0x" << std::hex << base_addr << std::dec);
Expand All @@ -867,22 +875,28 @@ class SparseUnit::Impl {
assert(meta_reg_idx < metadata_reg_file_.size() && "Metadata register index out of bounds");
auto &metadata_reg = metadata_reg_file_[meta_reg_idx];

// Load metadata from memory: 16 rows x 16 columns = 256 uint4 elements = 128 bytes
// Use VegetaLsu for bulk M-tile load (128 bytes)
constexpr uint32_t M_TILE_SIZE = 128;
uint8_t meta_buffer[M_TILE_SIZE];
core_->vegeta_lsu()->load_tile(base_addr, VegetaLsu::TileType::M_TILE,
meta_reg_idx, wid, tid, meta_buffer);

// Parse nibbles from linear buffer into metadata register
// Each byte stores two uint4 values: upper nibble for col N, lower nibble for col N+1
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; col += 2) {
uint64_t mem_addr = base_addr + (row * (TILE_DIM / 2) + col / 2);
uint8_t mem_data = 0;
core_->dcache_read(&mem_data, mem_addr, 1);
trace_data->mem_addrs.at(tid).push_back({mem_addr, 1});

// Upper nibble for col N, lower nibble for col N+1
metadata_reg[row][col] = (mem_data >> 4) & 0x0F;
metadata_reg[row][col + 1] = mem_data & 0x0F;
uint8_t byte = meta_buffer[row * (TILE_DIM / 2) + col / 2];
metadata_reg[row][col] = (byte >> 4) & 0x0F;
metadata_reg[row][col + 1] = byte & 0x0F;
}
}

// Record trace for all bytes
for (uint32_t i = 0; i < M_TILE_SIZE; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i, 1});
}

DP(2, "TILE_LOAD_M: wid=" << wid << ", tid=" << tid
DP(2, "TILE_LOAD_M (via VegetaLsu): wid=" << wid << ", tid=" << tid
<< ", metadata_reg_idx=" << meta_reg_idx << ", base_addr=0x" << std::hex << base_addr << std::dec);
break;
}
Expand Down Expand Up @@ -911,21 +925,26 @@ class SparseUnit::Impl {
assert(vs3 < tile_reg_file_.size() && "Tile register index out of bounds");
auto &tile_reg = tile_reg_file_[vs3];
constexpr uint32_t TILE_DIM = 16;
constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); // 4 bytes for fp32

// Store tile to memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes
// Copy 2D tile register to linear buffer for VegetaLsu
float tile_buffer[TILE_DIM * TILE_DIM];
for (uint32_t row = 0; row < TILE_DIM; ++row) {
for (uint32_t col = 0; col < TILE_DIM; ++col) {
uint64_t mem_addr = base_addr + (row * TILE_DIM + col) * ELEMENT_SIZE;
float value = tile_reg[row][col];
uint32_t mem_data = 0;
std::memcpy(&mem_data, &value, ELEMENT_SIZE);
core_->dcache_write(&mem_data, mem_addr, ELEMENT_SIZE);
trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE});
tile_buffer[row * TILE_DIM + col] = tile_reg[row][col];
}
}

// Use VegetaLsu for bulk tile store (1KB)
core_->vegeta_lsu()->store_tile(base_addr, VegetaLsu::TileType::T_TILE,
vs3, wid, tid, tile_buffer);

// Record trace for all elements
constexpr uint32_t ELEMENT_SIZE = sizeof(float);
for (uint32_t i = 0; i < TILE_DIM * TILE_DIM; ++i) {
trace_data->mem_addrs.at(tid).push_back({base_addr + i * ELEMENT_SIZE, ELEMENT_SIZE});
}

DP(2, "TILE_STORE: wid=" << wid << ", tid=" << tid << ", vs3=" << vs3
DP(2, "TILE_STORE (via VegetaLsu): wid=" << wid << ", tid=" << tid << ", vs3=" << vs3
<< ", base_addr=0x" << std::hex << base_addr << std::dec);
#else
std::abort(); // EXT_VEGETA_ENABLE required for store operations
Expand Down
Loading