Skip to content

misaghsoltani/DGTS

Repository files navigation

Diffusion-Guided Tree Search (DGTS)

DGTS logo

DGTS is a research codebase for uncertainty-aware planning in offline reinforcement learning. It couples:

  • A trajectory diffusion world model that proposes multi-step rollouts
  • A value ensemble and optional reward model that score rollouts and estimate epistemic uncertainty
  • A diffusion-guided tree search that selects actions with an uncertainty-aware UCB rule

Note

This repository currently provides an end-to-end pipeline and diagnostics for the Minari dataset D4RL/pointmaze/umaze-dense-v2.

Installation

Recommended: Pixi (reproducible Python and system dependencies)

pixi install

Alternative: editable install with pip

python -m pip install -e .

Quickstart

Train diffusion and value or reward models

pixi run dgts train-all train.out_dir=outputs train.seed=42

Run the full evaluation suite (component diagnostics plus policy rollouts)

pixi run dgts eval-all train.out_dir=outputs visualization.eval_dir=outputs/visuals/eval

Record GIFs

pixi run dgts viz-rollout train.out_dir=outputs visualization.gif_dir=outputs/visuals/rollouts
pixi run dgts eval train.out_dir=outputs visualization.eval_dir=outputs/visuals/eval visualization.eval_gif_episodes=5

To see all available commands:

pixi run dgts --help

Configuration

All entry points use Hydra. You can override any config field from the command line:

pixi run dgts train-all train.steps=50000 diffusion.inference_steps=50 mcts.simulations=300

Default configs live in configs/.

Outputs and diagnostics

  • Checkpoints: ${train.out_dir}/diffusion/ and ${train.out_dir}/value_reward/
  • Resolved configs and run logs: ${train.out_dir}/hydra/
  • Evaluation artifacts: ${visualization.eval_dir}/

Metric definitions and artifact descriptions are documented in EVALUATION.md.

Trajectory hollows (single-image rollout summaries)

This repo includes scripts that compress rollout GIFs into a single 2x2 PDF figure by accumulating the agent footprint over time.

First, create solved demonstration GIFs by replaying dataset actions:

pixi run dgts render-minari-episode-gifs \
  --dataset-id D4RL/pointmaze/umaze-dense-v2 \
  --episodes 0 7 \
  --out-dir outputs/visuals/gifs \
  --prefix demo

Then, select any two solved GIFs and any two unsolved GIFs to build the hollow panel:

pixi run dgts make-trajectory-hollows \
  --gif-dir outputs/visuals/gifs \
  --out-pdf outputs/visuals/trajectory_hollows.pdf \
  --solved-gif outputs/visuals/gifs/demo_ep0000.gif \
  --solved-gif outputs/visuals/gifs/demo_ep0007.gif \
  --unsolved-gif outputs/visuals/eval/D4RL_pointmaze_umaze-dense-v2_eval_ep0006.gif \
  --unsolved-gif outputs/visuals/eval/D4RL_pointmaze_umaze-dense-v2_eval_ep0011.gif

License

MIT, see LICENSE.

About

Diffusion-Guided Tree Search: Uncertainty-Aware Planning with Learned World Models

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published