Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/test_model_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ EOF
fi
;;
qwen3_5_moe)
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 32"
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0"
;;
voxtral_realtime)
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ jobs:
# Run CUDA backend Python tests
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
# Run quantize roundtrip tests (Qwen 3.5 MoE save/load prequantized)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py -v -o "addopts="
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
export-model-cuda-artifact:
name: export-model-cuda-artifact
Expand Down
4 changes: 4 additions & 0 deletions backends/aoti/common_shims_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ int32_t aoti_torch_dtype_int8() {
return 1; // ScalarType::Char
}

int32_t aoti_torch_dtype_uint8() {
return 0; // ScalarType::Byte
}

int32_t aoti_torch_dtype_bool() {
return 11; // ScalarType::Bool
}
Expand Down
1 change: 1 addition & 0 deletions backends/aoti/common_shims_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_uint8();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();

// ============================================================
Expand Down
9 changes: 8 additions & 1 deletion backends/aoti/slim/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using BFloat16 = ::executorch::runtime::etensor::BFloat16;
/// Enum representing the scalar type (dtype) of tensor elements.
/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility.
enum class ScalarType : int8_t {
// Byte = 0, // uint8_t - not currently needed
Byte = 0, // uint8_t
Char = 1, // int8_t
Short = 2, // int16_t
Int = 3, // int32_t
Expand All @@ -43,6 +43,7 @@ enum class ScalarType : int8_t {
};

// Type alias constants for convenience
constexpr ScalarType kByte = ScalarType::Byte;
constexpr ScalarType kChar = ScalarType::Char;
constexpr ScalarType kShort = ScalarType::Short;
constexpr ScalarType kInt = ScalarType::Int;
Expand All @@ -56,6 +57,8 @@ constexpr ScalarType kBFloat16 = ScalarType::BFloat16;
/// @return The size in bytes of a single element.
inline size_t elementSize(ScalarType t) {
switch (t) {
case ScalarType::Byte:
return sizeof(uint8_t);
case ScalarType::Char:
return sizeof(int8_t);
case ScalarType::Short:
Expand All @@ -80,6 +83,8 @@ inline size_t elementSize(ScalarType t) {
/// @return The name of the scalar type.
inline const char* toString(ScalarType t) {
switch (t) {
case ScalarType::Byte:
return "Byte";
case ScalarType::Char:
return "Char";
case ScalarType::Short:
Expand Down Expand Up @@ -114,6 +119,7 @@ inline bool isFloatingType(ScalarType t) {
/// @return true if the scalar type is integral, false otherwise.
inline bool isIntegralType(ScalarType t, bool includeBool) {
switch (t) {
case ScalarType::Byte:
case ScalarType::Char:
case ScalarType::Short:
case ScalarType::Int:
Expand All @@ -138,6 +144,7 @@ inline bool isBoolType(ScalarType t) {
/// @return true if the scalar type is valid, false otherwise.
inline bool isValidScalarType(ScalarType t) {
switch (t) {
case ScalarType::Byte:
case ScalarType::Char:
case ScalarType::Short:
case ScalarType::Int:
Expand Down
Loading
Loading