diff --git a/omnisafe/adapter/online_adapter.py b/omnisafe/adapter/online_adapter.py index 468e73cd1..e3340c5cd 100644 --- a/omnisafe/adapter/online_adapter.py +++ b/omnisafe/adapter/online_adapter.py @@ -189,20 +189,46 @@ def step( torch.Tensor, dict[str, Any], ]: - """Run one timestep of the environment's dynamics using the agent actions. - - Args: - action (torch.Tensor): The action from the agent or random. - - Returns: - observation: The agent's observation of the current environment. - reward: The amount of reward returned after previous action. - cost: The amount of cost returned after previous action. - terminated: Whether the episode has ended. - truncated: Whether the episode has been truncated due to a time limit. - info: Some information logged by the environment. - """ - return self._env.step(action) + """Run one timestep of the environment's dynamics using the agent actions.""" + + + #print(f"[DEBUG OnlineAdapter.step] Action stats - min: {action.min():.6f}, max: {action.max():.6f}, mean: {action.mean():.6f}") + if torch.any(torch.isnan(action)): + #print(f"[ERROR OnlineAdapter.step] Input action contains NaN!") + #print(f"Action values: {action}") + action = torch.nan_to_num(action, nan=0.0) + + try: + obs, reward, cost, terminated, truncated, info = self._env.step(action) + except Exception as e: + #print(f"[ERROR OnlineAdapter.step] Exception in env.step(): {e}") + raise + + ''' + print(f"[DEBUG OnlineAdapter.step] Reward: {reward}, Cost: {cost}") + print(f"[DEBUG OnlineAdapter.step] Terminated: {terminated}, Truncated: {truncated}") + + if torch.any(torch.isnan(obs)): + print(f"[ERROR OnlineAdapter.step] Obs contains NaN!") + + if torch.any(torch.isnan(reward)): + print(f"[CRITICAL OnlineAdapter.step] REWARD IS NaN! Type: {type(reward)}, Shape: {reward.shape if hasattr(reward, 'shape') else 'N/A'}") + print(f"Reward values: {reward}") + if hasattr(self._env, 'last_original_reward'): + print(f"Last original reward: {self._env.last_original_reward}") + + if torch.any(torch.isnan(cost)): + print(f"[CRITICAL OnlineAdapter.step] COST IS NaN! Type: {type(cost)}, Shape: {cost.shape if hasattr(cost, 'shape') else 'N/A'}") + print(f"Cost values: {cost}") + if hasattr(self._env, 'last_original_cost'): + print(f"Last original cost: {self._env.last_original_cost}") + ''' + + reward = torch.nan_to_num(reward, nan=0.0) + cost = torch.nan_to_num(cost, nan=0.0) + obs = torch.nan_to_num(obs, nan=0.0) + + return obs, reward, cost, terminated, truncated, info def reset( self, diff --git a/omnisafe/adapter/onpolicy_adapter.py b/omnisafe/adapter/onpolicy_adapter.py index e0c8f7f34..300276747 100644 --- a/omnisafe/adapter/onpolicy_adapter.py +++ b/omnisafe/adapter/onpolicy_adapter.py @@ -55,6 +55,8 @@ def __init__( # pylint: disable=too-many-arguments super().__init__(env_id, num_envs, seed, cfgs) self._reset_log() + self._debug_logged = False + def rollout( # pylint: disable=too-many-locals self, steps_per_epoch: int, @@ -85,6 +87,19 @@ def rollout( # pylint: disable=too-many-locals act, value_r, value_c, logp = agent.step(obs) next_obs, reward, cost, terminated, truncated, info = self.step(act) + '''调试 + if torch.any(torch.isnan(reward)): + print(f"[DEBUG Rollout] Step {step}: NaN reward from env.step(): {reward}") + print(f"[DEBUG Rollout] Info keys: {list(info.keys())}") + if 'original_reward' in info: + print(f"[DEBUG Rollout] Original reward in info: {info['original_reward']}") + + if torch.any(torch.isnan(cost)): + print(f"[DEBUG Rollout] Step {step}: NaN cost from env.step(): {cost}") + if 'original_cost' in info: + print(f"[DEBUG Rollout] Original cost in info: {info['original_cost']}") + ''' + self._log_value(reward=reward, cost=cost, info=info) if self._cfgs.algo_cfgs.use_cost: @@ -125,14 +140,10 @@ def rollout( # pylint: disable=too-many-locals last_value_r = last_value_r.unsqueeze(0) last_value_c = last_value_c.unsqueeze(0) - if done or time_out: + if done or time_out or epoch_end: #here i add epoch_end to justify the log self._log_metrics(logger, idx) self._reset_log(idx) - self._ep_ret[idx] = 0.0 - self._ep_cost[idx] = 0.0 - self._ep_len[idx] = 0.0 - buffer.finish_path(last_value_r, last_value_c, idx) def _log_value( @@ -152,26 +163,112 @@ def _log_value( cost (torch.Tensor): The immediate step cost. info (dict[str, Any]): Some information logged by the environment. """ - self._ep_ret += info.get('original_reward', reward).cpu() - self._ep_cost += info.get('original_cost', cost).cpu() + + ''' + if hasattr(self, '_debug_logged') and not self._debug_logged: + print(f"[DEBUG _log_value] First call debug:") + print(f" reward shape: {reward.shape}, value: {reward}") + print(f" cost shape: {cost.shape}, value: {cost}") + print(f" info keys: {list(info.keys())}") + if 'original_reward' in info: + print(f" original_reward: {info['original_reward']}") + if 'original_cost' in info: + print(f" original_cost: {info['original_cost']}") + print(f" _ep_ret before: {self._ep_ret}") + print(f" _ep_cost before: {self._ep_cost}") + print(f" _ep_len before: {self._ep_len}") + self._debug_logged = True + ''' + + #FIX BEGIN + raw_reward = info.get('original_reward', reward) + raw_cost = info.get('original_cost', cost) + + if torch.any(torch.isnan(raw_reward)): + #print(f"[CRITICAL _log_value] NaN raw_reward detected! raw_reward={raw_reward}, reward={reward}") + #if 'original_reward' in info: + #print(f" original_reward in info: {info['original_reward']}") + raw_reward = torch.nan_to_num(raw_reward, nan=0.0) + + if torch.any(torch.isnan(raw_cost)): + #print(f"[CRITICAL _log_value] NaN raw_cost detected! raw_cost={raw_cost}, cost={cost}") + #if 'original_cost' in info: + #print(f" original_cost in info: {info['original_cost']}") + raw_cost = torch.nan_to_num(raw_cost, nan=0.0) + + if torch.any(torch.isnan(self._ep_ret)): + #print(f"[CRITICAL _log_value] _ep_ret is NaN before addition! Value: {self._ep_ret}") + self._ep_ret = torch.zeros_like(self._ep_ret) + + if torch.any(torch.isnan(self._ep_cost)): + #print(f"[CRITICAL _log_value] _ep_cost is NaN before addition! Value: {self._ep_cost}") + self._ep_cost = torch.zeros_like(self._ep_cost) + + if torch.any(torch.isnan(self._ep_len)): + #print(f"[CRITICAL _log_value] _ep_len is NaN before addition! Value: {self._ep_len}") + self._ep_len = torch.zeros_like(self._ep_len) + + self._ep_ret += raw_reward.cpu() + self._ep_cost += raw_cost.cpu() self._ep_len += 1 + + ''' + if torch.any(torch.isnan(self._ep_ret)): + print(f"[CRITICAL _log_value] _ep_ret became NaN after addition!") + if torch.any(torch.isnan(self._ep_cost)): + print(f"[CRITICAL _log_value] _ep_cost became NaN after addition!") + if torch.any(torch.isnan(self._ep_len)): + print(f"[CRITICAL _log_value] _ep_len became NaN after addition!") + ''' def _log_metrics(self, logger: Logger, idx: int) -> None: - """Log metrics, including ``EpRet``, ``EpCost``, ``EpLen``. - - Args: - logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``. - idx (int): The index of the environment. - """ + """Log metrics, including ``EpRet``, ``EpCost``, ``EpLen``.""" if hasattr(self._env, 'spec_log'): self._env.spec_log(logger) + + ''' + print(f"[DEBUG _log_metrics] Called for idx={idx}") + print(f" _ep_ret: {self._ep_ret}, type: {type(self._ep_ret)}") + print(f" _ep_cost: {self._ep_cost}, type: {type(self._ep_cost)}") + print(f" _ep_len: {self._ep_len}, type: {type(self._ep_len)}") + + if torch.any(torch.isnan(self._ep_ret)): + print(f"[ERROR _log_metrics] _ep_ret contains NaN! Values: {self._ep_ret}") + if torch.any(torch.isnan(self._ep_cost)): + print(f"[ERROR _log_metrics] _ep_cost contains NaN! Values: {self._ep_cost}") + if torch.any(torch.isnan(self._ep_len)): + print(f"[ERROR _log_metrics] _ep_len contains NaN! Values: {self._ep_len}") + ''' + + ep_ret_val = self._ep_ret[idx] + ep_cost_val = self._ep_cost[idx] + ep_len_val = self._ep_len[idx] + + #print(f" ep_ret_val: {ep_ret_val}, type: {type(ep_ret_val)}") + #print(f" ep_cost_val: {ep_cost_val}, type: {type(ep_cost_val)}") + #print(f" ep_len_val: {ep_len_val}, type: {type(ep_len_val)}") + + if torch.isnan(ep_ret_val) or torch.isinf(ep_ret_val): + #print(f"[FIXING _log_metrics] EpRet[{idx}] = {ep_ret_val}, setting to 0.0") + ep_ret_val = torch.tensor(0.0, dtype=torch.float32) + + if torch.isnan(ep_cost_val) or torch.isinf(ep_cost_val): + #print(f"[FIXING _log_metrics] EpCost[{idx}] = {ep_cost_val}, setting to 0.0") + ep_cost_val = torch.tensor(0.0, dtype=torch.float32) + + if torch.isnan(ep_len_val) or torch.isinf(ep_len_val): + #print(f"[FIXING _log_metrics] EpLen[{idx}] = {ep_len_val}, setting to 0.0") + ep_len_val = torch.tensor(0.0, dtype=torch.float32) + logger.store( { - 'Metrics/EpRet': self._ep_ret[idx], - 'Metrics/EpCost': self._ep_cost[idx], - 'Metrics/EpLen': self._ep_len[idx], + 'Metrics/EpRet': ep_ret_val, + 'Metrics/EpCost': ep_cost_val, + 'Metrics/EpLen': ep_len_val, }, ) + + #print(f"[DEBUG _log_metrics] Stored values - EpRet: {ep_ret_val}, EpCost: {ep_cost_val}, EpLen: {ep_len_val}") def _reset_log(self, idx: int | None = None) -> None: """Reset the episode return, episode cost and episode length. @@ -188,3 +285,15 @@ def _reset_log(self, idx: int | None = None) -> None: self._ep_ret[idx] = 0.0 self._ep_cost[idx] = 0.0 self._ep_len[idx] = 0.0 + + if torch.any(torch.isnan(self._ep_ret)): + #print(f"[ERROR _reset_log] _ep_ret initialized as NaN!") + self._ep_ret = torch.zeros_like(self._ep_ret) + + if torch.any(torch.isnan(self._ep_cost)): + #print(f"[ERROR _reset_log] _ep_cost initialized as NaN!") + self._ep_cost = torch.zeros_like(self._ep_cost) + + if torch.any(torch.isnan(self._ep_len)): + #print(f"[ERROR _reset_log] _ep_len initialized as NaN!") + self._ep_len = torch.zeros_like(self._ep_len) diff --git a/omnisafe/algorithms/on_policy/base/policy_gradient.py b/omnisafe/algorithms/on_policy/base/policy_gradient.py index e0792d6ab..20270854b 100644 --- a/omnisafe/algorithms/on_policy/base/policy_gradient.py +++ b/omnisafe/algorithms/on_policy/base/policy_gradient.py @@ -33,6 +33,7 @@ from omnisafe.models.actor_critic.constraint_actor_critic import ConstraintActorCritic from omnisafe.utils import distributed +import numpy as np @registry.register # pylint: disable-next=too-many-instance-attributes,too-few-public-methods,line-too-long @@ -264,6 +265,23 @@ def learn(self) -> tuple[float, float, float]: ) self._logger.store({'Time/Rollout': time.time() - rollout_time}) + ''' + print(f"[DEBUG] Epoch {epoch}: Checking stats after rollout") + ep_ret_stats = self._logger.get_stats('Metrics/EpRet') + ep_cost_stats = self._logger.get_stats('Metrics/EpCost') + ep_len_stats = self._logger.get_stats('Metrics/EpLen') + print(f" EpRet: {ep_ret_stats}") + print(f" EpCost: {ep_cost_stats}") + print(f" EpLen: {ep_len_stats}") + + if len(ep_ret_stats) > 0 and np.any(np.isnan(ep_ret_stats)): + print(f" WARNING: EpRet contains NaN!") + if len(ep_cost_stats) > 0 and np.any(np.isnan(ep_cost_stats)): + print(f" WARNING: EpCost contains NaN!") + if len(ep_len_stats) > 0 and np.any(np.isnan(ep_len_stats)): + print(f" WARNING: EpLen contains NaN!") + ''' + update_time = time.time() self._update() self._logger.store({'Time/Update': time.time() - update_time}) diff --git a/omnisafe/algorithms/on_policy/naive_lagrange/ppo_lag.py b/omnisafe/algorithms/on_policy/naive_lagrange/ppo_lag.py index 930ab3e53..9b093a725 100644 --- a/omnisafe/algorithms/on_policy/naive_lagrange/ppo_lag.py +++ b/omnisafe/algorithms/on_policy/naive_lagrange/ppo_lag.py @@ -70,8 +70,20 @@ def _update(self) -> None: where :math:`\lambda` is the Lagrange multiplier parameter. """ # note that logger already uses MPI statistics across all processes.. - Jc = self._logger.get_stats('Metrics/EpCost')[0] - assert not np.isnan(Jc), 'cost for updating lagrange multiplier is nan' + + + #FIX FROM HERE + #assert not np.isnan(Jc), 'cost for updating lagrange multiplier is nan' + costs = self._logger.get_stats('Metrics/EpCost') + + Jc = costs[0] + if np.isnan(Jc): + print(f"[PPO-Lag Warning] NaN cost detected, using 1e-8 instead") + Jc = 1e-8 + + assert not np.isnan(Jc), f'cost is nan after protection: {Jc}' + #FIX END + # first update Lagrange multiplier parameter self._lagrange.update_lagrange_multiplier(Jc) # then update the policy and value function diff --git a/reproduce_nan_issue.py b/reproduce_nan_issue.py new file mode 100644 index 000000000..881ad3541 --- /dev/null +++ b/reproduce_nan_issue.py @@ -0,0 +1,45 @@ +""" +Reproduce NaN cost assertion in PPOLag with zero cost_limit. +Expected error: AssertionError: cost for updating lagrange multiplier is nan +""" + +import omnisafe + +def reproduce_nan_bug(): + """Reproduce the NaN assertion bug in PPOLag.""" + + print("=" * 60) + print("Reproducing PPOLag NaN assertion bug") + print("=" * 60) + + agent = omnisafe.Agent( + algo='PPOLag', + env_id='SafetyPointGoal1-v0', # A relatively safe environment + train_terminal_cfgs={ + 'total_steps': 2000, # Enough to trigger within first epoch + 'parallel': 1, + 'device': 'cpu', + }, + custom_cfgs={ + 'algo_cfgs': { + 'steps_per_epoch': 200, + }, + 'lagrange_cfgs': { + 'cost_limit': 0.0 # This is the key trigger + } + } + ) + + print("Configuration:") + print(f" Algorithm: PPOLag") + print(f" Environment: SafetyPointGoal1-v0") + print(f" Cost limit: 0.0") + print() + print("Expected behavior: Will crash with AssertionError") + print("Actual behavior: ...") + print() + + agent.learn() + +if __name__ == '__main__': + reproduce_nan_bug() \ No newline at end of file diff --git a/train_with_risk.py b/train_with_risk.py new file mode 100644 index 000000000..cd35499f6 --- /dev/null +++ b/train_with_risk.py @@ -0,0 +1,67 @@ +import omnisafe + +print("=== 使用更容易产生成本的环境 ===") + +# 尝试这些环境(更容易违反约束): +# 1. SafetyAntVelocity-v1 (速度控制,容易超速) +# 2. SafetyPointCircle1-v0 (绕圈任务,容易碰壁) +# 3. SafetyCarButton1-v0 (按钮任务,有障碍物) + +env_id = 'SafetyAntVelocity-v1' # 推荐:速度控制任务容易产生成本 +print(f"环境: {env_id}") + +agent = omnisafe.Agent( + algo='PPOLag', + env_id=env_id, + train_terminal_cfgs={ + 'parallel': 1, + 'total_steps': 5000, + 'device': 'cpu', + 'vector_env_nums': 1, + 'torch_threads': 1, + }, + custom_cfgs={ + 'algo_cfgs': { + 'steps_per_epoch': 500, + }, + 'lagrange_cfgs': { + 'cost_limit': 10.0, + } + } +) + +print("开始训练...") +try: + agent.learn() + print("✅ 训练成功!") +except Exception as e: + print(f"❌ 训练失败: {e}") + import traceback + traceback.print_exc() + + print("\n尝试更激进的环境...") + env_id = 'SafetyPointCircle1-v0' # 绕圈任务,更容易碰壁 + print(f"新环境: {env_id}") + + agent2 = omnisafe.Agent( + algo='PPOLag', + env_id=env_id, + train_terminal_cfgs={ + 'parallel': 1, + 'total_steps': 3000, + 'device': 'cpu', + 'vector_env_nums': 1, + 'torch_threads': 1, + }, + custom_cfgs={ + 'algo_cfgs': { + 'steps_per_epoch': 300, + }, + 'lagrange_cfgs': { + 'cost_limit': 5.0, + } + } + ) + + print("开始第二次训练...") + agent2.learn()