Skip to content

Add AMD/ROCm support for SSD TBE inference#5559

Open
goldcoderZ wants to merge 3 commits intopytorch:mainfrom
goldcoderZ:export-D98852460
Open

Add AMD/ROCm support for SSD TBE inference#5559
goldcoderZ wants to merge 3 commits intopytorch:mainfrom
goldcoderZ:export-D98852460

Conversation

@goldcoderZ
Copy link
Copy Markdown
Contributor

Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2524

Enable SSD TBE inference operator on AMD GPUs (ROCm/HIP) by fixing
the warp-size / cache-associativity mismatch between NVIDIA (32) and
AMD (64).

Key changes:

  • common.py: ASSOC is now platform-aware (32 on CUDA, 64 on ROCm)
    to match the hardware warp/wavefront width. All Python-side tensor
    allocations (lxu_cache_state, lru_state, lxu_cache_weights, etc.)
    automatically use the correct ASSOC value.

  • bitonic_sort.cuh: Added 6th merge stage (L=32) under USE_ROCM
    to fully sort 64 elements on AMD's 64-wide wavefronts. Without this,
    only 32 of 64 lanes were sorted, causing incorrect cache eviction
    ordering.

  • lxu_cache.cu: Removed assert(false) in lxu_cache_lookup_kernel's
    ROCm path. The existing __ballot() + __ffsll() implementation was
    already correct for 64-wide wavefronts — the assert was overly
    conservative.

  • ssd_split_embeddings_cache_cuda.cu: Added USE_ROCM branch in
    get_masked_index_default_pipeline_sms() with CU-count-based heuristic,
    avoiding NVIDIA-specific compute capability checks on AMD hardware.

  • inference.py / inference_serving.py: Updated docstrings, ROCm
    logging, and replaced hardcoded ASSOC=32 with the platform-aware
    constant from common.py.

  • Tests: Added SSDInferenceAMDSupportTest class with 17 tests:
    ASSOC constant validation, tensor shape verification, cache
    invalidation with platform ASSOC, forward correctness, multi-table
    shapes, and simulated ASSOC=64 tests that run on NVIDIA hardware
    to verify the 64-wide cache invalidation logic.

Differential Revision: D98852460

Summary:
Add streaming delta update support to the SSD TBE inference operator
(SSDIntNBitTableBatchedEmbeddingBags), closing the gap with EmbeddingDB's
streaming delta table feature.

Two new public methods:

1. streaming_update(indices, weights) — writes updated embedding rows to
   RocksDB and invalidates corresponding HBM cache entries so subsequent
   prefetch() calls reload from SSD. Uses vectorized set-associative
   cache invalidation.

2. load_snapshot(ssd_storage_directory, ...) — flushes the current
   RocksDB, opens a new instance at a fresh directory, and fully
   invalidates the HBM cache. Enables zero-downtime snapshot transitions.

Also adds AMD/ROCm awareness:
- IS_ROCM detection flag
- Constructor warns when running on ROCm (streaming API works but
  prefetch/forward C++ kernels are NVIDIA-only due to ASSOC=32 vs
  kWarpSize=64 mismatch)
- Docstring documents AMD support status per method

These are the minimal primitives needed for online training models with
streaming embedding updates (e.g., 45-min publish intervals).

Differential Revision: D98827795
Summary:
Add a serving-friendly wrapper module that enables Video Retrieval HSTU
models (VDD, New2, IFU) to use FBGEMM SSD TBE with TurboSSD v2 features
instead of EmbeddingDB.

TurboSSDInferenceModule provides:

1. Single-call forward: auto-prefetch + lookup (vs. EmbeddingDB's
   synchronous SSD reads). Uses GPU HBM cache with LRU eviction.

2. streaming_update() + load_snapshot(): streaming delta updates and
   zero-downtime snapshot transitions, matching EmbeddingDB's streaming
   delta table feature.

3. Factory method from_embedding_specs(): auto-sizes HBM cache based on
   target hit rate and optional HBM budget. Useful for capacity planning
   on H100 (96 GB) and MI350X (288 GB).

4. estimate_hbm_gb(): static method for HBM capacity planning without
   instantiating the module.

This is the integration layer between SSD TBE and the TGIF serving
framework. The module can replace SSDEmbeddingDBSplitTableBatchedEmbedding
BagsCodegen in the model graph via DIShardingPass.

Differential Revision: D98843434
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2524

Enable SSD TBE inference operator on AMD GPUs (ROCm/HIP) by fixing
the warp-size / cache-associativity mismatch between NVIDIA (32) and
AMD (64).

Key changes:
- **common.py**: ASSOC is now platform-aware (32 on CUDA, 64 on ROCm)
  to match the hardware warp/wavefront width. All Python-side tensor
  allocations (lxu_cache_state, lru_state, lxu_cache_weights, etc.)
  automatically use the correct ASSOC value.

- **bitonic_sort.cuh**: Added 6th merge stage (L=32) under USE_ROCM
  to fully sort 64 elements on AMD's 64-wide wavefronts. Without this,
  only 32 of 64 lanes were sorted, causing incorrect cache eviction
  ordering.

- **lxu_cache.cu**: Removed assert(false) in lxu_cache_lookup_kernel's
  ROCm path. The existing __ballot() + __ffsll() implementation was
  already correct for 64-wide wavefronts — the assert was overly
  conservative.

- **ssd_split_embeddings_cache_cuda.cu**: Added USE_ROCM branch in
  get_masked_index_default_pipeline_sms() with CU-count-based heuristic,
  avoiding NVIDIA-specific compute capability checks on AMD hardware.

- **inference.py / inference_serving.py**: Updated docstrings, ROCm
  logging, and replaced hardcoded ASSOC=32 with the platform-aware
  constant from common.py.

- **Tests**: Added SSDInferenceAMDSupportTest class with 17 tests:
  ASSOC constant validation, tensor shape verification, cache
  invalidation with platform ASSOC, forward correctness, multi-table
  shapes, and simulated ASSOC=64 tests that run on NVIDIA hardware
  to verify the 64-wide cache invalidation logic.

Differential Revision: D98852460
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync bot commented Mar 31, 2026

@goldcoderZ has exported this pull request. If you are a Meta employee, you can view the originating Diff in D98852460.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant