diff --git a/DATASET_PRESET_TESTING.md b/DATASET_PRESET_TESTING.md new file mode 100644 index 00000000..406d0154 --- /dev/null +++ b/DATASET_PRESET_TESTING.md @@ -0,0 +1,96 @@ +# Dataset Preset Testing + +Unit tests for dataset preset transforms. These tests verify that presets correctly transform dataset columns without requiring end-to-end benchmark runs. + +## Quick Start + +```bash +# Run all preset tests +pytest tests/unit/dataset_manager/test_dataset_presets.py -v + +# Run tests for a specific dataset +pytest tests/unit/dataset_manager/test_dataset_presets.py::TestCNNDailyMailPresets -v + +# Exclude slow tests (Harmonize transform requires transformers) +pytest tests/unit/dataset_manager/test_dataset_presets.py -m "not slow" -v +``` + +## Preset Coverage + +| Dataset | Presets | Tests | +| ------------- | ------------------------------- | ----- | +| CNNDailyMail | `llama3_8b`, `llama3_8b_sglang` | 6 | +| AIME25 | `gptoss` | 3 | +| GPQA | `gptoss` | 3 | +| LiveCodeBench | `gptoss` | 3 | +| OpenOrca | `llama2_70b` | 3 | + +## Adding Tests for New Presets + +When adding a new dataset preset, add a test class to `tests/unit/dataset_manager/test_dataset_presets.py`: + +```python +import pandas as pd +import pytest +from inference_endpoint.dataset_manager.transforms import apply_transforms +from inference_endpoint.dataset_manager.predefined.my_dataset import MyDataset + + +class TestMyDatasetPresets: + @pytest.fixture + def sample_data(self): + """Minimal sample data matching dataset schema.""" + return pd.DataFrame({ + "input_col1": ["value1"], + "input_col2": ["value2"], + }) + + @pytest.fixture + def transformed_data(self, sample_data): + """Apply preset transforms to sample data.""" + transforms = MyDataset.PRESETS.my_preset() + return apply_transforms(sample_data, transforms) + + def test_my_preset_instantiation(self): + """Verify preset can be created.""" + transforms = MyDataset.PRESETS.my_preset() + assert transforms is not None + assert len(transforms) > 0 + + def test_my_preset_transforms_apply(self, transformed_data): + """Verify transforms apply without errors.""" + assert transformed_data is not None + assert "prompt" in transformed_data.columns # Expected output column + + def test_my_preset_output_format(self, transformed_data): + """Verify output has expected format.""" + # Validate format-specific expectations + assert len(transformed_data["prompt"][0]) > 0 +``` + +If the preset uses `Harmonize` transform (requires `transformers` library), mark slow tests: + +```python +@pytest.mark.slow +def test_my_preset_transforms_apply(self, transformed_data): + # Test that requires transformers library + pass +``` + +## Test Scope + +✅ **Tests verify:** + +- Preset instantiation +- Transform application without errors +- Required output columns exist +- Data is properly transformed + +❌ **Tests do NOT verify:** + +- Model inference accuracy +- API endpoint compatibility +- Throughput/latency metrics +- Full benchmark runs + +See `src/inference_endpoint/dataset_manager/README.md` for dataset schema and preset creation details. diff --git a/examples/05_Llama3.1-8B_Example/README.md b/examples/05_Llama3.1-8B_Example/README.md index 56c39bf9..0c51089d 100644 --- a/examples/05_Llama3.1-8B_Example/README.md +++ b/examples/05_Llama3.1-8B_Example/README.md @@ -2,9 +2,9 @@ It is recommended to use a config file such as [online_llama3_8b_cnn.yaml](online_llama3_8b_cnn.yaml) to run the benchmark. -## [Optional] Download dataset +## Download dataset (Only needed if quantizing the model) -The Llama3.1-8B benchmark uses the [cnn/dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset (for summarization). If using a config (such as provided) to run the benchmark, the (validation) dataset is downloaded automatically by specifying dataset name as `- name: cnn_dailymail::llama3_8b` under the `dataset` entry. +The Llama3.1-8B benchmark uses the [cnn/dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset (for summarization). If using a config (such as provided) to run the benchmark, the (validation) dataset is downloaded automatically by specifying dataset name as `- name: cnn_dailymail::llama3_8b # or cnn_dailymail::llama3_8b_sglang` under the `dataset` entry. For post-training quantization, users can use the [cnn-dailymail-calibration-list](https://github.com/mlcommons/inference/blob/v4.0/calibration/CNNDailyMail/calibration-list.txt) to select samples for the calibration. @@ -15,6 +15,8 @@ python download_cnndm.py --save-dir data --calibration-ids-file calibration-list ## Launch the server +We provide instructions below for using either vLLM or SGLang endpoints. + The following environment variables are used by the commands below to make the scripts easier to run: ``` @@ -31,7 +33,7 @@ hf download $MODEL_NAME The cached models can be verified with `hf cache scan`. -### [vLLM](https://github.com/vllm-project/vllm) +### [vLLM](https://github.com/vllm-project/vllm) (Using NVIDIA GPUs for demo) **Note**: To generate same outputs as the ones produced from submissions with legacy loadgen, we need to apply a custom chat template (this is taken care of automatically by the cnn-dailymail dataset preset). The flag `--trust-request-chat-template` is also required for this behavior. **Security warning:** `--trust-request-chat-template` allows execution of request-provided chat templates and should only be used in trusted environments or when all requests are controlled by the benchmark harness/preset. Do not enable this flag on publicly exposed endpoints receiving untrusted traffic. @@ -41,22 +43,65 @@ We can launch the latest docker image for vllm using the command below: docker run --runtime nvidia --gpus all -v ${HF_HOME}:/root/.cache/huggingface --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" -p 8000:8000 --ipc=host vllm/vllm-openai:latest --model ${MODEL_NAME} --trust-request-chat-template ``` -### To run Offline mode +### [SGLang](https://github.com/sgl-project/sglang) + +- First build the container and start the endpoint + +``` +# Clone the SGLang repository +SGLANG_VER=3f9fc8b848365a5797a44856854e3e6f00a60dd0 # Latest tested +git clone https://github.com/sgl-project/sglang.git +cd sglang/docker && git checkout $SGLANG_VER + +# Build the docker image +docker build -t sglang-cpu:latest -f xeon.Dockerfile . + +# Initiate a docker container +docker run -it --privileged --ipc=host --network=host -v /dev/shm:/dev/shm -v ~/.cache/huggingface:/root/.cache/huggingface -e "HF_TOKEN=" --name sglang-cpu-server sglang-cpu:latest /bin/bash + +# Start sglang endpoint +docker exec -u root -w /workspace sglang-cpu-server /bin/bash -lc "python3 -m sglang.launch_server \ + --model-path $MODEL_NAME \ + --served-model-name meta-llama/Llama-3.1-8B-Instruct \ + --dtype bfloat16 \ + --device cpu \ + --max-running-requests 64 \ + --max-total-tokens 131072 \ + --chunked-prefill-size 8192 \ + --max-prefill-tokens 32768 \ + --mem-fraction-static 0.9 \ + --disable-piecewise-cuda-graph \ + --disable-radix-cache \ + --host 127.0.0.1 \ + --port 8080 2>&1 | tee server.log" +``` + +## Start benchmark + +Make sure the [`inference-endpoint`](https://github.com/mlcommons/endpoints/tree/main?tab=readme-ov-file#installation) is installed and activated **Note** Double-check the config file for correct parameters such as the model name in the config -- Launch the benchmark with config yaml +- Launch the benchmark with config yaml (For performance only, remove the accuracy dataset entry in the `online_llama3_8b_cnn.yaml`) + +### vLLM endpoint targets + +- To run Offline mode ``` -inference-endpoint benchmark from-config -c offline_llama3_8b_cnn.yaml --timeout 600 +inference-endpoint benchmark from-config -c offline_llama3_8b_cnn.yaml ``` -### To run Online mode +- To run Online mode -**Note** Double-check the config file for correct parameters +``` +inference-endpoint benchmark from-config -c online_llama3_8b_cnn.yaml +``` -- Launch the benchmark with config yaml (For performance only, remove the accuracy dataset entry in the `online_llama3_8b_cnn.yaml`) +### SGLang endpoint targets + +- To run the offline benchmark: ``` -inference-endpoint benchmark from-config -c online_llama3_8b_cnn.yaml --timeout 600 +inference-endpoint benchmark from-config -c offline_llama3_8b_cnn_sglang_api.yaml ``` diff --git a/examples/05_Llama3.1-8B_Example/offline_llama3_8b_cnn_sglang_api.yaml b/examples/05_Llama3.1-8B_Example/offline_llama3_8b_cnn_sglang_api.yaml new file mode 100644 index 00000000..672859f4 --- /dev/null +++ b/examples/05_Llama3.1-8B_Example/offline_llama3_8b_cnn_sglang_api.yaml @@ -0,0 +1,56 @@ +# Offline Throughput Benchmark +name: "offline-llama3.1-8b-cnn-benchmark" +version: "1.0" +type: "offline" + +model_params: + name: "meta-llama/Llama-3.1-8B-Instruct" # Path to the model + temperature: 0.0 + top_p: 1.0 + max_new_tokens: 128 + +datasets: + - name: cnn_dailymail::llama3_8b_sglang + type: accuracy + samples: 13368 + parser: + input: prompt + accuracy_config: + eval_method: "rouge" + ground_truth: "highlights" + extractor: identity_extractor + num_repeats: 1 + - name: cnn_dailymail::llama3_8b_sglang + type: "performance" + samples: 13368 + parser: + input: prompt + +settings: + runtime: + min_duration_ms: 60000 # 1 minute + max_duration_ms: 360000 # 6 minutes (Arbitrary here, and doesn't have counterpart in legacy loadgen) + scheduler_random_seed: 137 # For Poisson/distribution sampling + dataloader_random_seed: 111 # For dataset shuffling (Will be updated after rng seeds are finalized for submission) + n_samples_to_issue: 13368 # Number of samples to issue (for offline, this should match the dataset samples) + + load_pattern: + type: "max_throughput" + + client: + workers: 4 # Number of client workers + +metrics: + collect: + - "throughput" + - "latency" + - "ttft" + - "tpot" + +endpoint_config: + endpoints: + - "http://localhost:8080" + api_type: "sglang" + api_key: null + +report_dir: logs/llama3_8b_cnn_sglang_offline # Directory to save the benchmark report diff --git a/src/inference_endpoint/dataset_manager/predefined/cnndailymail/presets.py b/src/inference_endpoint/dataset_manager/predefined/cnndailymail/presets.py index 8faea0fc..4104734c 100644 --- a/src/inference_endpoint/dataset_manager/predefined/cnndailymail/presets.py +++ b/src/inference_endpoint/dataset_manager/predefined/cnndailymail/presets.py @@ -20,6 +20,7 @@ from inference_endpoint.dataset_manager.transforms import ( AddStaticColumns, + Harmonize, Transform, UserPromptFormatter, ) @@ -48,3 +49,28 @@ def llama3_8b( ), AddStaticColumns(chat_template), ] + + +def llama3_8b_sglang( + stream: bool = True, + max_new_tokens: int = 128, + temperature: float = 0.0, + top_p: float = 1.0, + top_k: int = 1, + tokenizer_name: str = "meta-llama/Llama-3.1-8B-Instruct", +) -> list[Transform]: + return [ + # Step 1: Format the prompt from "article" + UserPromptFormatter( + user_prompt_format=f"Summarize the following news article in {max_new_tokens} tokens. Please output the summary only, without any other text.\n\nArticle:\n{{article}}\n\nSummary:", + output_column="prompt", + ), + # Step 2: Tokenize the raw prompt via Harmonize in plain mode. + Harmonize( + tokenizer_name=tokenizer_name, + prompt_column="prompt", + tokenized_column="input_tokens", + harmonized_column=None, + mode="plain", + ), + ] diff --git a/src/inference_endpoint/dataset_manager/transforms.py b/src/inference_endpoint/dataset_manager/transforms.py index a2e2e3be..379315f1 100644 --- a/src/inference_endpoint/dataset_manager/transforms.py +++ b/src/inference_endpoint/dataset_manager/transforms.py @@ -137,6 +137,7 @@ def __init__( prompt_column: str = "prompt", tokenized_column: str = "input_tokens", harmonized_column: str | None = "harmonized_prompt", + mode: str = "harmony", ): """Initialize the Harmonize transform. @@ -149,10 +150,14 @@ def __init__( tokenized_column: The name of the column containing the tokenized prompt. harmonized_column: The name of the column containing the harmonized prompt. If None, the harmonized prompt will not be stored as text. + mode: "harmony" to render a Harmony conversation; "plain" to tokenize the raw prompt. """ self.prompt_column = prompt_column self.tokenized_column = tokenized_column self.harmonized_column = harmonized_column + self.mode = mode + if self.mode not in {"harmony", "plain"}: + raise ValueError(f"Invalid harmonize mode: {self.mode}") self.harmonizer = Harmonizer( tokenizer_name=tokenizer_name, encoding_name=encoding_name, @@ -175,7 +180,19 @@ def process_row(self, row: dict[str, Any]) -> dict[str, Any]: Returns: Row dictionary with the harmonized prompt added """ - row[self.tokenized_column] = self.harmonizer(row[self.prompt_column]) + # Guard pre-tokenized rows: the SGLang adapter adds a default Harmonize + # (GPT-OSS tokenizer + harmony mode). When row processors are fused, the + # dataframe-level skip is bypassed, so without this guard, adapter + # Harmonize would overwrite input tokens. Alternative: remove Harmonize + # from the adapter transforms and require each SGLang preset to add its + # own Harmonize with the desired tokenizer/args. + if self.tokenized_column in row and row[self.tokenized_column] is not None: + return row + if self.mode == "plain": + tokens = self.harmonizer.to_tokens(row[self.prompt_column]) + row[self.tokenized_column] = tokens + else: + row[self.tokenized_column] = self.harmonizer(row[self.prompt_column]) if self.harmonized_column is not None: row[self.harmonized_column] = self.harmonizer.to_text( row[self.tokenized_column] diff --git a/src/inference_endpoint/openai/types.py b/src/inference_endpoint/openai/types.py index 036dd172..3b0fe726 100644 --- a/src/inference_endpoint/openai/types.py +++ b/src/inference_endpoint/openai/types.py @@ -112,7 +112,7 @@ class ChatCompletionResponseMessage( role: str content: str | None - refusal: str | None + refusal: str | None = None class ChatCompletionChoice( @@ -149,5 +149,5 @@ class ChatCompletionResponse( created: int model: str choices: list[ChatCompletionChoice] - usage: CompletionUsage | None - system_fingerprint: str | None + usage: CompletionUsage | None = None + system_fingerprint: str | None = None diff --git a/tests/unit/dataset_manager/test_dataset_presets.py b/tests/unit/dataset_manager/test_dataset_presets.py new file mode 100644 index 00000000..3c16e277 --- /dev/null +++ b/tests/unit/dataset_manager/test_dataset_presets.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: 2026 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for preset dataset transforms. + +Tests verify that each preset configuration: +1. Can be instantiated without errors +2. Applies transforms correctly to sample data +3. Produces expected output columns + +These tests do NOT require end-to-end benchmarking or external compute resources. +Instead, they use minimal dummy datasets with the required columns. +""" + +import pandas as pd +import pytest +from inference_endpoint.dataset_manager.predefined.aime25 import AIME25 +from inference_endpoint.dataset_manager.predefined.cnndailymail import CNNDailyMail +from inference_endpoint.dataset_manager.predefined.gpqa import GPQA +from inference_endpoint.dataset_manager.predefined.livecodebench import LiveCodeBench +from inference_endpoint.dataset_manager.predefined.open_orca import OpenOrca +from inference_endpoint.dataset_manager.transforms import apply_transforms + + +class TestCNNDailyMailPresets: + """Test CNN/DailyMail dataset presets.""" + + @pytest.fixture + def sample_cnn_data(self): + """Create minimal sample data matching CNN/DailyMail schema.""" + return pd.DataFrame( + { + "article": [ + "CNN reported today that markets are up. Stocks rose 2%.", + "Breaking news: New policy announced. Impact expected next quarter.", + ], + "highlights": [ + "Markets up 2%", + "Policy announced", + ], + } + ) + + @pytest.fixture + def llama3_8b_transformed(self, sample_cnn_data): + """Apply llama3_8b preset transforms to sample data.""" + transforms = CNNDailyMail.PRESETS.llama3_8b() + return apply_transforms(sample_cnn_data, transforms) + + @pytest.fixture + def llama3_8b_sglang_transformed(self, sample_cnn_data): + """Apply llama3_8b_sglang preset transforms to sample data.""" + transforms = CNNDailyMail.PRESETS.llama3_8b_sglang() + return apply_transforms(sample_cnn_data, transforms) + + def test_llama3_8b_preset_instantiation(self): + """Test that llama3_8b preset can be instantiated.""" + transforms = CNNDailyMail.PRESETS.llama3_8b() + assert transforms is not None + assert len(transforms) > 0 + + def test_llama3_8b_transforms_apply(self, llama3_8b_transformed): + """Test that llama3_8b transforms apply without errors.""" + assert llama3_8b_transformed is not None + assert "prompt" in llama3_8b_transformed.columns + assert len(llama3_8b_transformed["prompt"][0]) > 0 + + def test_llama3_8b_prompt_format(self, llama3_8b_transformed, sample_cnn_data): + """Test that llama3_8b produces properly formatted prompts.""" + prompt = llama3_8b_transformed["prompt"][0] + assert "Summarize" in prompt + assert "news article" in prompt + assert "article" in sample_cnn_data.columns + # The original article should be embedded in the prompt + assert sample_cnn_data["article"][0] in prompt + + @pytest.mark.slow + def test_llama3_8b_sglang_preset_instantiation(self): + """Test that llama3_8b_sglang preset can be instantiated.""" + transforms = CNNDailyMail.PRESETS.llama3_8b_sglang() + assert transforms is not None + assert len(transforms) > 0 + + @pytest.mark.slow + def test_llama3_8b_sglang_transforms_apply(self, llama3_8b_sglang_transformed): + """Test that llama3_8b_sglang transforms apply without errors.""" + assert llama3_8b_sglang_transformed is not None + # SGLang preset should still provide a prompt column + assert "prompt" in llama3_8b_sglang_transformed.columns + # Key output for SGLang preset is tokenized input + assert "input_tokens" in llama3_8b_sglang_transformed.columns + input_tokens = llama3_8b_sglang_transformed["input_tokens"].iloc[0] + assert isinstance(input_tokens, list) + assert len(input_tokens) > 0 + assert all(isinstance(token, int) for token in input_tokens) + # harmonized_column is expected to be None for this preset + assert "harmonized_prompt" not in llama3_8b_sglang_transformed.columns + + +class TestAIME25Presets: + """Test AIME25 dataset presets.""" + + @pytest.fixture + def sample_aime_data(self): + """Create minimal sample data matching AIME25 schema.""" + return pd.DataFrame( + { + "question": [ + "If x + 1 = 5, then x = ?", + "What is 2 + 2 * 3?", + ], + "answer": [4, 8], + } + ) + + @pytest.fixture + def gptoss_transformed(self, sample_aime_data): + """Apply gptoss preset transforms to sample data.""" + transforms = AIME25.PRESETS.gptoss() + return apply_transforms(sample_aime_data, transforms) + + def test_gptoss_preset_instantiation(self): + """Test that gptoss preset can be instantiated.""" + transforms = AIME25.PRESETS.gptoss() + assert transforms is not None + assert len(transforms) > 0 + + def test_gptoss_transforms_apply(self, gptoss_transformed): + """Test that gptoss transforms apply without errors.""" + assert gptoss_transformed is not None + assert "prompt" in gptoss_transformed.columns + + def test_gptoss_includes_boxed_answer_format(self, gptoss_transformed): + """Test that gptoss format includes boxed answer format.""" + prompt = gptoss_transformed["prompt"][0] + # AIME preset should instruct to put answer in \boxed{} + assert "boxed" in prompt.lower() or "box" in prompt + + +class TestGPQAPresets: + """Test GPQA dataset presets.""" + + @pytest.fixture + def sample_gpqa_data(self): + """Create minimal sample data matching GPQA schema.""" + return pd.DataFrame( + { + "question": [ + "What is the capital of France?", + "Who discovered penicillin?", + ], + "choice1": ["Paris", "Alexander Fleming"], + "choice2": ["London", "Marie Curie"], + "choice3": ["Berlin", "Louis Pasteur"], + "choice4": ["Madrid", "Joseph Lister"], + "correct_choice": ["A", "A"], + } + ) + + @pytest.fixture + def gptoss_transformed(self, sample_gpqa_data): + """Apply gptoss preset transforms to sample data.""" + transforms = GPQA.PRESETS.gptoss() + return apply_transforms(sample_gpqa_data, transforms) + + def test_gptoss_preset_instantiation(self): + """Test that gptoss preset can be instantiated.""" + transforms = GPQA.PRESETS.gptoss() + assert transforms is not None + assert len(transforms) > 0 + + def test_gptoss_transforms_apply(self, gptoss_transformed): + """Test that gptoss transforms apply without errors.""" + assert gptoss_transformed is not None + assert "prompt" in gptoss_transformed.columns + + def test_gptoss_format_includes_choices(self, gptoss_transformed): + """Test that gptoss format includes all multiple choice options.""" + prompt = gptoss_transformed["prompt"][0] + # Should include all four choices formatted as (A), (B), (C), (D) + assert "(A)" in prompt + assert "(B)" in prompt + assert "(C)" in prompt + assert "(D)" in prompt + # Should instruct to express answer as option letter + assert "A" in prompt or "option" in prompt.lower() + + +class TestLiveCodeBenchPresets: + """Test LiveCodeBench dataset presets.""" + + @pytest.fixture + def sample_lcb_data(self): + """Create minimal sample data matching LiveCodeBench schema.""" + return pd.DataFrame( + { + "question": [ + "Write a function that returns the sum of two numbers.", + "Write a function that reverses a string.", + ], + "starter_code": [ + "def add(a, b):\n pass", + "def reverse(s):\n pass", + ], + } + ) + + @pytest.fixture + def gptoss_transformed(self, sample_lcb_data): + """Apply gptoss preset transforms to sample data.""" + transforms = LiveCodeBench.PRESETS.gptoss() + return apply_transforms(sample_lcb_data, transforms) + + def test_gptoss_preset_instantiation(self): + """Test that gptoss preset can be instantiated.""" + transforms = LiveCodeBench.PRESETS.gptoss() + assert transforms is not None + assert len(transforms) > 0 + + def test_gptoss_transforms_apply(self, gptoss_transformed): + """Test that gptoss transforms apply without errors.""" + assert gptoss_transformed is not None + assert "prompt" in gptoss_transformed.columns + + def test_gptoss_format_includes_code_delimiters( + self, gptoss_transformed, sample_lcb_data + ): + """Test that gptoss format includes code delimiters.""" + prompt = gptoss_transformed["prompt"][0] + # Should include ```python delimiters for code + assert "```python" in prompt + assert "starter_code" in sample_lcb_data.columns + # Starter code should be included in prompt + assert sample_lcb_data["starter_code"][0] in prompt + + +class TestOpenOrcaPresets: + """Test OpenOrca dataset presets.""" + + @pytest.fixture + def sample_openorca_data(self): + """Create minimal sample data matching OpenOrca schema.""" + return pd.DataFrame( + { + "question": [ + "What is machine learning?", + "Explain neural networks.", + ], + "system_prompt": [ + "You are an AI expert.", + "You are a technical educator.", + ], + "response": [ + "Machine learning is...", + "Neural networks are...", + ], + } + ) + + @pytest.fixture + def llama2_70b_transformed(self, sample_openorca_data): + """Apply llama2_70b preset transforms to sample data.""" + transforms = OpenOrca.PRESETS.llama2_70b() + return apply_transforms(sample_openorca_data, transforms) + + def test_llama2_70b_preset_instantiation(self): + """Test that llama2_70b preset can be instantiated.""" + transforms = OpenOrca.PRESETS.llama2_70b() + assert transforms is not None + assert len(transforms) > 0 + + def test_llama2_70b_transforms_apply(self, llama2_70b_transformed): + """Test that llama2_70b transforms apply without errors.""" + assert llama2_70b_transformed is not None + assert "prompt" in llama2_70b_transformed.columns + assert "system" in llama2_70b_transformed.columns + + def test_llama2_70b_remaps_columns( + self, llama2_70b_transformed, sample_openorca_data + ): + """Test that llama2_70b correctly remaps question->prompt and system_prompt->system.""" + # After transformation, original columns should be renamed + assert "prompt" in llama2_70b_transformed.columns + assert "system" in llama2_70b_transformed.columns + # Data should be preserved in renamed columns + assert ( + llama2_70b_transformed["prompt"][0] == sample_openorca_data["question"][0] + ) + assert ( + llama2_70b_transformed["system"][0] + == sample_openorca_data["system_prompt"][0] + ) diff --git a/tests/unit/dataset_manager/test_transforms.py b/tests/unit/dataset_manager/test_transforms.py index ab342204..f47a2a29 100644 --- a/tests/unit/dataset_manager/test_transforms.py +++ b/tests/unit/dataset_manager/test_transforms.py @@ -15,10 +15,11 @@ """ Unit tests for the transforms module. -Tests all transform classes and functions except Harmonize. +Tests all transform classes and functions. """ from typing import Any +from unittest.mock import patch import pandas as pd import pytest @@ -27,6 +28,7 @@ ColumnFilter, ColumnRemap, FusedRowProcessor, + Harmonize, MakeAdapterCompatible, RowProcessor, Transform, @@ -825,3 +827,50 @@ def test_no_matching_columns(self): # Should not raise error or create prompt column assert "prompt" not in result.columns assert "unrelated" in result.columns + + +class TestHarmonize: + """Test Harmonize transform with mocked Harmonizer to avoid tokenizer downloads.""" + + @patch("inference_endpoint.dataset_manager.transforms.Harmonizer") + def test_harmonize_invalid_mode_raises(self, mock_harmonizer): + """Test that invalid mode raises ValueError.""" + with pytest.raises(ValueError, match="Invalid harmonize mode"): + Harmonize(mode="invalid") + + @patch("inference_endpoint.dataset_manager.transforms.Harmonizer") + def test_harmonize_row_skip_existing_tokens(self, mock_harmonizer): + """Test that pre-existing tokens aren't overwritten (fusion safety).""" + # This is critical: when fused, the __call__ check is bypassed + # so the row-level check must prevent overwriting preset-added tokens + harmonize = Harmonize() + row = { + "prompt": "test", + "input_tokens": [1, 2, 3], # Already added by preset + } + result = harmonize.process_row(row) + assert result["input_tokens"] == [1, 2, 3] + # Harmonizer should NOT be called + mock_harmonizer.return_value.assert_not_called() + + @patch("inference_endpoint.dataset_manager.transforms.Harmonizer") + def test_harmonize_plain_mode(self, mock_harmonizer): + """Test plain mode uses to_tokens() instead of full Harmonizer call.""" + mock_harmonizer.return_value.to_tokens.return_value = [1, 2, 3] + harmonize = Harmonize(mode="plain") + row = {"prompt": "test prompt"} + result = harmonize.process_row(row) + + assert result["input_tokens"] == [1, 2, 3] + mock_harmonizer.return_value.to_tokens.assert_called_once_with("test prompt") + + @patch("inference_endpoint.dataset_manager.transforms.Harmonizer") + def test_harmonize_harmony_mode(self, mock_harmonizer): + """Test harmony mode calls full Harmonizer.""" + mock_harmonizer.return_value.return_value = [4, 5, 6] + harmonize = Harmonize(mode="harmony") + row = {"prompt": "test prompt"} + result = harmonize.process_row(row) + + assert result["input_tokens"] == [4, 5, 6] + mock_harmonizer.return_value.assert_called_once_with("test prompt")