Add AMD/ROCm support for SSD TBE inference#5559
Open
goldcoderZ wants to merge 3 commits intopytorch:mainfrom
Open
Add AMD/ROCm support for SSD TBE inference#5559goldcoderZ wants to merge 3 commits intopytorch:mainfrom
goldcoderZ wants to merge 3 commits intopytorch:mainfrom
Conversation
Summary: Add streaming delta update support to the SSD TBE inference operator (SSDIntNBitTableBatchedEmbeddingBags), closing the gap with EmbeddingDB's streaming delta table feature. Two new public methods: 1. streaming_update(indices, weights) — writes updated embedding rows to RocksDB and invalidates corresponding HBM cache entries so subsequent prefetch() calls reload from SSD. Uses vectorized set-associative cache invalidation. 2. load_snapshot(ssd_storage_directory, ...) — flushes the current RocksDB, opens a new instance at a fresh directory, and fully invalidates the HBM cache. Enables zero-downtime snapshot transitions. Also adds AMD/ROCm awareness: - IS_ROCM detection flag - Constructor warns when running on ROCm (streaming API works but prefetch/forward C++ kernels are NVIDIA-only due to ASSOC=32 vs kWarpSize=64 mismatch) - Docstring documents AMD support status per method These are the minimal primitives needed for online training models with streaming embedding updates (e.g., 45-min publish intervals). Differential Revision: D98827795
Summary: Add a serving-friendly wrapper module that enables Video Retrieval HSTU models (VDD, New2, IFU) to use FBGEMM SSD TBE with TurboSSD v2 features instead of EmbeddingDB. TurboSSDInferenceModule provides: 1. Single-call forward: auto-prefetch + lookup (vs. EmbeddingDB's synchronous SSD reads). Uses GPU HBM cache with LRU eviction. 2. streaming_update() + load_snapshot(): streaming delta updates and zero-downtime snapshot transitions, matching EmbeddingDB's streaming delta table feature. 3. Factory method from_embedding_specs(): auto-sizes HBM cache based on target hit rate and optional HBM budget. Useful for capacity planning on H100 (96 GB) and MI350X (288 GB). 4. estimate_hbm_gb(): static method for HBM capacity planning without instantiating the module. This is the integration layer between SSD TBE and the TGIF serving framework. The module can replace SSDEmbeddingDBSplitTableBatchedEmbedding BagsCodegen in the model graph via DIShardingPass. Differential Revision: D98843434
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2524 Enable SSD TBE inference operator on AMD GPUs (ROCm/HIP) by fixing the warp-size / cache-associativity mismatch between NVIDIA (32) and AMD (64). Key changes: - **common.py**: ASSOC is now platform-aware (32 on CUDA, 64 on ROCm) to match the hardware warp/wavefront width. All Python-side tensor allocations (lxu_cache_state, lru_state, lxu_cache_weights, etc.) automatically use the correct ASSOC value. - **bitonic_sort.cuh**: Added 6th merge stage (L=32) under USE_ROCM to fully sort 64 elements on AMD's 64-wide wavefronts. Without this, only 32 of 64 lanes were sorted, causing incorrect cache eviction ordering. - **lxu_cache.cu**: Removed assert(false) in lxu_cache_lookup_kernel's ROCm path. The existing __ballot() + __ffsll() implementation was already correct for 64-wide wavefronts — the assert was overly conservative. - **ssd_split_embeddings_cache_cuda.cu**: Added USE_ROCM branch in get_masked_index_default_pipeline_sms() with CU-count-based heuristic, avoiding NVIDIA-specific compute capability checks on AMD hardware. - **inference.py / inference_serving.py**: Updated docstrings, ROCm logging, and replaced hardcoded ASSOC=32 with the platform-aware constant from common.py. - **Tests**: Added SSDInferenceAMDSupportTest class with 17 tests: ASSOC constant validation, tensor shape verification, cache invalidation with platform ASSOC, forward correctness, multi-table shapes, and simulated ASSOC=64 tests that run on NVIDIA hardware to verify the 64-wide cache invalidation logic. Differential Revision: D98852460
Contributor
|
@goldcoderZ has exported this pull request. If you are a Meta employee, you can view the originating Diff in D98852460. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2524
Enable SSD TBE inference operator on AMD GPUs (ROCm/HIP) by fixing
the warp-size / cache-associativity mismatch between NVIDIA (32) and
AMD (64).
Key changes:
common.py: ASSOC is now platform-aware (32 on CUDA, 64 on ROCm)
to match the hardware warp/wavefront width. All Python-side tensor
allocations (lxu_cache_state, lru_state, lxu_cache_weights, etc.)
automatically use the correct ASSOC value.
bitonic_sort.cuh: Added 6th merge stage (L=32) under USE_ROCM
to fully sort 64 elements on AMD's 64-wide wavefronts. Without this,
only 32 of 64 lanes were sorted, causing incorrect cache eviction
ordering.
lxu_cache.cu: Removed assert(false) in lxu_cache_lookup_kernel's
ROCm path. The existing __ballot() + __ffsll() implementation was
already correct for 64-wide wavefronts — the assert was overly
conservative.
ssd_split_embeddings_cache_cuda.cu: Added USE_ROCM branch in
get_masked_index_default_pipeline_sms() with CU-count-based heuristic,
avoiding NVIDIA-specific compute capability checks on AMD hardware.
inference.py / inference_serving.py: Updated docstrings, ROCm
logging, and replaced hardcoded ASSOC=32 with the platform-aware
constant from common.py.
Tests: Added SSDInferenceAMDSupportTest class with 17 tests:
ASSOC constant validation, tensor shape verification, cache
invalidation with platform ASSOC, forward correctness, multi-table
shapes, and simulated ASSOC=64 tests that run on NVIDIA hardware
to verify the 64-wide cache invalidation logic.
Differential Revision: D98852460