Skip to content

Commit ad326e0

Browse files
committed
update to checkpointing
1 parent d8b0884 commit ad326e0

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

envs/test.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,38 @@
1717

1818
def train_PPO():
1919
LOG_DIR = "./logs/"
20-
CHECKPOINT_DIR = "../models/ppo_checkpoints/"
20+
CHECKPOINT_DIR = "./models/ppo_checkpoints/"
2121
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
22+
2223
env = gym.make('simpleBiped-v0', render=False)
23-
model = PPO("MlpPolicy", env, verbose=1, learning_rate=0.001, device="auto", tensorboard_log=LOG_DIR)
2424

25-
# Add checkpoint callback
26-
from stable_baselines3.common.callbacks import CheckpointCallback
27-
checkpoint_callback = CheckpointCallback(
28-
save_freq=100_000,
29-
save_path=CHECKPOINT_DIR,
30-
name_prefix="ppo_model"
25+
model = PPO(
26+
"MlpPolicy",
27+
env,
28+
verbose=1,
29+
learning_rate=0.001,
30+
device="auto",
31+
tensorboard_log=LOG_DIR
3132
)
32-
model.learn(total_timesteps=5_000_000, callback=checkpoint_callback)
33-
model.save("../models/humanoid_ppo")
34-
env.close()
3533

36-
def train_DDPG():
37-
LOG_DIR = "./logs/"
38-
env = gym.make('simpleBiped-v0', render_mode=True)
39-
model = DDPG("MlpPolicy", env, verbose=1, device="auto", tensorboard_log=LOG_DIR)
40-
model.learn(total_timesteps=10000000)
41-
model.save("../models/humanoid_ddpg")
34+
from stable_baselines3.common.callbacks import EvalCallback
35+
eval_env = gym.make('simpleBiped-v0', render=False)
36+
eval_callback = EvalCallback(
37+
eval_env,
38+
best_model_save_path=CHECKPOINT_DIR,
39+
log_path=CHECKPOINT_DIR,
40+
eval_freq=50_000, # run evaluation every 50k steps
41+
deterministic=True,
42+
render=False,
43+
verbose=1
44+
)
45+
46+
model.learn(total_timesteps=5_000_000, callback=eval_callback)
47+
48+
model.save("./models/humanoid_ppo")
4249
env.close()
50+
eval_env.close()
51+
4352

4453
def run():
4554
env = DummyVecEnv([lambda: simpleBipedEnv(render=True)])
Binary file not shown.

0 commit comments

Comments
 (0)