|
17 | 17 |
|
18 | 18 | def train_PPO(): |
19 | 19 | LOG_DIR = "./logs/" |
20 | | - CHECKPOINT_DIR = "../models/ppo_checkpoints/" |
| 20 | + CHECKPOINT_DIR = "./models/ppo_checkpoints/" |
21 | 21 | os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
| 22 | + |
22 | 23 | env = gym.make('simpleBiped-v0', render=False) |
23 | | - model = PPO("MlpPolicy", env, verbose=1, learning_rate=0.001, device="auto", tensorboard_log=LOG_DIR) |
24 | 24 |
|
25 | | - # Add checkpoint callback |
26 | | - from stable_baselines3.common.callbacks import CheckpointCallback |
27 | | - checkpoint_callback = CheckpointCallback( |
28 | | - save_freq=100_000, |
29 | | - save_path=CHECKPOINT_DIR, |
30 | | - name_prefix="ppo_model" |
| 25 | + model = PPO( |
| 26 | + "MlpPolicy", |
| 27 | + env, |
| 28 | + verbose=1, |
| 29 | + learning_rate=0.001, |
| 30 | + device="auto", |
| 31 | + tensorboard_log=LOG_DIR |
31 | 32 | ) |
32 | | - model.learn(total_timesteps=5_000_000, callback=checkpoint_callback) |
33 | | - model.save("../models/humanoid_ppo") |
34 | | - env.close() |
35 | 33 |
|
36 | | -def train_DDPG(): |
37 | | - LOG_DIR = "./logs/" |
38 | | - env = gym.make('simpleBiped-v0', render_mode=True) |
39 | | - model = DDPG("MlpPolicy", env, verbose=1, device="auto", tensorboard_log=LOG_DIR) |
40 | | - model.learn(total_timesteps=10000000) |
41 | | - model.save("../models/humanoid_ddpg") |
| 34 | + from stable_baselines3.common.callbacks import EvalCallback |
| 35 | + eval_env = gym.make('simpleBiped-v0', render=False) |
| 36 | + eval_callback = EvalCallback( |
| 37 | + eval_env, |
| 38 | + best_model_save_path=CHECKPOINT_DIR, |
| 39 | + log_path=CHECKPOINT_DIR, |
| 40 | + eval_freq=50_000, # run evaluation every 50k steps |
| 41 | + deterministic=True, |
| 42 | + render=False, |
| 43 | + verbose=1 |
| 44 | + ) |
| 45 | + |
| 46 | + model.learn(total_timesteps=5_000_000, callback=eval_callback) |
| 47 | + |
| 48 | + model.save("./models/humanoid_ppo") |
42 | 49 | env.close() |
| 50 | + eval_env.close() |
| 51 | + |
43 | 52 |
|
44 | 53 | def run(): |
45 | 54 | env = DummyVecEnv([lambda: simpleBipedEnv(render=True)]) |
|
0 commit comments