diff --git a/slime/rollout/data_source.py b/slime/rollout/data_source.py index fa14c65f0..620f6403f 100644 --- a/slime/rollout/data_source.py +++ b/slime/rollout/data_source.py @@ -75,6 +75,7 @@ def __init__(self, args): apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, seed=args.rollout_seed, + data_loading_workers=args.data_loading_workers, ) if self.args.rollout_shuffle: self.dataset.shuffle(self.epoch_id) diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index a0cc52fb6..151dabb8d 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -467,6 +467,7 @@ async def eval_rollout_single_dataset( tool_key=dataset_cfg.tool_key, apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, + data_loading_workers=args.data_loading_workers, ) dataset = EVAL_PROMPT_DATASET[cache_key] diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 7020fb88d..5dc5c7129 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -534,6 +534,12 @@ def add_data_arguments(parser): "When need to add tools during apply_chat_template, you should provide the key for the tools in the prompt dataset." ), ) + parser.add_argument( + "--data-loading-workers", + type=int, + default=1, + help="Number of parallel workers for data loading. Higher values speed up VLM data loading.", + ) parser.add_argument( "--start-rollout-id", diff --git a/slime/utils/data.py b/slime/utils/data.py index 0b9d6a6db..21d68bd50 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -4,9 +4,11 @@ import os import random import re +from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np import ray +from tqdm import tqdm try: import pyarrow.parquet as pq @@ -181,9 +183,14 @@ def __init__( seed=42, apply_chat_template=False, apply_chat_template_kwargs=None, + data_loading_workers=1, ): - origin_samples = [] - for data in read_file(path): + # Phase 1: Read raw data (sequential, fast) + raw_data_list = list(read_file(path)) + logger.info(f"Read {len(raw_data_list)} raw samples from {path}") + + # Define per-sample processing function + def process_single_sample(data): # Both chat templates and multimodal inputs require conversation format (list of message dicts) as_conversation = apply_chat_template or (multimodal_keys is not None) prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys) @@ -220,15 +227,28 @@ def __init__( else: multimodal_inputs = None - origin_samples.append( - Sample( - prompt=output_prompt, - label=data[label_key] if label_key is not None else None, - metadata=metadata, - multimodal_inputs=multimodal_inputs, - ) + return Sample( + prompt=output_prompt, + label=data[label_key] if label_key is not None else None, + metadata=metadata, + multimodal_inputs=multimodal_inputs, ) + # Phase 2: Process samples in parallel (default to 1 worker if not specified) + logger.info(f"Loading data with {data_loading_workers} workers...") + origin_samples = [None] * len(raw_data_list) + with ThreadPoolExecutor(max_workers=data_loading_workers) as executor: + future_to_idx = { + executor.submit(process_single_sample, data): idx for idx, data in enumerate(raw_data_list) + } + for future in tqdm(as_completed(future_to_idx), total=len(raw_data_list), desc="Loading data"): + idx = future_to_idx[future] + try: + origin_samples[idx] = future.result() + except Exception as e: + logger.error(f"Error processing sample {idx}: {e}") + raise + if max_length is not None: self.origin_samples = filter_long_prompt(origin_samples, tokenizer, processor, max_length) else: