From 2ab78918d7aa21294b491d7d76ff61f9e14442c9 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 6 Feb 2026 17:08:46 +0100 Subject: [PATCH] Added Rubrics support --- docs/source/openenv.md | 23 + examples/notebooks/openenv_wordle_grpo.ipynb | 3674 +++++++++--------- examples/scripts/openenv/echo.py | 1 - examples/scripts/openenv/wordle.py | 69 +- trl/experimental/openenv/__init__.py | 4 +- trl/experimental/openenv/utils.py | 67 + 6 files changed, 1879 insertions(+), 1959 deletions(-) diff --git a/docs/source/openenv.md b/docs/source/openenv.md index 988a64e3b9d..b43df01d24c 100644 --- a/docs/source/openenv.md +++ b/docs/source/openenv.md @@ -114,6 +114,29 @@ args = GRPOConfig( # CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen3-1.7B --tensor-parallel-size 4 ``` +## OpenEnv Rubric System + +OpenEnv environments use Rubrics to expose composable reward components. Instead of a single scalar reward, rubrics provide named component scores (e.g., `"wordle.greens"`, `"wordle.yellows"`, `"wordle.correct"`) for logging and custom reward shaping. + +### Using Rubrics in TRL + +Extract rubric scores from observations using `get_rubric_scores()`: + +```python +from trl.experimental.openenv import get_rubric_scores + +# In your rollout function +result = env.step(action) +scores = get_rubric_scores(observation=result.observation) +# Example: {"wordle.greens": 0.8, "wordle.yellows": 0.4, "wordle.correct": 1.0} + +# Use in reward functions +def reward_correct(completions, **kwargs): + return [float(r) for r in kwargs.get("rubric_correct", [0.0] * len(completions))] +``` + +See [`examples/scripts/openenv/wordle.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/wordle.py) for a complete example using TextArena rubrics. + ## Running the Environments You can run OpenEnv environments in three different ways: diff --git a/examples/notebooks/openenv_wordle_grpo.ipynb b/examples/notebooks/openenv_wordle_grpo.ipynb index 6711024fb68..b0906d1a8dd 100644 --- a/examples/notebooks/openenv_wordle_grpo.ipynb +++ b/examples/notebooks/openenv_wordle_grpo.ipynb @@ -1,2036 +1,1850 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b", - "metadata": { - "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b" - }, - "source": [ - "# OpenEnv Wordle with GRPO using TRL\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb)\n", - "\n", - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n", - "\n", - "\n", - "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can train a model that learns to **play Wordle**, a word-guessing game, through interaction and reinforcement.\n", - "\n", - "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", - "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", - "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n", - "- [OpenEnv](https://github.com/meta-pytorch/OpenEnv)\n", - "\n", - "\n", - "An **agentic environment** is a setting where a model can take actions, observe outcomes, and adjust its behavior based on feedback, similar to how humans learn from trial and error.\n", - "In this case, the agent interacts with the **Wordle** environment through the [**OpenEnv**](https://github.com/meta-pytorch/OpenEnv) framework, which standardizes multi-agent and RL-style text environments.\n", - "\n", - "[Wordle](https://en.wikipedia.org/wiki/Wordle) is a popular word puzzle where the player must guess a secret five-letter word within six tries. \n", - "After each guess, feedback indicates whether each letter is:\n", - "- 🟩 **Correct and in the right position**\n", - "- 🟨 **Present but in the wrong position**\n", - "- ⬛ **Not in the word**\n", - "\n", - "This feedback loop makes Wordle a perfect environment for **RL with LLMs**, where the goal is to maximize the probability of guessing the correct word efficiently.\n", - "\n", - "\n", - "We'll fine-tune a model using **GRPO** (Group Relative Policy Optimization) via TRL. \n", - "The agent will:\n", - "1. Generate guesses based on the game state and feedback.\n", - "2. Receive structured feedback from the environment after each guess.\n", - "3. Learn to improve its guessing strategy over time through reward signals.\n", - "\n", - "\n", - "## Install dependencies\n", - "\n", - "We'll start by installing **TRL**, which automatically includes the main dependencies like **Transformers**. \n", - "We'll also install the **OpenEnv** framework via the remote deployent env at [sergiopaniego/wordle](https://huggingface.co/spaces/sergiopaniego/wordle) (for the environment), **trackio** (for logging and monitoring training runs), and **vLLM** (for efficient generation)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b4812fbf-3f61-481e-9a64-95277eada9c9", - "metadata": { - "id": "b4812fbf-3f61-481e-9a64-95277eada9c9" - }, - "outputs": [], - "source": [ - "!pip install -Uq trl[vllm] git+https://huggingface.co/spaces/sergiopaniego/wordle trackio" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b", + "metadata": { + "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b" + }, + "source": [ + "# OpenEnv Wordle with GRPO using TRL\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb)\n", + "\n", + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n", + "\n", + "\n", + "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can train a model that learns to **play Wordle**, a word-guessing game, through interaction and reinforcement.\n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n", + "- [OpenEnv](https://github.com/meta-pytorch/OpenEnv)\n", + "\n", + "\n", + "An **agentic environment** is a setting where a model can take actions, observe outcomes, and adjust its behavior based on feedback, similar to how humans learn from trial and error.\n", + "In this case, the agent interacts with the **Wordle** environment through the [**OpenEnv**](https://github.com/meta-pytorch/OpenEnv) framework, which standardizes multi-agent and RL-style text environments.\n", + "\n", + "[Wordle](https://en.wikipedia.org/wiki/Wordle) is a popular word puzzle where the player must guess a secret five-letter word within six tries. \n", + "After each guess, feedback indicates whether each letter is:\n", + "- 🟩 **Correct and in the right position**\n", + "- 🟨 **Present but in the wrong position**\n", + "- ⬛ **Not in the word**\n", + "\n", + "This feedback loop makes Wordle a perfect environment for **RL with LLMs**, where the goal is to maximize the probability of guessing the correct word efficiently.\n", + "\n", + "\n", + "We'll fine-tune a model using **GRPO** (Group Relative Policy Optimization) via TRL. \n", + "The agent will:\n", + "1. Generate guesses based on the game state and feedback.\n", + "2. Receive structured feedback from the environment after each guess.\n", + "3. Learn to improve its guessing strategy over time through reward signals.\n", + "\n", + "\n", + "## Install dependencies\n", + "\n", + "We'll start by installing **TRL**, which automatically includes the main dependencies like **Transformers**. \n", + "We'll also install the **OpenEnv** framework via the remote deployent env at [sergiopaniego/wordle](https://huggingface.co/spaces/sergiopaniego/wordle) (for the environment), **trackio** (for logging and monitoring training runs), and **vLLM** (for efficient generation)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4812fbf-3f61-481e-9a64-95277eada9c9", + "metadata": { + "id": "b4812fbf-3f61-481e-9a64-95277eada9c9" + }, + "outputs": [], + "source": [ + "!pip install -Uq trl[vllm] git+https://huggingface.co/spaces/sergiopaniego/wordle trackio" + ] + }, + { + "cell_type": "markdown", + "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148", + "metadata": { + "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148" + }, + "source": [ + "### Log in to Hugging Face\n", + "\n", + "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21756ac0-78b2-495d-8137-28dfa9faae6a", + "metadata": { + "id": "21756ac0-78b2-495d-8137-28dfa9faae6a" + }, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "id": "rpFT3PxHT5Uc", + "metadata": { + "id": "rpFT3PxHT5Uc" + }, + "source": [ + "## Initialize the Environment\n", + "\n", + "Let's begin by setting up the environment that will be used during training. \n", + "For this task, we'll rely on the **TextArena** environment from **OpenEnv**, which exposes a familiar Gymnasium-style API (`reset()`, `step()`, etc.) to simplify interaction.\n", + "\n", + "In this example, we'll connect to the hosted environment at [sergiopaniego/wordle](https://huggingface.co/spaces/sergiopaniego/wordle). \n", + "For production use or custom configurations, we **strongly recommend** running the environment locally via Docker. The hosted versions on the Hub currently have limited concurrency support, so duplicating the Space to your own account is the preferred approach in those cases.\n", + "\n", + "For more information, refer to the [TRL-OpenEnv documentation](https://huggingface.co/docs/trl/main/en/openenv).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "rZimqp1UTIV_", + "metadata": { + "id": "rZimqp1UTIV_", + "outputId": "e53c277c-6050-4380-84e1-983857f0b325" + }, + "outputs": [ { - "cell_type": "markdown", - "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148", - "metadata": { - "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148" - }, - "source": [ - "### Log in to Hugging Face\n", - "\n", - "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." - ] - }, + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] + } + ], + "source": [ + "from textarena_env import TextArenaEnv\n", + "\n", + "wordle_url = \"https://sergiopaniego-wordle.hf.space\" # Duplicate the Space and update this!\n", + "env = TextArenaEnv(base_url=wordle_url)\n", + "# wordle_url = \"sergiopaniego/wordle\"\n", + "# env = TextArenaEnv.from_hub(repo_id=wordle_url)" + ] + }, + { + "cell_type": "markdown", + "id": "hARwiQm8ehw3", + "metadata": { + "id": "hARwiQm8ehw3" + }, + "source": [ + "## Init model and tokenizer\n", + "\n", + "We'll use [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B), a lightweight instruction-tuned model that works well for quick experiments. \n", + "Despite its small size, it can still learn interesting strategies during fine-tuning. \n", + "If you have stronger hardware, you can easily scale up to larger models.\n", + "\n", + "We'll load the **tokenizer** (needed for text processing) here. \n", + "The **model** itself will be handled internally by TRL during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "lR7usp2Dd-JK", + "metadata": { + "id": "lR7usp2Dd-JK", + "outputId": "b8a60feb-e0c0-47c9-839e-2743a502341f" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "id": "21756ac0-78b2-495d-8137-28dfa9faae6a", - "metadata": { - "id": "21756ac0-78b2-495d-8137-28dfa9faae6a" - }, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning: \n", + "Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.\n", + "You are not authenticated with the Hugging Face Hub in this notebook.\n", + "If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "model_name = \"Qwen/Qwen3-1.7B\" #\"Qwen/Qwen2.5-0.5B-Instruct\" # \"Qwen/Qwen3-0.6B\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "tokenizer.pad_token = tokenizer.eos_token" + ] + }, + { + "cell_type": "markdown", + "id": "0oojh2i0ey88", + "metadata": { + "id": "0oojh2i0ey88" + }, + "source": [ + "## Rollout function with helpers\n", + "\n", + "The **rollout function** defines how the agent interacts with the environment during GRPO training.\n", + "It's responsible for generating model completions, collecting feedback (rewards), and returning all necessary information for optimization.\n", + "\n", + "In this setup:\n", + "- The function is called automatically by the **GRPOTrainer** during each training step. \n", + "- It uses the trainer's built-in `generate_rollout_completions()` method for efficient generation with vLLM in colocate mode.\n", + "- Each rollout represents a full interaction loop. The model guesses, receives feedback from Wordle, and updates based on reward signals.\n", + "- The **`env_mask`** tracks which tokens are model-generated vs environment-generated, ensuring only model tokens contribute to the training loss.\n", + "\n", + "The rewards track different aspects of the agent's performance. Helper functions (like `rollout_once`) handle one episode of interaction, keeping the main `rollout_func` clean and modular.\n", + "\n", + "This modular approach allows GRPO to efficiently sample, evaluate, and improve the model's guessing strategy through reinforcement learning.\n", + "\n", + "First, we define the `system_prompt` that guides the model's behavior as an expert Wordle solver with strategic reasoning and structured responses." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "QlUHqvZV6ytz", + "metadata": { + "id": "QlUHqvZV6ytz" + }, + "outputs": [], + "source": [ + "# @title System prompt (click to expand)\n", + "system_prompt = \"\"\"\n", + "You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies.\n", + "\n", + "## GAME RULES\n", + "\n", + "1. The target is a 5-letter English word\n", + "2. You have 6 attempts to guess the correct word\n", + "3. After each guess, you receive color-coded feedback:\n", + " - GREEN: Letter is correct and in the correct position\n", + " - YELLOW: Letter is in the word but in the wrong position\n", + " - GRAY: Letter is not in the word at all\n", + "4. All guesses must be valid 5-letter English words\n", + "5. You cannot reuse a word you've already guessed\n", + "\n", + "## RESPONSE FORMAT\n", + "\n", + "Only respond with your next guess in square brackets, e.g., [crane].\n", + "\n", + "Format:\n", + "```\n", + "[guess]\n", + "```\n", + "\n", + "\n", + "## STRATEGIC APPROACH\n", + "\n", + "Do not repeat the same guess twice.\n", + "\n", + "### Opening Strategy\n", + "- Start with words rich in common vowels (A, E, I, O, U) and consonants (R, S, T, L, N)\n", + "- Optimal starters: CRANE, SLATE, STARE, AROSE, IRATE\n", + "- Prioritize words that test the most common letters in different positions\n", + "\n", + "### Mid-Game Strategy\n", + "- Use confirmed GREEN letters in their correct positions\n", + "- Place YELLOW letters in different positions than where they appeared\n", + "- Eliminate GRAY letters entirely from consideration\n", + "- If multiple letters are unknown, prioritize common letter combinations (TH, CH, ST, ER, etc.)\n", + "- Consider letter frequency: E is most common, followed by A, R, I, O, T, N, S\n", + "\n", + "### Vowel Placement\n", + "- Most 5-letter words have 2 vowels\n", + "- Common patterns: vowel-consonant-vowel (like CRANE) or consonant-vowel-vowel-consonant-vowel (like QUEUE)\n", + "- If you have 1-2 vowels confirmed, consider where the others might be\n", + "\n", + "### Advanced Tactics\n", + "- Use \"sacrificial\" guesses to test multiple new letters if you have attempts to spare\n", + "- Avoid repeating letter patterns unless you're certain (e.g., SPEED has two E's)\n", + "- Think about word endings: -ER, -LY, -ED, -ING are common but may not fit the 5-letter constraint\n", + "- Consider less common letters (Q, X, Z, J) only when you've eliminated most common options\n", + "\n", + "### Common Pitfalls to Avoid\n", + "- Don't reuse X letters\n", + "- Don't place Y letters in the same position they appeared\n", + "- Don't ignore confirmed G letters\n", + "- Don't guess words that contradict known information\n", + "\n", + "## EXAMPLES\n", + "\n", + "### Example 1: Opening Guess\n", + "\"Starting with a word that tests common vowels and consonants in varied positions.\"\n", + "[crane]\n", + "\n", + "### Example 2: After Receiving Feedback\n", + "Previous guess: CRANE\n", + "Feedback: C=gray, R=yellow, A=green, N=gray, E=yellow\n", + "\n", + "\"A is confirmed in position 2. R and E are in the word but need different positions. C and N are eliminated. I'll try a word with A in position 2, and test R and E in new positions along with common letters like S and T.\"\n", + "[spare]\n", + "\n", + "### Example 3: Narrowing Down\n", + "Previous guesses: CRANE (C=gray, R=yellow, A=green, N=gray, E=yellow), SPARE (S=gray, P=gray, A=green, R=green, E=green)\n", + "Feedback summary: _ARE_ with R in position 4, A in position 2, E in position 5\n", + "\n", + "\"I have _AR E_ confirmed. Position 1 and 3 are unknown. Common letters to try: T, L, D, B, F, G. Testing with TARED.\"\n", + "[tared]\n", + "\n", + "### Example 4: Final Deduction\n", + "Previous feedback shows: _ARED with position 1 unknown and all common consonants tested\n", + "\n", + "\"Only position 1 remains. I've eliminated S, P, C, N. Common starting consonants left are B, F, G, H. BARED is a common word.\"\n", + "[bared]\n", + "\n", + "## LETTER FREQUENCY REFERENCE\n", + "\n", + "Most common letters in 5-letter words (in order):\n", + "S, E, A, O, R, I, L, T, N, U, D, Y, C, P, M, H, G, B, K, F\n", + "\n", + "Most common starting letters:\n", + "S, C, B, T, P, A, F, G, D, M\n", + "\n", + "Most common ending letters:\n", + "E, Y, T, S, R, L, N, D\n", + "\n", + "## IMPORTANT CONSTRAINTS\n", + "\n", + "- Use lowercase only\n", + "- One guess per response\n", + "- Must be exactly 5 letters\n", + "- Must be a real English word from standard dictionaries\n", + "- Never repeat a previous guess\n", + "- Always include brief reasoning before your guess\n", + "\n", + "## YOUR GOAL\n", + "\n", + "Solve the Wordle in as few guesses as possible by strategically using feedback to eliminate impossible words and narrow down the solution space efficiently.\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "id": "rUOAm7o-kJ5U", + "metadata": { + "id": "rUOAm7o-kJ5U" + }, + "source": [ + "Now, let's define the `rollout_func`:\n", + "\n", + "This function orchestrates the interaction between the model and the Wordle environment. For each prompt in the batch, it runs the episode interaction, collecting rewards and model outputs for GRPO optimization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a9e7a62-fff9-4caa-9500-dd278f49ec0f", + "metadata": { + "id": "8a9e7a62-fff9-4caa-9500-dd278f49ec0f" + }, + "outputs": [], + "source": "max_new_tokens = 8\nmax_turns = 6\n\ndef rollout_func(prompts, trainer):\n \"\"\"\n Rollout function for GRPO training with environment interaction.\n\n This function is called by GRPOTrainer to generate completions and compute rewards.\n It uses trainer.generate_rollout_completions() for inference.\n\n Args:\n prompts: List of prompts to generate from\n trainer: GRPOTrainer instance containing context and configuration\n\n Returns:\n Dictionary with prompt_ids, completion_ids, logprobs, env_mask, and reward signals\n \"\"\"\n episode_prompt_ids = []\n episode_completion_ids = []\n episode_logprobs = []\n episode_env_masks = []\n correctness_rewards = []\n position_rewards = []\n format_rewards = []\n rubric_component_scores = {}\n\n for prompt_text in prompts:\n episode = rollout_once(\n trainer=trainer,\n env=env,\n tokenizer=tokenizer,\n dataset_prompt=prompt_text,\n system_prompt=system_prompt,\n max_turns=max_turns,\n max_new_tokens=max_new_tokens,\n )\n episode_prompt_ids.append(episode[\"prompt_ids\"])\n episode_completion_ids.append(episode[\"completion_ids\"])\n episode_logprobs.append(episode[\"logprobs\"])\n episode_env_masks.append(episode[\"env_mask\"])\n correctness_rewards.append(episode[\"correct_reward\"])\n position_rewards.append(episode[\"position_reward\"])\n format_rewards.append(compute_format_reward(episode[\"model_outputs\"]))\n\n # Collect rubric component scores\n for key, value in episode.items():\n if key.startswith(\"rubric_\"):\n component_name = key\n if component_name not in rubric_component_scores:\n rubric_component_scores[component_name] = []\n rubric_component_scores[component_name].append(value)\n\n result = {\n \"prompt_ids\": episode_prompt_ids,\n \"completion_ids\": episode_completion_ids,\n \"logprobs\": episode_logprobs,\n \"env_mask\": episode_env_masks,\n \"correct_reward\": correctness_rewards,\n \"position_reward\": position_rewards,\n \"format_reward\": format_rewards,\n }\n\n # Add rubric component scores to result for logging\n result.update(rubric_component_scores)\n\n return result" + }, + { + "cell_type": "markdown", + "id": "mJ4D8zvAkQLh", + "metadata": { + "id": "mJ4D8zvAkQLh" + }, + "source": [ + "### Define `rollout_once`\n", + "\n", + "The `rollout_once` function runs **one full interaction loop** between the model and the Wordle environment using the trainer's generation method. \n", + "It executes a mini episode of gameplay, from generating a guess to receiving and processing feedback.\n", + "\n", + "Here's the step-by-step breakdown:\n", + "\n", + "1. **Environment reset:** Start a new game session and initialize the observation. \n", + "2. **Prompt construction:** Combine the system prompt, current state, and user messages to form the model input. \n", + "3. **Generation:** Use `trl.experimental.openenv.generate_rollout_completions()` to produce the model's guess efficiently. \n", + "4. **Feedback extraction:** Parse the environment's response using helpers like `extract_guess()` and `extract_wordle_feedback()`. \n", + "5. **Reward calculation:** Compute rewards based on correctness, green/yellow feedback, and repetition penalty.\n", + "6. **Return structured rollout data:** Includes prompt/completion IDs, logprobs, `env_mask`, and all computed reward components.\n", + "\n", + "**Important: The `env_mask` mechanism**\n", + "\n", + "In multi-turn environments like Wordle, the completion includes both:\n", + "- **Model-generated tokens** (the guesses): These should contribute to the loss during training.\n", + "- **Environment feedback tokens** (game responses): These should NOT contribute to the loss.\n", + "\n", + "The `env_mask` is a list of 1s and 0s that marks which tokens are model-generated (`1`) vs environment-generated (`0`). \n", + "The GRPOTrainer uses this mask to exclude environment tokens from the loss calculation, ensuring the model only learns from its own outputs.\n", + "\n", + "This modular design ensures that each episode can be processed independently while still providing rich feedback for the **GRPO training loop**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c585602-5352-4e57-8d35-e5b95e05f6c5", + "metadata": { + "id": "5c585602-5352-4e57-8d35-e5b95e05f6c5" + }, + "outputs": [], + "source": "import re\nfrom textarena_env import TextArenaAction\nfrom textarena_env.rewards import extract_feedback_counts, extract_guess, extract_wordle_feedback\nfrom trl.experimental.openenv import generate_rollout_completions, get_rubric_scores\n\ndef rollout_once(trainer, env, tokenizer, dataset_prompt, system_prompt, max_turns, max_new_tokens):\n result = env.reset()\n observation = result.observation\n\n prompt_ids = []\n completion_ids = []\n logprobs = []\n env_mask = [] # 1 for model-generated tokens, 0 for environment tokens\n model_outputs = []\n position_scores = []\n correct_scores = []\n prev_env_output_len = 0 # Track length to only add NEW portion each turn\n\n accumulated_messages: list[dict[str, str]] = [{\"role\": \"system\", \"content\": system_prompt}]\n # Build initial prompt (only once, at the start)\n base_prompt = observation.prompt or dataset_prompt\n initial_user_prompt = make_user_prompt(base_prompt, observation.messages)\n initial_env_output = format_history(observation.messages) if observation.messages else \"\"\n prev_env_output_len = len(initial_env_output)\n initial_messages = accumulated_messages + [{\"role\": \"user\", \"content\": initial_user_prompt}]\n initial_prompt_text = tokenizer.apply_chat_template(\n initial_messages,\n add_generation_prompt=True,\n tokenize=False,\n enable_thinking=False,\n )\n initial_prompt_ids = tokenizer.encode(initial_prompt_text, add_special_tokens=False)\n prompt_ids.extend(initial_prompt_ids)\n\n for _turn in range(max_turns):\n if result.done:\n break\n\n base_prompt = observation.prompt or dataset_prompt\n user_prompt = make_user_prompt(base_prompt, observation.messages)\n messages = accumulated_messages + [{\"role\": \"user\", \"content\": user_prompt}]\n prompt_text = tokenizer.apply_chat_template(\n messages,\n add_generation_prompt=True,\n tokenize=False,\n enable_thinking=False,\n )\n\n rollout_outputs = generate_rollout_completions(\n trainer, [prompt_text], generation_overrides={\"max_tokens\": max_new_tokens}\n )[0]\n newline_tokens = tokenizer.encode(\"\\n\", add_special_tokens=False)\n completion_ids.extend(newline_tokens)\n logprobs.extend([0.0] * len(newline_tokens))\n env_mask.extend([1] * len(newline_tokens))\n\n completion_ids.extend(rollout_outputs[\"completion_ids\"])\n logprobs.extend(rollout_outputs[\"logprobs\"])\n env_mask.extend([1] * len(rollout_outputs[\"completion_ids\"]))\n\n completion_ids.extend(newline_tokens)\n logprobs.extend([0.0] * len(newline_tokens))\n env_mask.extend([1] * len(newline_tokens))\n completion_text = rollout_outputs.get(\"text\") or tokenizer.decode(\n rollout_outputs[\"completion_ids\"], skip_special_tokens=True\n )\n guess = extract_guess(completion_text)\n model_outputs.append(completion_text.strip())\n\n result = env.step(TextArenaAction(message=guess))\n\n observation = result.observation\n correct_score = float(result.reward or 0.0)\n\n # Calculate position score (greens worth 1.0, yellows worth 0.5)\n rubric_scores = get_rubric_scores(observation=observation)\n if rubric_scores:\n position_score = rubric_scores.get(\"wordle.greens\", 0.0) + 0.5 * rubric_scores.get(\"wordle.yellows\", 0.0)\n else:\n feedback = extract_wordle_feedback(observation)\n if not feedback:\n position_score = 0.0\n else:\n green_count, yellow_count = extract_feedback_counts(feedback)\n position_score = (green_count + 0.5 * yellow_count) / 5.0\n\n position_scores.append(position_score)\n correct_scores.append(correct_score)\n\n full_env_output = format_history(observation.messages) if observation.messages else \"\"\n new_env_output = full_env_output[prev_env_output_len:].lstrip(\"\\n\")\n prev_env_output_len = len(full_env_output)\n\n if new_env_output:\n env_output_tokens = tokenizer.encode(new_env_output, add_special_tokens=False)\n completion_ids.extend(env_output_tokens)\n logprobs.extend([0.0] * len(env_output_tokens))\n env_mask.extend([0] * len(env_output_tokens))\n completion_with_env = completion_text + \"\\n\" + new_env_output\n else:\n completion_with_env = completion_text\n\n accumulated_messages.append({\"role\": \"user\", \"content\": user_prompt})\n accumulated_messages.append({\"role\": \"assistant\", \"content\": completion_with_env})\n\n # Final rewards: correct is binary win/lose, position uses last attempt (or 1.0 if won)\n correct_reward_value = correct_scores[-1] if correct_scores else 0.0\n final_position_reward = 1.0 if correct_reward_value >= 1.0 else (position_scores[-1] if position_scores else 0.0)\n\n result_dict = {\n \"prompt_ids\": prompt_ids,\n \"completion_ids\": completion_ids,\n \"logprobs\": logprobs,\n \"env_mask\": env_mask,\n \"correct_reward\": correct_reward_value,\n \"position_reward\": final_position_reward,\n \"model_outputs\": model_outputs,\n }\n\n # Add rubric component scores for logging\n rubric_scores = get_rubric_scores(observation=observation)\n for name, score in rubric_scores.items():\n clean_name = name.replace(\"wordle.\", \"\") if name.startswith(\"wordle.\") else name\n result_dict[f\"rubric_{clean_name}\"] = score\n\n return result_dict" + }, + { + "cell_type": "markdown", + "id": "cipvIDzcoF3C", + "metadata": { + "id": "cipvIDzcoF3C" + }, + "source": [ + "### Helper functions\n", + "\n", + "Supporting utilities used in `rollout_once`:\n", + "\n", + "- **`make_user_prompt`**: builds the user prompt combining the conversation history.\n", + "- **`format_history`**: formats the conversation log for consistent context." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bVeKfbaK7C4z", + "metadata": { + "id": "bVeKfbaK7C4z" + }, + "outputs": [], + "source": [ + "# @title Helpers definition (click to expand)\n", + "def format_history(messages) -> str:\n", + " lines = []\n", + " for message in messages:\n", + " tag = message.category or \"MESSAGE\"\n", + " content = message.content.strip()\n", + " if not content:\n", + " continue\n", + " lines.append(f\"[{tag}] {content}\")\n", + " return \"\\n\".join(lines)\n", + "\n", + "\n", + "def make_user_prompt(prompt_text, messages) -> str:\n", + " history = format_history(messages)\n", + " # Only use messages for conversation history - the prompt is already included as the first message\n", + " history_section = history if history else \"[PROMPT] Awaiting first feedback.\"\n", + " return f\"Conversation so far:\\n{history_section}\\n\\nReply with your next guess enclosed in square brackets.\"" + ] + }, + { + "cell_type": "markdown", + "id": "i3G0x0RheYkL", + "metadata": { + "id": "i3G0x0RheYkL" + }, + "source": [ + "## Define reward functions\n", + "\n", + "To guide the agent's learning process, we define simple reward functions that map the feedback from the environment into numeric signals. \n", + "Each function corresponds to a specific aspect of the **Wordle** game:\n", + "\n", + "- āœ… **`reward_correct`**: rewards the model when it guesses the correct word (binary: 0 or 1). \n", + "- šŸŽÆ **`reward_position`**: rewards progress based on letter feedback. Green letters worth 1.0, yellow worth 0.5, normalized by 5. If the model wins, this is set to 1.0.\n", + "- šŸ“ **`reward_format_strict`**: rewards correct output format `[xxxxx]`. Returns proportion of correctly formatted outputs across all turns.\n", + "\n", + "These functions return lists of float values that the **GRPOTrainer** uses during optimization. \n", + "By combining them, the model learns to balance correctness, information gathering, and proper formatting in its guessing strategy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61e454d1-9abc-42a6-868c-a24e9801ac44", + "metadata": { + "id": "61e454d1-9abc-42a6-868c-a24e9801ac44" + }, + "outputs": [], + "source": [ + "def reward_correct(completions, **kwargs):\n", + " \"\"\"Reward from environment (correct answer).\"\"\"\n", + " rewards = kwargs.get(\"correct_reward\") if kwargs else None\n", + " if rewards is None:\n", + " return [0.0 for _ in completions]\n", + " return [float(r) for r in rewards]\n", + "\n", + "\n", + "def reward_position(completions, **kwargs):\n", + " \"\"\"Position reward: green worth 1.0, yellow worth 0.5, normalized by 5.\"\"\"\n", + " rewards = kwargs.get(\"position_reward\") if kwargs else None\n", + " if rewards is None:\n", + " return [0.0 for _ in completions]\n", + " return [float(r) for r in rewards]\n", + "\n", + "\n", + "def compute_format_reward(model_outputs):\n", + " \"\"\"Compute format reward from a list of model outputs (one per turn).\n", + "\n", + " Each output should be exactly [5 letters] with optional whitespace.\n", + " Returns proportion of correctly formatted outputs.\n", + " \"\"\"\n", + " if not model_outputs:\n", + " return 0.0\n", + "\n", + " exact_pattern = re.compile(r\"^\\s*\\[[A-Za-z]{5}\\]\\s*$\")\n", + " correct_count = sum(1 for output in model_outputs if exact_pattern.match(output))\n", + "\n", + " return correct_count / len(model_outputs)\n", + "\n", + "\n", + "def reward_format_strict(completions, **kwargs):\n", + " \"\"\"Format reward - pre-computed in rollout_func.\"\"\"\n", + " rewards = kwargs.get(\"format_reward\") if kwargs else None\n", + " if rewards is None:\n", + " return [0.0 for _ in completions]\n", + " return [float(r) for r in rewards]" + ] + }, + { + "cell_type": "markdown", + "id": "RN5VkehojyOJ", + "metadata": { + "id": "RN5VkehojyOJ" + }, + "source": [ + "## Create dataset\n", + "\n", + "We create a dataset with repeated prompts to control the number of training episodes. \n", + "Each entry in the dataset triggers one rollout episode during training. The `dataset_prompt` provides the initial instruction to the model before each game starts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "deab8040-9b51-4c52-befe-e48578cdbb53", + "metadata": { + "id": "deab8040-9b51-4c52-befe-e48578cdbb53" + }, + "outputs": [], + "source": [ + "from datasets import Dataset\n", + "\n", + "dataset_size = 3000\n", + "dataset_prompt = \"Play Wordle like an expert.\"\n", + "\n", + "dataset = Dataset.from_dict({\"prompt\": [dataset_prompt] * dataset_size})" + ] + }, + { + "cell_type": "markdown", + "id": "DnR90-D66Fm_", + "metadata": { + "id": "DnR90-D66Fm_" + }, + "source": [ + "## Set GRPO Config\n", + "\n", + "Next, we define the **GRPOConfig**, which controls all key training parameters. \n", + "This configuration specifies how the model interacts with **vLLM**, manages memory, and logs results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20ac9371-af1a-4b9e-b678-33d6a3bf07cc", + "metadata": { + "id": "20ac9371-af1a-4b9e-b678-33d6a3bf07cc" + }, + "outputs": [], + "source": [ + "from trl import GRPOConfig\n", + "\n", + "output_dir = \"wordle-grpo-Qwen3-1.7B-test\"\n", + "\n", + "grpo_config = GRPOConfig(\n", + " # Training schedule / optimization\n", + " num_train_epochs = 1, # Number of full dataset passes\n", + " learning_rate = 1e-6, # Learning rate for the optimizer\n", + " gradient_accumulation_steps = 64, # Accumulate gradients over multiple steps\n", + " per_device_train_batch_size = 1, # Batch size per GPU (number of prompts processed together)\n", + " warmup_steps = 10, # Steps for learning rate warmup\n", + " optim=\"adamw_torch\", # Optimizer\n", + " max_grad_norm=1.0, # Clip gradients to prevent explosion\n", + "\n", + " # GRPO configuration\n", + " num_generations = 2, # Number of rollout episodes per prompt (for variance reduction)\n", + " max_completion_length=1024, # Full episode length, not per-turn\n", + " log_completions = False, # Log completions for debugging\n", + "\n", + " # vLLM configuration\n", + " use_vllm = True, # Enable vLLM for faster inference during rollouts\n", + " vllm_mode = \"colocate\", # Run vLLM in colocate mode (same process as training)\n", + " vllm_gpu_memory_utilization = 0.15, # Fraction of GPU memory reserved for vLLM inference\n", + " vllm_max_model_length=3072, # Maximum context length for vLLM\n", + " vllm_importance_sampling_correction=False,\n", + "\n", + " # Logging / reporting\n", + " output_dir = output_dir, # Directory for checkpoints and logs\n", + " report_to=\"trackio\", # Experiment tracking tool (integrates with HF Spaces)\n", + " trackio_space_id = output_dir, # HF Space where experiment tracking will be saved\n", + " logging_steps = 1, # Log metrics every N steps\n", + " save_steps = 10, # Interval for saving checkpoints\n", + " save_total_limit=1, # Max number of checkpoints to save\n", + "\n", + " # Memory optimization\n", + " gradient_checkpointing = True, # Enable activation recomputation to save memory\n", + "\n", + " # Hub integration\n", + " push_to_hub = True, # Set True to automatically push model to Hugging Face Hub\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "Mrs9bAr06H2G", + "metadata": { + "id": "Mrs9bAr06H2G" + }, + "source": [ + "## Create `GRPOTrainer` and start training\n", + "\n", + "Now we initialize the `GRPOTrainer`, which manages the entire reinforcement learning loop.\n", + "\n", + "It takes the model, tokenizer, reward functions, rollout function, and dataset defined earlier. \n", + "The trainer coordinates the interaction between the model and the environment, applies the reward signals, and updates the policy.\n", + "\n", + "Finally, we call `trainer.train()` to start the fine-tuning process and let the model learn to play Wordle through feedback and iteration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "FeBMCppH7rAc", + "metadata": { + "id": "FeBMCppH7rAc" + }, + "outputs": [], + "source": [ + "import sys\n", + "sys.stdout.fileno = lambda: 1\n", + "sys.stderr.fileno = lambda: 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f7aceb9-fe9e-49ba-b976-a39c1e29d4e5", + "metadata": { + "colab": { + "referenced_widgets": [ + "f44d7bb668064bdb80e3904ff92da5ea", + "efa028ffbd704a489729c83af0647d68" + ] }, + "id": "1f7aceb9-fe9e-49ba-b976-a39c1e29d4e5", + "outputId": "aa6f81a6-662c-4215-f091-bcf422f43f9c" + }, + "outputs": [ { - "cell_type": "markdown", - "id": "rpFT3PxHT5Uc", - "metadata": { - "id": "rpFT3PxHT5Uc" - }, - "source": [ - "## Initialize the Environment\n", - "\n", - "Let's begin by setting up the environment that will be used during training. \n", - "For this task, we'll rely on the **TextArena** environment from **OpenEnv**, which exposes a familiar Gymnasium-style API (`reset()`, `step()`, etc.) to simplify interaction.\n", - "\n", - "In this example, we'll connect to the hosted environment at [sergiopaniego/wordle](https://huggingface.co/spaces/sergiopaniego/wordle). \n", - "For production use or custom configurations, we **strongly recommend** running the environment locally via Docker. The hosted versions on the Hub currently have limited concurrency support, so duplicating the Space to your own account is the preferred approach in those cases.\n", - "\n", - "For more information, refer to the [TRL-OpenEnv documentation](https://huggingface.co/docs/trl/main/en/openenv).\n" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "id": "rZimqp1UTIV_", - "metadata": { - "id": "rZimqp1UTIV_", - "outputId": "e53c277c-6050-4380-84e1-983857f0b325" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f44d7bb668064bdb80e3904ff92da5ea", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n" - ] - } - ], - "source": [ - "from textarena_env import TextArenaEnv\n", - "\n", - "wordle_url = \"https://sergiopaniego-wordle.hf.space\" # Duplicate the Space and update this!\n", - "env = TextArenaEnv(base_url=wordle_url)\n", - "# wordle_url = \"sergiopaniego/wordle\"\n", - "# env = TextArenaEnv.from_hub(repo_id=wordle_url)" + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "id": "5c585602-5352-4e57-8d35-e5b95e05f6c5", - "metadata": { - "id": "5c585602-5352-4e57-8d35-e5b95e05f6c5" - }, - "outputs": [], - "source": [ - "import re\n", - "from textarena_env import TextArenaAction\n", - "from textarena_env.rewards import extract_feedback_counts, extract_guess, extract_wordle_feedback\n", - "from trl.experimental.openenv import generate_rollout_completions\n", - "\n", - "def rollout_once(trainer, env, tokenizer, dataset_prompt, system_prompt, max_turns, max_new_tokens):\n", - " result = env.reset()\n", - " observation = result.observation\n", - "\n", - " prompt_ids = []\n", - " completion_ids = []\n", - " logprobs = []\n", - " env_mask = [] # 1 for model-generated tokens, 0 for environment tokens\n", - " model_outputs = []\n", - " raw_rewards = []\n", - " position_scores = []\n", - " correct_scores = []\n", - " prev_env_output_len = 0 # Track length to only add NEW portion each turn\n", - "\n", - " accumulated_messages: list[dict[str, str]] = [{\"role\": \"system\", \"content\": system_prompt}]\n", - " # Build initial prompt (only once, at the start)\n", - " # The initial env messages are included in the prompt, not completion\n", - " base_prompt = observation.prompt or dataset_prompt\n", - " initial_user_prompt = make_user_prompt(base_prompt, observation.messages)\n", - " # Track initial env output length so we don't add it again\n", - " initial_env_output = format_history(observation.messages) if observation.messages else \"\"\n", - " prev_env_output_len = len(initial_env_output)\n", - " initial_messages = accumulated_messages + [{\"role\": \"user\", \"content\": initial_user_prompt}]\n", - " initial_prompt_text = tokenizer.apply_chat_template(\n", - " initial_messages,\n", - " add_generation_prompt=True,\n", - " tokenize=False,\n", - " enable_thinking=False,\n", - " )\n", - " # Tokenize initial prompt once - this is the base prompt for the entire episode.\n", - " # GRPO expects one prompt-completion pair per episode, where:\n", - " # - prompt_ids = the initial/base prompt (what the model sees at episode start)\n", - " # - completion_ids = all model responses + env feedback from all turns concatenated\n", - " # Note: The actual prompts used for generation in each turn are longer (include conversation history),\n", - " # but we only count the initial prompt tokens here.\n", - " initial_prompt_ids = tokenizer.encode(initial_prompt_text, add_special_tokens=False)\n", - " prompt_ids.extend(initial_prompt_ids)\n", - "\n", - " for _turn in range(max_turns):\n", - " if result.done:\n", - " break\n", - "\n", - " base_prompt = observation.prompt or dataset_prompt\n", - " user_prompt = make_user_prompt(base_prompt, observation.messages)\n", - " messages = accumulated_messages + [{\"role\": \"user\", \"content\": user_prompt}]\n", - " prompt_text = tokenizer.apply_chat_template(\n", - " messages,\n", - " add_generation_prompt=True,\n", - " tokenize=False,\n", - " enable_thinking=False,\n", - " )\n", - "\n", - " rollout_outputs = generate_rollout_completions(\n", - " trainer, [prompt_text], generation_overrides={\"max_tokens\": max_new_tokens}\n", - " )[0]\n", - " # Add model-generated completion tokens and logprobs with newlines for readability\n", - " newline_tokens = tokenizer.encode(\"\\n\", add_special_tokens=False)\n", - " completion_ids.extend(newline_tokens) # newline before guess\n", - " logprobs.extend([0.0] * len(newline_tokens))\n", - " env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format\n", - "\n", - " completion_ids.extend(rollout_outputs[\"completion_ids\"])\n", - " logprobs.extend(rollout_outputs[\"logprobs\"])\n", - " env_mask.extend([1] * len(rollout_outputs[\"completion_ids\"])) # model-generated tokens\n", - "\n", - " completion_ids.extend(newline_tokens) # newline after guess\n", - " logprobs.extend([0.0] * len(newline_tokens))\n", - " env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format\n", - " completion_text = rollout_outputs.get(\"text\") or tokenizer.decode(\n", - " rollout_outputs[\"completion_ids\"], skip_special_tokens=True\n", - " )\n", - " guess = extract_guess(completion_text)\n", - " model_outputs.append(completion_text.strip()) # Store raw model output for format reward\n", - "\n", - " result = env.step(TextArenaAction(message=guess))\n", - "\n", - " raw_rewards.append(float(result.reward or 0.0))\n", - " observation = result.observation\n", - " correct_score = float(result.reward or 0.0)\n", - " feedback = extract_wordle_feedback(observation)\n", - "\n", - " full_env_output = format_history(observation.messages) if observation.messages else \"\"\n", - " new_env_output = full_env_output[prev_env_output_len:].lstrip(\"\\n\")\n", - " prev_env_output_len = len(full_env_output)\n", - "\n", - " if new_env_output:\n", - " env_output_tokens = tokenizer.encode(new_env_output, add_special_tokens=False)\n", - " completion_ids.extend(env_output_tokens) # Add to completion_ids\n", - " logprobs.extend([0.0] * len(env_output_tokens)) # Placeholder (ignored via env_mask=0)\n", - " env_mask.extend([0] * len(env_output_tokens)) # Environment tokens - mask out from loss\n", - " completion_with_env = completion_text + \"\\n\" + new_env_output\n", - " else:\n", - " completion_with_env = completion_text\n", - "\n", - " accumulated_messages.append({\"role\": \"user\", \"content\": user_prompt})\n", - " accumulated_messages.append({\"role\": \"assistant\", \"content\": completion_with_env})\n", - "\n", - " if not feedback:\n", - " position_score = 0.0\n", - " else:\n", - " green_count, yellow_count = extract_feedback_counts(feedback)\n", - " position_score = (green_count + 0.5 * yellow_count) / 5.0\n", - "\n", - " position_scores.append(position_score)\n", - " correct_scores.append(correct_score)\n", - "\n", - " # Use the final correct reward (win/lose is binary at end)\n", - " correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0)\n", - "\n", - " # Position reward as shaping signal:\n", - " # - If model WINS: position_reward = 1.0 (no penalty for winning fast)\n", - " # - If model LOSES: position_reward = last attempt (where it ended up)\n", - " if correct_reward_value >= 1.0:\n", - " final_position_reward = 1.0\n", - " else:\n", - " final_position_reward = position_scores[-1] if position_scores else 0.0\n", - "\n", - " return {\n", - " \"prompt_ids\": prompt_ids,\n", - " \"completion_ids\": completion_ids,\n", - " \"logprobs\": logprobs,\n", - " \"env_mask\": env_mask,\n", - " \"raw_rewards\": raw_rewards,\n", - " \"correct_reward\": correct_reward_value,\n", - " \"position_reward\": final_position_reward,\n", - " \"model_outputs\": model_outputs,\n", - " }" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "* GPU detected, enabling automatic GPU metrics logging\n", + "* Created new run: sergiopaniego-1770031943\n" + ] }, { - "cell_type": "markdown", - "id": "cipvIDzcoF3C", - "metadata": { - "id": "cipvIDzcoF3C" - }, - "source": [ - "### Helper functions\n", - "\n", - "Supporting utilities used in `rollout_once`:\n", - "\n", - "- **`make_user_prompt`**: builds the user prompt combining the conversation history.\n", - "- **`format_history`**: formats the conversation log for consistent context." + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [93/93 3:18:36, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
10.009800
20.016400
30.005600
40.014700
50.019500
60.002300
70.005300
80.025100
90.004500
100.004200
110.009600
120.014900
130.024500
140.012200
150.015500
160.007400
170.017500
180.014900
190.035600
200.014900
210.030000
220.014300
230.018000
240.014000
250.016600
260.015600
270.021300
280.021000
290.036900
300.006400
310.044800
320.026400
330.038700
340.022000
350.013400
360.025000
370.042900
380.072700
390.070100
400.019900
410.058700
420.060100
43-0.026700
440.038900
450.042400
46-0.009100
470.001300
480.020200
490.078700
500.026300
510.045700
520.035300
53-0.006700
540.025300
550.069500
560.092800
570.067900
580.035000
590.061300
600.048800
610.000600
620.028400
630.016200
640.010700
650.020200
660.041800
670.006800
680.014800
690.025100
70-0.006600
710.041000
720.008300
730.045300
740.062800
750.048200
760.032800
770.053000
780.023100
790.014900
800.078200
81-0.000700
820.013400
830.030200
84-0.003600
850.051700
860.033500
870.021800
88-0.003400
890.023200
90-0.002900
910.030900
920.029200
930.002500

" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "id": "bVeKfbaK7C4z", - "metadata": { - "id": "bVeKfbaK7C4z" - }, - "outputs": [], - "source": [ - "# @title Helpers definition (click to expand)\n", - "def format_history(messages) -> str:\n", - " lines = []\n", - " for message in messages:\n", - " tag = message.category or \"MESSAGE\"\n", - " content = message.content.strip()\n", - " if not content:\n", - " continue\n", - " lines.append(f\"[{tag}] {content}\")\n", - " return \"\\n\".join(lines)\n", - "\n", - "\n", - "def make_user_prompt(prompt_text, messages) -> str:\n", - " history = format_history(messages)\n", - " # Only use messages for conversation history - the prompt is already included as the first message\n", - " history_section = history if history else \"[PROMPT] Awaiting first feedback.\"\n", - " return f\"Conversation so far:\\n{history_section}\\n\\nReply with your next guess enclosed in square brackets.\"" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] }, { - "cell_type": "markdown", - "id": "i3G0x0RheYkL", - "metadata": { - "id": "i3G0x0RheYkL" - }, - "source": [ - "## Define reward functions\n", - "\n", - "To guide the agent's learning process, we define simple reward functions that map the feedback from the environment into numeric signals. \n", - "Each function corresponds to a specific aspect of the **Wordle** game:\n", - "\n", - "- āœ… **`reward_correct`**: rewards the model when it guesses the correct word (binary: 0 or 1). \n", - "- šŸŽÆ **`reward_position`**: rewards progress based on letter feedback. Green letters worth 1.0, yellow worth 0.5, normalized by 5. If the model wins, this is set to 1.0.\n", - "- šŸ“ **`reward_format_strict`**: rewards correct output format `[xxxxx]`. Returns proportion of correctly formatted outputs across all turns.\n", - "\n", - "These functions return lists of float values that the **GRPOTrainer** uses during optimization. \n", - "By combining them, the model learns to balance correctness, information gathering, and proper formatting in its guessing strategy." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "* Run finished. Uploading logs to Trackio (please wait...)\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "id": "61e454d1-9abc-42a6-868c-a24e9801ac44", - "metadata": { - "id": "61e454d1-9abc-42a6-868c-a24e9801ac44" - }, - "outputs": [], - "source": [ - "def reward_correct(completions, **kwargs):\n", - " \"\"\"Reward from environment (correct answer).\"\"\"\n", - " rewards = kwargs.get(\"correct_reward\") if kwargs else None\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]\n", - "\n", - "\n", - "def reward_position(completions, **kwargs):\n", - " \"\"\"Position reward: green worth 1.0, yellow worth 0.5, normalized by 5.\"\"\"\n", - " rewards = kwargs.get(\"position_reward\") if kwargs else None\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]\n", - "\n", - "\n", - "def compute_format_reward(model_outputs):\n", - " \"\"\"Compute format reward from a list of model outputs (one per turn).\n", - "\n", - " Each output should be exactly [5 letters] with optional whitespace.\n", - " Returns proportion of correctly formatted outputs.\n", - " \"\"\"\n", - " if not model_outputs:\n", - " return 0.0\n", - "\n", - " exact_pattern = re.compile(r\"^\\s*\\[[A-Za-z]{5}\\]\\s*$\")\n", - " correct_count = sum(1 for output in model_outputs if exact_pattern.match(output))\n", - "\n", - " return correct_count / len(model_outputs)\n", - "\n", - "\n", - "def reward_format_strict(completions, **kwargs):\n", - " \"\"\"Format reward - pre-computed in rollout_func.\"\"\"\n", - " rewards = kwargs.get(\"format_reward\") if kwargs else None\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]" - ] - }, + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] + } + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "id": "o-hEO4oK4ZXr", + "metadata": { + "id": "o-hEO4oK4ZXr" + }, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "zuHTwuxAVp8p", + "metadata": { + "id": "zuHTwuxAVp8p", + "outputId": "fce9bdc8-d734-4382-bb26-7e03dbffa7a0" + }, + "outputs": [ { - "cell_type": "markdown", - "id": "RN5VkehojyOJ", - "metadata": { - "id": "RN5VkehojyOJ" - }, - "source": [ - "## Create dataset\n", - "\n", - "We create a dataset with repeated prompts to control the number of training episodes. \n", - "Each entry in the dataset triggers one rollout episode during training. The `dataset_prompt` provides the initial instruction to the model before each game starts." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "12065.8973 seconds used for training.\n", + "201.1 minutes used for training.\n", + "Peak reserved memory = 38.139 GB.\n", + "Peak reserved memory for training = 25.655 GB.\n", + "Peak reserved memory % of max memory = 96.415 %.\n", + "Peak reserved memory for training % of max memory = 64.856 %.\n" + ] + } + ], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_training = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "training_memory_percentage = round(used_memory_for_training / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_training} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {training_memory_percentage} %.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13e9fd4e-e7a5-468d-a25a-3f7d2794201f", + "metadata": { + "colab": { + "referenced_widgets": [ + "decd9f00c4da42bf92b72c327bd28278", + "2d924050f7bf4e7f88316c8fc202a763", + "d589783221084eb7833ae6cd742d277c", + "0e135c821b5744b287b4de7eeb15d419", + "a1839712ff344a409e6f7f48a1467fd5", + "e9ae0fcd43e34d7e916fe1bda0a38a49", + "75776d6523ef42df930ddfd7048b384e", + "e2e07a449d914bd39653b7cbbc5903e3", + "0eafc3f9bac14807866233f924793380", + "b64c487a9dff4108a66da9eee4e4ed66", + "17a3ba38cf7349269ea54df84faf30b7", + "7382295b99ee4db28de43e1451dd0d17" + ] }, + "id": "13e9fd4e-e7a5-468d-a25a-3f7d2794201f", + "outputId": "7f703ed8-7874-4da1-8490-48222755ae11" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "id": "deab8040-9b51-4c52-befe-e48578cdbb53", - "metadata": { - "id": "deab8040-9b51-4c52-befe-e48578cdbb53" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "decd9f00c4da42bf92b72c327bd28278", + "version_major": 2, + "version_minor": 0 }, - "outputs": [], - "source": [ - "from datasets import Dataset\n", - "\n", - "dataset_size = 3000\n", - "dataset_prompt = \"Play Wordle like an expert.\"\n", - "\n", - "dataset = Dataset.from_dict({\"prompt\": [dataset_prompt] * dataset_size})" + "text/plain": [ + "Processing Files (0 / 0) : | | 0.00B / 0.00B " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "id": "DnR90-D66Fm_", - "metadata": { - "id": "DnR90-D66Fm_" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2d924050f7bf4e7f88316c8fc202a763", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "## Set GRPO Config\n", - "\n", - "Next, we define the **GRPOConfig**, which controls all key training parameters. \n", - "This configuration specifies how the model interacts with **vLLM**, manages memory, and logs results." + "text/plain": [ + "New Data Upload : | | 0.00B / 0.00B " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "id": "20ac9371-af1a-4b9e-b678-33d6a3bf07cc", - "metadata": { - "id": "20ac9371-af1a-4b9e-b678-33d6a3bf07cc" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d589783221084eb7833ae6cd742d277c", + "version_major": 2, + "version_minor": 0 }, - "outputs": [], - "source": [ - "from trl import GRPOConfig\n", - "\n", - "output_dir = \"wordle-grpo-Qwen3-1.7B-test\"\n", - "\n", - "grpo_config = GRPOConfig(\n", - " # Training schedule / optimization\n", - " num_train_epochs = 1, # Number of full dataset passes\n", - " learning_rate = 1e-6, # Learning rate for the optimizer\n", - " gradient_accumulation_steps = 64, # Accumulate gradients over multiple steps\n", - " per_device_train_batch_size = 1, # Batch size per GPU (number of prompts processed together)\n", - " warmup_steps = 10, # Steps for learning rate warmup\n", - " optim=\"adamw_torch\", # Optimizer\n", - " max_grad_norm=1.0, # Clip gradients to prevent explosion\n", - "\n", - " # GRPO configuration\n", - " num_generations = 2, # Number of rollout episodes per prompt (for variance reduction)\n", - " max_completion_length=1024, # Full episode length, not per-turn\n", - " log_completions = False, # Log completions for debugging\n", - "\n", - " # vLLM configuration\n", - " use_vllm = True, # Enable vLLM for faster inference during rollouts\n", - " vllm_mode = \"colocate\", # Run vLLM in colocate mode (same process as training)\n", - " vllm_gpu_memory_utilization = 0.15, # Fraction of GPU memory reserved for vLLM inference\n", - " vllm_max_model_length=3072, # Maximum context length for vLLM\n", - " vllm_importance_sampling_correction=False,\n", - "\n", - " # Logging / reporting\n", - " output_dir = output_dir, # Directory for checkpoints and logs\n", - " report_to=\"trackio\", # Experiment tracking tool (integrates with HF Spaces)\n", - " trackio_space_id = output_dir, # HF Space where experiment tracking will be saved\n", - " logging_steps = 1, # Log metrics every N steps\n", - " save_steps = 10, # Interval for saving checkpoints\n", - " save_total_limit=1, # Max number of checkpoints to save\n", - "\n", - " # Memory optimization\n", - " gradient_checkpointing = True, # Enable activation recomputation to save memory\n", - "\n", - " # Hub integration\n", - " push_to_hub = True, # Set True to automatically push model to Hugging Face Hub\n", - ")" + "text/plain": [ + " ...7B-test/training_args.bin: 100%|##########| 7.70kB / 7.70kB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "id": "Mrs9bAr06H2G", - "metadata": { - "id": "Mrs9bAr06H2G" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0e135c821b5744b287b4de7eeb15d419", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "## Create `GRPOTrainer` and start training\n", - "\n", - "Now we initialize the `GRPOTrainer`, which manages the entire reinforcement learning loop.\n", - "\n", - "It takes the model, tokenizer, reward functions, rollout function, and dataset defined earlier. \n", - "The trainer coordinates the interaction between the model and the environment, applies the reward signals, and updates the policy.\n", - "\n", - "Finally, we call `trainer.train()` to start the fine-tuning process and let the model learn to play Wordle through feedback and iteration." + "text/plain": [ + " ...-1.7B-test/tokenizer.json: 100%|##########| 11.4MB / 11.4MB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "id": "FeBMCppH7rAc", - "metadata": { - "id": "FeBMCppH7rAc" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a1839712ff344a409e6f7f48a1467fd5", + "version_major": 2, + "version_minor": 0 }, - "outputs": [], - "source": [ - "import sys\n", - "sys.stdout.fileno = lambda: 1\n", - "sys.stderr.fileno = lambda: 2" + "text/plain": [ + " ...0002-of-00002.safetensors: 2%|1 | 33.5MB / 1.91GB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "id": "1f7aceb9-fe9e-49ba-b976-a39c1e29d4e5", - "metadata": { - "colab": { - "referenced_widgets": [ - "f44d7bb668064bdb80e3904ff92da5ea", - "efa028ffbd704a489729c83af0647d68" - ] - }, - "id": "1f7aceb9-fe9e-49ba-b976-a39c1e29d4e5", - "outputId": "aa6f81a6-662c-4215-f091-bcf422f43f9c" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e9ae0fcd43e34d7e916fe1bda0a38a49", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f44d7bb668064bdb80e3904ff92da5ea", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* GPU detected, enabling automatic GPU metrics logging\n", - "* Created new run: sergiopaniego-1770031943\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "

\n", - " \n", - " \n", - " [93/93 3:18:36, Epoch 1/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
10.009800
20.016400
30.005600
40.014700
50.019500
60.002300
70.005300
80.025100
90.004500
100.004200
110.009600
120.014900
130.024500
140.012200
150.015500
160.007400
170.017500
180.014900
190.035600
200.014900
210.030000
220.014300
230.018000
240.014000
250.016600
260.015600
270.021300
280.021000
290.036900
300.006400
310.044800
320.026400
330.038700
340.022000
350.013400
360.025000
370.042900
380.072700
390.070100
400.019900
410.058700
420.060100
43-0.026700
440.038900
450.042400
46-0.009100
470.001300
480.020200
490.078700
500.026300
510.045700
520.035300
53-0.006700
540.025300
550.069500
560.092800
570.067900
580.035000
590.061300
600.048800
610.000600
620.028400
630.016200
640.010700
650.020200
660.041800
670.006800
680.014800
690.025100
70-0.006600
710.041000
720.008300
730.045300
740.062800
750.048200
760.032800
770.053000
780.023100
790.014900
800.078200
81-0.000700
820.013400
830.030200
84-0.003600
850.051700
860.033500
870.021800
88-0.003400
890.023200
90-0.002900
910.030900
920.029200
930.002500

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n", - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n", - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n", - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n", - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n", - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n", - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n", - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n", - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n", - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", - " return datetime.utcnow().replace(tzinfo=utc)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" + "text/plain": [ + " ...7B-test/training_args.bin: 100%|##########| 7.70kB / 7.70kB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "id": "o-hEO4oK4ZXr", - "metadata": { - "id": "o-hEO4oK4ZXr" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b64c487a9dff4108a66da9eee4e4ed66", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "Show memory stats after training" + "text/plain": [ + " ...-1.7B-test/tokenizer.json: 100%|##########| 11.4MB / 11.4MB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "id": "zuHTwuxAVp8p", - "metadata": { - "id": "zuHTwuxAVp8p", - "outputId": "fce9bdc8-d734-4382-bb26-7e03dbffa7a0" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "17a3ba38cf7349269ea54df84faf30b7", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12065.8973 seconds used for training.\n", - "201.1 minutes used for training.\n", - "Peak reserved memory = 38.139 GB.\n", - "Peak reserved memory for training = 25.655 GB.\n", - "Peak reserved memory % of max memory = 96.415 %.\n", - "Peak reserved memory for training % of max memory = 64.856 %.\n" - ] - } - ], - "source": [ - "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "used_memory_for_training = round(used_memory - start_gpu_memory, 3)\n", - "used_percentage = round(used_memory / max_memory * 100, 3)\n", - "training_memory_percentage = round(used_memory_for_training / max_memory * 100, 3)\n", - "\n", - "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", - "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", - "print(f\"Peak reserved memory = {used_memory} GB.\")\n", - "print(f\"Peak reserved memory for training = {used_memory_for_training} GB.\")\n", - "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", - "print(f\"Peak reserved memory for training % of max memory = {training_memory_percentage} %.\")" + "text/plain": [ + " ...0001-of-00002.safetensors: 1%| | 33.5MB / 4.97GB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "id": "13e9fd4e-e7a5-468d-a25a-3f7d2794201f", - "metadata": { - "colab": { - "referenced_widgets": [ - "decd9f00c4da42bf92b72c327bd28278", - "2d924050f7bf4e7f88316c8fc202a763", - "d589783221084eb7833ae6cd742d277c", - "0e135c821b5744b287b4de7eeb15d419", - "a1839712ff344a409e6f7f48a1467fd5", - "e9ae0fcd43e34d7e916fe1bda0a38a49", - "75776d6523ef42df930ddfd7048b384e", - "e2e07a449d914bd39653b7cbbc5903e3", - "0eafc3f9bac14807866233f924793380", - "b64c487a9dff4108a66da9eee4e4ed66", - "17a3ba38cf7349269ea54df84faf30b7", - "7382295b99ee4db28de43e1451dd0d17" - ] - }, - "id": "13e9fd4e-e7a5-468d-a25a-3f7d2794201f", - "outputId": "7f703ed8-7874-4da1-8490-48222755ae11" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7382295b99ee4db28de43e1451dd0d17", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "decd9f00c4da42bf92b72c327bd28278", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing Files (0 / 0) : | | 0.00B / 0.00B " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2d924050f7bf4e7f88316c8fc202a763", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "New Data Upload : | | 0.00B / 0.00B " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d589783221084eb7833ae6cd742d277c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...7B-test/training_args.bin: 100%|##########| 7.70kB / 7.70kB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0e135c821b5744b287b4de7eeb15d419", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...-1.7B-test/tokenizer.json: 100%|##########| 11.4MB / 11.4MB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a1839712ff344a409e6f7f48a1467fd5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...0002-of-00002.safetensors: 2%|1 | 33.5MB / 1.91GB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e9ae0fcd43e34d7e916fe1bda0a38a49", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...0001-of-00002.safetensors: 1%| | 33.5MB / 4.97GB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "No files have been modified since last commit. Skipping to prevent empty commit.\n", - "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "75776d6523ef42df930ddfd7048b384e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing Files (0 / 0) : | | 0.00B / 0.00B " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e2e07a449d914bd39653b7cbbc5903e3", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "New Data Upload : | | 0.00B / 0.00B " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0eafc3f9bac14807866233f924793380", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...7B-test/training_args.bin: 100%|##########| 7.70kB / 7.70kB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b64c487a9dff4108a66da9eee4e4ed66", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...-1.7B-test/tokenizer.json: 100%|##########| 11.4MB / 11.4MB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "17a3ba38cf7349269ea54df84faf30b7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...0001-of-00002.safetensors: 1%| | 33.5MB / 4.97GB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7382295b99ee4db28de43e1451dd0d17", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...0002-of-00002.safetensors: 2%|1 | 33.5MB / 1.91GB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "No files have been modified since last commit. Skipping to prevent empty commit.\n", - "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n" - ] - }, - { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/wordle-grpo-Qwen3-1.7B-test/commit/2d7a27066ef244796a079cbf08fa6656af426145', commit_message='End of training', commit_description='', oid='2d7a27066ef244796a079cbf08fa6656af426145', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/wordle-grpo-Qwen3-1.7B-test', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/wordle-grpo-Qwen3-1.7B-test'), pr_revision=None, pr_num=None)" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "env.close()\n", - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub()" + "text/plain": [ + " ...0002-of-00002.safetensors: 2%|1 | 33.5MB / 1.91GB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "id": "wQyVb1nAxWld", - "metadata": { - "id": "wQyVb1nAxWld" - }, - "source": [ - "## Load the Fine-Tuned Model and Run Inference\n", - "\n", - "Now let's test our fine-tuned model by loading the **adapter** and running **inference**. \n", - "We begin by loading the **base model**, attaching the adapter, and obtaining the final fine-tuned model ready for evaluation." - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "No files have been modified since last commit. Skipping to prevent empty commit.\n", + "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "id": "JcTeeSBXxWWF", - "metadata": { - "colab": { - "referenced_widgets": [ - "281b1cf074fd4d60bb754906a0764865", - "e129fb465f1a41c1bdf2495d14143458" - ] - }, - "id": "JcTeeSBXxWWF", - "outputId": "86efafc3-1161-471b-86b1-14c43e95908f" + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning: \n", - "Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.\n", - "You are not authenticated with the Hugging Face Hub in this notebook.\n", - "If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "281b1cf074fd4d60bb754906a0764865", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Fetching 2 files: 0%| | 0/2 [00:00 {generated_text}\")\n", - " print(f\" Parsed guess: {guess}\")\n", - "\n", - " result = env.step(TextArenaAction(message=guess))\n", - " observation = result.observation\n", - "\n", - " print(\" Feedback messages:\")\n", - " for message in observation.messages:\n", - " print(f\" [{message.category}] {message.content}\")\n", - "\n", - " print(\"\\nāœ… Game finished\")\n", - " print(f\" Reward: {result.reward}\")\n", - " print(f\" Done: {result.done}\")" + "text/plain": [ + "Fetching 2 files: 0%| | 0/2 [00:00 {generated_text}\")\n", + " print(f\" Parsed guess: {guess}\")\n", + "\n", + " result = env.step(TextArenaAction(message=guess))\n", + " observation = result.observation\n", + "\n", + " print(\" Feedback messages:\")\n", + " for message in observation.messages:\n", + " print(f\" [{message.category}] {message.content}\")\n", + "\n", + " print(\"\\nāœ… Game finished\")\n", + " print(f\" Reward: {result.reward}\")\n", + " print(f\" Done: {result.done}\")" + ] + }, + { + "cell_type": "markdown", + "id": "MjIxHOHK4PVe", + "metadata": { + "id": "MjIxHOHK4PVe" + }, + "source": [ + "Let's play the game!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "JjOzWexUXmfW", + "metadata": { + "id": "JjOzWexUXmfW", + "outputId": "1c6130af-fe89-4930-e53a-7329e0483ef0" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "id": "JjOzWexUXmfW", - "metadata": { - "id": "JjOzWexUXmfW", - "outputId": "1c6130af-fe89-4930-e53a-7329e0483ef0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "šŸ“œ Initial Prompt:\n", - "You are Player 0 in Wordle.\n", - "A secret 5-letter word has been chosen. You have 6 attempts to guess it.\n", - "For each guess, wrap your word in square brackets (e.g., [apple]).\n", - "Feedback for each letter will be given as follows:\n", - " - G (green): correct letter in the correct position\n", - " - Y (yellow): letter exists in the word but in the wrong position\n", - " - X (wrong): letter is not in the word\n", - "Enter your guess to begin.\n", - "\n", - "šŸŽÆ Turn 0: model replied with -> [crane]\n", - " Parsed guess: [crane]\n", - " Feedback messages:\n", - " [MESSAGE] [crane]\n", - " [MESSAGE] Player 0 submitted [crane].\n", - "Feedback:\n", - "C R A N E\n", - "X Y X X X\n", - "\n", - "You have 5 guesses left.\n", - "\n", - "šŸŽÆ Turn 1: model replied with -> [spare]\n", - " Parsed guess: [spare]\n", - " Feedback messages:\n", - " [MESSAGE] [spare]\n", - " [MESSAGE] Player 0 submitted [spare].\n", - "Feedback:\n", - "C R A N E\n", - "X Y X X X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "You have 4 guesses left.\n", - "\n", - "šŸŽÆ Turn 2: model replied with -> [spare]\n", - " Parsed guess: [spare]\n", - " Feedback messages:\n", - " [MESSAGE] [spare]\n", - " [MESSAGE] Player 0 submitted [spare].\n", - "Feedback:\n", - "C R A N E\n", - "X Y X X X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "You have 3 guesses left.\n", - "\n", - "šŸŽÆ Turn 3: model replied with -> [spare]\n", - " Parsed guess: [spare]\n", - " Feedback messages:\n", - " [MESSAGE] [spare]\n", - " [MESSAGE] Player 0 submitted [spare].\n", - "Feedback:\n", - "C R A N E\n", - "X Y X X X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "You have 2 guesses left.\n", - "\n", - "šŸŽÆ Turn 4: model replied with -> [spare]\n", - " Parsed guess: [spare]\n", - " Feedback messages:\n", - " [MESSAGE] [spare]\n", - " [MESSAGE] Player 0 submitted [spare].\n", - "Feedback:\n", - "C R A N E\n", - "X Y X X X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "You have 1 guesses left.\n", - "\n", - "šŸŽÆ Turn 5: model replied with -> [spare]\n", - " Parsed guess: [spare]\n", - " Feedback messages:\n", - " [MESSAGE] [spare]\n", - " [MESSAGE] Player 0 submitted [spare].\n", - "Feedback:\n", - "C R A N E\n", - "X Y X X X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "S P A R E\n", - "G X X G X\n", - "\n", - "You have 0 guesses left.\n", - " [MESSAGE] The game ended in a draw. Reason: Turn limit reached.\n", - "\n", - "āœ… Game finished\n", - " Reward: 0.0\n", - " Done: True\n" - ] - } - ], - "source": [ - "try:\n", - " play_wordle(env, fine_tuned_model, tokenizer)\n", - "finally:\n", - " env.close()" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "šŸ“œ Initial Prompt:\n", + "You are Player 0 in Wordle.\n", + "A secret 5-letter word has been chosen. You have 6 attempts to guess it.\n", + "For each guess, wrap your word in square brackets (e.g., [apple]).\n", + "Feedback for each letter will be given as follows:\n", + " - G (green): correct letter in the correct position\n", + " - Y (yellow): letter exists in the word but in the wrong position\n", + " - X (wrong): letter is not in the word\n", + "Enter your guess to begin.\n", + "\n", + "šŸŽÆ Turn 0: model replied with -> [crane]\n", + " Parsed guess: [crane]\n", + " Feedback messages:\n", + " [MESSAGE] [crane]\n", + " [MESSAGE] Player 0 submitted [crane].\n", + "Feedback:\n", + "C R A N E\n", + "X Y X X X\n", + "\n", + "You have 5 guesses left.\n", + "\n", + "šŸŽÆ Turn 1: model replied with -> [spare]\n", + " Parsed guess: [spare]\n", + " Feedback messages:\n", + " [MESSAGE] [spare]\n", + " [MESSAGE] Player 0 submitted [spare].\n", + "Feedback:\n", + "C R A N E\n", + "X Y X X X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "You have 4 guesses left.\n", + "\n", + "šŸŽÆ Turn 2: model replied with -> [spare]\n", + " Parsed guess: [spare]\n", + " Feedback messages:\n", + " [MESSAGE] [spare]\n", + " [MESSAGE] Player 0 submitted [spare].\n", + "Feedback:\n", + "C R A N E\n", + "X Y X X X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "You have 3 guesses left.\n", + "\n", + "šŸŽÆ Turn 3: model replied with -> [spare]\n", + " Parsed guess: [spare]\n", + " Feedback messages:\n", + " [MESSAGE] [spare]\n", + " [MESSAGE] Player 0 submitted [spare].\n", + "Feedback:\n", + "C R A N E\n", + "X Y X X X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "You have 2 guesses left.\n", + "\n", + "šŸŽÆ Turn 4: model replied with -> [spare]\n", + " Parsed guess: [spare]\n", + " Feedback messages:\n", + " [MESSAGE] [spare]\n", + " [MESSAGE] Player 0 submitted [spare].\n", + "Feedback:\n", + "C R A N E\n", + "X Y X X X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "You have 1 guesses left.\n", + "\n", + "šŸŽÆ Turn 5: model replied with -> [spare]\n", + " Parsed guess: [spare]\n", + " Feedback messages:\n", + " [MESSAGE] [spare]\n", + " [MESSAGE] Player 0 submitted [spare].\n", + "Feedback:\n", + "C R A N E\n", + "X Y X X X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "S P A R E\n", + "G X X G X\n", + "\n", + "You have 0 guesses left.\n", + " [MESSAGE] The game ended in a draw. Reason: Turn limit reached.\n", + "\n", + "āœ… Game finished\n", + " Reward: 0.0\n", + " Done: True\n" + ] } - ], - "metadata": { - "language_info": { - "name": "python" - }, - "colab": { - "provenance": [], - "gpuType": "A100" - }, - "accelerator": "GPU" + ], + "source": [ + "try:\n", + " play_wordle(env, fine_tuned_model, tokenizer)\n", + "finally:\n", + " env.close()" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "colab": { + "provenance": [], + "gpuType": "A100" }, - "nbformat": 4, - "nbformat_minor": 5 + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 5 } \ No newline at end of file diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py index f52a7a1850d..6ec164ccd23 100644 --- a/examples/scripts/openenv/echo.py +++ b/examples/scripts/openenv/echo.py @@ -211,7 +211,6 @@ def main(): def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: outputs = generate_rollout_completions(trainer, prompts) tokenizer = trainer.processing_class - completions_text = [tokenizer.decode(output["completion_ids"], skip_special_tokens=True) for output in outputs] env_result = client.reset() diff --git a/examples/scripts/openenv/wordle.py b/examples/scripts/openenv/wordle.py index e35802a12af..cf025ab58ea 100644 --- a/examples/scripts/openenv/wordle.py +++ b/examples/scripts/openenv/wordle.py @@ -105,7 +105,7 @@ from transformers import AutoTokenizer from trl import GRPOConfig, GRPOTrainer -from trl.experimental.openenv import generate_rollout_completions +from trl.experimental.openenv import generate_rollout_completions, get_rubric_scores # Ensure src/ is on the path @@ -325,7 +325,6 @@ def rollout_once( logprobs: list[float] = [] env_mask: list[int] = [] # 1 for model-generated tokens, 0 for environment tokens model_outputs: list[str] = [] - raw_rewards: list[float] = [] position_scores: list[float] = [] correct_scores: list[float] = [] prev_env_output_len: int = 0 # Track length to only add NEW portion each turn @@ -392,10 +391,23 @@ def rollout_once( result = env.step(TextArenaAction(message=guess)) - raw_rewards.append(float(result.reward or 0.0)) observation = result.observation correct_score = float(result.reward or 0.0) - feedback = extract_wordle_feedback(observation) + + # Calculate position score (greens worth 1.0, yellows worth 0.5) + rubric_scores = get_rubric_scores(observation=observation) + if rubric_scores: + position_score = rubric_scores.get("wordle.greens", 0.0) + 0.5 * rubric_scores.get("wordle.yellows", 0.0) + else: + feedback = extract_wordle_feedback(observation) + if not feedback: + position_score = 0.0 + else: + green_count, yellow_count = extract_feedback_counts(feedback) + position_score = (green_count + 0.5 * yellow_count) / 5.0 + + position_scores.append(position_score) + correct_scores.append(correct_score) full_env_output = format_history(observation.messages) if observation.messages else "" new_env_output = full_env_output[prev_env_output_len:].lstrip("\n") @@ -413,37 +425,28 @@ def rollout_once( accumulated_messages.append({"role": "user", "content": user_prompt}) accumulated_messages.append({"role": "assistant", "content": completion_with_env}) - if not feedback: - position_score = 0.0 - else: - green_count, yellow_count = extract_feedback_counts(feedback) - position_score = (green_count + 0.5 * yellow_count) / 5.0 - - position_scores.append(position_score) - correct_scores.append(correct_score) - - # Use the final correct reward (win/lose is binary at end) - correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0) + # Final rewards: correct is binary win/lose, position uses last attempt (or 1.0 if won) + correct_reward_value = correct_scores[-1] if correct_scores else 0.0 + final_position_reward = 1.0 if correct_reward_value >= 1.0 else (position_scores[-1] if position_scores else 0.0) - # Position reward as shaping signal: - # - If model WINS: position_reward = 1.0 (no penalty for winning fast) - # - If model LOSES: position_reward = last attempt (where it ended up) - if correct_reward_value >= 1.0: - final_position_reward = 1.0 - else: - final_position_reward = position_scores[-1] if position_scores else 0.0 - - return { + result_dict = { "prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs, "env_mask": env_mask, - "raw_rewards": raw_rewards, "correct_reward": correct_reward_value, "position_reward": final_position_reward, "model_outputs": model_outputs, } + # Add rubric component scores for logging + rubric_scores = get_rubric_scores(observation=observation) + for name, score in rubric_scores.items(): + clean_name = name.replace("wordle.", "") if name.startswith("wordle.") else name + result_dict[f"rubric_{clean_name}"] = score + + return result_dict + # --------------------------------------------------------------------------- # Rewards @@ -552,6 +555,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: correctness_rewards: list[float] = [] position_rewards: list[float] = [] format_rewards: list[float] = [] + rubric_component_scores: dict[str, list[float]] = {} for prompt_text in prompts: episode = rollout_once( @@ -571,7 +575,15 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: position_rewards.append(episode["position_reward"]) format_rewards.append(compute_format_reward(episode["model_outputs"])) - return { + # Collect rubric component scores + for key, value in episode.items(): + if key.startswith("rubric_"): + component_name = key + if component_name not in rubric_component_scores: + rubric_component_scores[component_name] = [] + rubric_component_scores[component_name].append(value) + + result = { "prompt_ids": episode_prompt_ids, "completion_ids": episode_completion_ids, "logprobs": episode_logprobs, @@ -581,6 +593,11 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: "format_reward": format_rewards, } + # Add rubric component scores to result for logging + result.update(rubric_component_scores) + + return result + trainer = GRPOTrainer( model=args.model_id, processing_class=tokenizer, diff --git a/trl/experimental/openenv/__init__.py b/trl/experimental/openenv/__init__.py index 4325e17f284..8c66e3cde20 100644 --- a/trl/experimental/openenv/__init__.py +++ b/trl/experimental/openenv/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import generate_rollout_completions +from .utils import generate_rollout_completions, get_rubric_scores -__all__ = ["generate_rollout_completions"] +__all__ = ["generate_rollout_completions", "get_rubric_scores"] diff --git a/trl/experimental/openenv/utils.py b/trl/experimental/openenv/utils.py index 5c4c710132b..e3221bbbc49 100644 --- a/trl/experimental/openenv/utils.py +++ b/trl/experimental/openenv/utils.py @@ -207,3 +207,70 @@ def _generate_rollout_completions_colocate( trainer.vllm_generation.llm.sleep(level=2) return results + + +def get_rubric_scores(env: Any = None, observation: Any = None) -> dict[str, float]: + """ + Extract named rubric component scores from an environment or observation. + + This function supports multiple ways of accessing rubric scores: + 1. From observation.info["reward_signals"] (TextArena pattern) + 2. From observation.metadata["reward_signals"] (TextArena pattern) + 3. From env.rubric.named_rubrics() (direct rubric introspection) + 4. From env.rubric.last_score (single rubric) + + Args: + env: An OpenEnv environment instance with optional rubric support + observation: An observation object that may contain reward_signals + + Returns: + Dictionary mapping rubric component names to float scores. + Returns empty dict if no rubric scores are available. + + Example: + >>> # From observation (TextArena pattern) + >>> scores = get_rubric_scores(observation=observation) + >>> # scores might be: {"wordle.greens": 0.8, "wordle.yellows": 0.4, ...} + >>> + >>> # From environment rubric + >>> scores = get_rubric_scores(env=env) + >>> # scores might be: {"greens": 0.8, "yellows": 0.4, ...} + """ + scores = {} + + # First, try to extract from observation (TextArena pattern) + if observation is not None: + try: + # Try observation.info["reward_signals"] + if hasattr(observation, "info") and isinstance(observation.info, dict): + reward_signals = observation.info.get("reward_signals", {}) + if reward_signals: + scores.update({k: float(v) for k, v in reward_signals.items()}) + return scores + + # Try observation.metadata["reward_signals"] + if hasattr(observation, "metadata") and isinstance(observation.metadata, dict): + reward_signals = observation.metadata.get("reward_signals", {}) + if reward_signals: + scores.update({k: float(v) for k, v in reward_signals.items()}) + return scores + except Exception: + pass + + # Fall back to environment rubric introspection + if env is not None: + try: + # Check if env has rubric attribute + if hasattr(env, "rubric") and env.rubric is not None: + # Check if rubric has named_rubrics method (composable rubrics) + if hasattr(env.rubric, "named_rubrics"): + for name, rubric in env.rubric.named_rubrics(): + if hasattr(rubric, "last_score"): + scores[name] = float(rubric.last_score) + # Single rubric with last_score + elif hasattr(env.rubric, "last_score"): + scores["rubric"] = float(env.rubric.last_score) + except Exception: + pass + + return scores