Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions omnisafe/adapter/online_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,20 +189,46 @@ def step(
torch.Tensor,
dict[str, Any],
]:
"""Run one timestep of the environment's dynamics using the agent actions.

Args:
action (torch.Tensor): The action from the agent or random.

Returns:
observation: The agent's observation of the current environment.
reward: The amount of reward returned after previous action.
cost: The amount of cost returned after previous action.
terminated: Whether the episode has ended.
truncated: Whether the episode has been truncated due to a time limit.
info: Some information logged by the environment.
"""
return self._env.step(action)
"""Run one timestep of the environment's dynamics using the agent actions."""


#print(f"[DEBUG OnlineAdapter.step] Action stats - min: {action.min():.6f}, max: {action.max():.6f}, mean: {action.mean():.6f}")
if torch.any(torch.isnan(action)):
#print(f"[ERROR OnlineAdapter.step] Input action contains NaN!")
#print(f"Action values: {action}")
action = torch.nan_to_num(action, nan=0.0)

try:
obs, reward, cost, terminated, truncated, info = self._env.step(action)
except Exception as e:
#print(f"[ERROR OnlineAdapter.step] Exception in env.step(): {e}")
raise

'''
print(f"[DEBUG OnlineAdapter.step] Reward: {reward}, Cost: {cost}")
print(f"[DEBUG OnlineAdapter.step] Terminated: {terminated}, Truncated: {truncated}")

if torch.any(torch.isnan(obs)):
print(f"[ERROR OnlineAdapter.step] Obs contains NaN!")

if torch.any(torch.isnan(reward)):
print(f"[CRITICAL OnlineAdapter.step] REWARD IS NaN! Type: {type(reward)}, Shape: {reward.shape if hasattr(reward, 'shape') else 'N/A'}")
print(f"Reward values: {reward}")
if hasattr(self._env, 'last_original_reward'):
print(f"Last original reward: {self._env.last_original_reward}")

if torch.any(torch.isnan(cost)):
print(f"[CRITICAL OnlineAdapter.step] COST IS NaN! Type: {type(cost)}, Shape: {cost.shape if hasattr(cost, 'shape') else 'N/A'}")
print(f"Cost values: {cost}")
if hasattr(self._env, 'last_original_cost'):
print(f"Last original cost: {self._env.last_original_cost}")
'''

reward = torch.nan_to_num(reward, nan=0.0)
cost = torch.nan_to_num(cost, nan=0.0)
obs = torch.nan_to_num(obs, nan=0.0)

return obs, reward, cost, terminated, truncated, info

def reset(
self,
Expand Down
141 changes: 125 additions & 16 deletions omnisafe/adapter/onpolicy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__( # pylint: disable=too-many-arguments
super().__init__(env_id, num_envs, seed, cfgs)
self._reset_log()

self._debug_logged = False

def rollout( # pylint: disable=too-many-locals
self,
steps_per_epoch: int,
Expand Down Expand Up @@ -85,6 +87,19 @@ def rollout( # pylint: disable=too-many-locals
act, value_r, value_c, logp = agent.step(obs)
next_obs, reward, cost, terminated, truncated, info = self.step(act)

'''调试
if torch.any(torch.isnan(reward)):
print(f"[DEBUG Rollout] Step {step}: NaN reward from env.step(): {reward}")
print(f"[DEBUG Rollout] Info keys: {list(info.keys())}")
if 'original_reward' in info:
print(f"[DEBUG Rollout] Original reward in info: {info['original_reward']}")

if torch.any(torch.isnan(cost)):
print(f"[DEBUG Rollout] Step {step}: NaN cost from env.step(): {cost}")
if 'original_cost' in info:
print(f"[DEBUG Rollout] Original cost in info: {info['original_cost']}")
'''

self._log_value(reward=reward, cost=cost, info=info)

if self._cfgs.algo_cfgs.use_cost:
Expand Down Expand Up @@ -125,14 +140,10 @@ def rollout( # pylint: disable=too-many-locals
last_value_r = last_value_r.unsqueeze(0)
last_value_c = last_value_c.unsqueeze(0)

if done or time_out:
if done or time_out or epoch_end: #here i add epoch_end to justify the log
self._log_metrics(logger, idx)
self._reset_log(idx)

self._ep_ret[idx] = 0.0
self._ep_cost[idx] = 0.0
self._ep_len[idx] = 0.0

buffer.finish_path(last_value_r, last_value_c, idx)

def _log_value(
Expand All @@ -152,26 +163,112 @@ def _log_value(
cost (torch.Tensor): The immediate step cost.
info (dict[str, Any]): Some information logged by the environment.
"""
self._ep_ret += info.get('original_reward', reward).cpu()
self._ep_cost += info.get('original_cost', cost).cpu()

'''
if hasattr(self, '_debug_logged') and not self._debug_logged:
print(f"[DEBUG _log_value] First call debug:")
print(f" reward shape: {reward.shape}, value: {reward}")
print(f" cost shape: {cost.shape}, value: {cost}")
print(f" info keys: {list(info.keys())}")
if 'original_reward' in info:
print(f" original_reward: {info['original_reward']}")
if 'original_cost' in info:
print(f" original_cost: {info['original_cost']}")
print(f" _ep_ret before: {self._ep_ret}")
print(f" _ep_cost before: {self._ep_cost}")
print(f" _ep_len before: {self._ep_len}")
self._debug_logged = True
'''

#FIX BEGIN
raw_reward = info.get('original_reward', reward)
raw_cost = info.get('original_cost', cost)

if torch.any(torch.isnan(raw_reward)):
#print(f"[CRITICAL _log_value] NaN raw_reward detected! raw_reward={raw_reward}, reward={reward}")
#if 'original_reward' in info:
#print(f" original_reward in info: {info['original_reward']}")
raw_reward = torch.nan_to_num(raw_reward, nan=0.0)

if torch.any(torch.isnan(raw_cost)):
#print(f"[CRITICAL _log_value] NaN raw_cost detected! raw_cost={raw_cost}, cost={cost}")
#if 'original_cost' in info:
#print(f" original_cost in info: {info['original_cost']}")
raw_cost = torch.nan_to_num(raw_cost, nan=0.0)

if torch.any(torch.isnan(self._ep_ret)):
#print(f"[CRITICAL _log_value] _ep_ret is NaN before addition! Value: {self._ep_ret}")
self._ep_ret = torch.zeros_like(self._ep_ret)

if torch.any(torch.isnan(self._ep_cost)):
#print(f"[CRITICAL _log_value] _ep_cost is NaN before addition! Value: {self._ep_cost}")
self._ep_cost = torch.zeros_like(self._ep_cost)

if torch.any(torch.isnan(self._ep_len)):
#print(f"[CRITICAL _log_value] _ep_len is NaN before addition! Value: {self._ep_len}")
self._ep_len = torch.zeros_like(self._ep_len)

self._ep_ret += raw_reward.cpu()
self._ep_cost += raw_cost.cpu()
self._ep_len += 1

'''
if torch.any(torch.isnan(self._ep_ret)):
print(f"[CRITICAL _log_value] _ep_ret became NaN after addition!")
if torch.any(torch.isnan(self._ep_cost)):
print(f"[CRITICAL _log_value] _ep_cost became NaN after addition!")
if torch.any(torch.isnan(self._ep_len)):
print(f"[CRITICAL _log_value] _ep_len became NaN after addition!")
'''

def _log_metrics(self, logger: Logger, idx: int) -> None:
"""Log metrics, including ``EpRet``, ``EpCost``, ``EpLen``.

Args:
logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``.
idx (int): The index of the environment.
"""
"""Log metrics, including ``EpRet``, ``EpCost``, ``EpLen``."""
if hasattr(self._env, 'spec_log'):
self._env.spec_log(logger)

'''
print(f"[DEBUG _log_metrics] Called for idx={idx}")
print(f" _ep_ret: {self._ep_ret}, type: {type(self._ep_ret)}")
print(f" _ep_cost: {self._ep_cost}, type: {type(self._ep_cost)}")
print(f" _ep_len: {self._ep_len}, type: {type(self._ep_len)}")

if torch.any(torch.isnan(self._ep_ret)):
print(f"[ERROR _log_metrics] _ep_ret contains NaN! Values: {self._ep_ret}")
if torch.any(torch.isnan(self._ep_cost)):
print(f"[ERROR _log_metrics] _ep_cost contains NaN! Values: {self._ep_cost}")
if torch.any(torch.isnan(self._ep_len)):
print(f"[ERROR _log_metrics] _ep_len contains NaN! Values: {self._ep_len}")
'''

ep_ret_val = self._ep_ret[idx]
ep_cost_val = self._ep_cost[idx]
ep_len_val = self._ep_len[idx]

#print(f" ep_ret_val: {ep_ret_val}, type: {type(ep_ret_val)}")
#print(f" ep_cost_val: {ep_cost_val}, type: {type(ep_cost_val)}")
#print(f" ep_len_val: {ep_len_val}, type: {type(ep_len_val)}")

if torch.isnan(ep_ret_val) or torch.isinf(ep_ret_val):
#print(f"[FIXING _log_metrics] EpRet[{idx}] = {ep_ret_val}, setting to 0.0")
ep_ret_val = torch.tensor(0.0, dtype=torch.float32)

if torch.isnan(ep_cost_val) or torch.isinf(ep_cost_val):
#print(f"[FIXING _log_metrics] EpCost[{idx}] = {ep_cost_val}, setting to 0.0")
ep_cost_val = torch.tensor(0.0, dtype=torch.float32)

if torch.isnan(ep_len_val) or torch.isinf(ep_len_val):
#print(f"[FIXING _log_metrics] EpLen[{idx}] = {ep_len_val}, setting to 0.0")
ep_len_val = torch.tensor(0.0, dtype=torch.float32)

logger.store(
{
'Metrics/EpRet': self._ep_ret[idx],
'Metrics/EpCost': self._ep_cost[idx],
'Metrics/EpLen': self._ep_len[idx],
'Metrics/EpRet': ep_ret_val,
'Metrics/EpCost': ep_cost_val,
'Metrics/EpLen': ep_len_val,
},
)

#print(f"[DEBUG _log_metrics] Stored values - EpRet: {ep_ret_val}, EpCost: {ep_cost_val}, EpLen: {ep_len_val}")

def _reset_log(self, idx: int | None = None) -> None:
"""Reset the episode return, episode cost and episode length.
Expand All @@ -188,3 +285,15 @@ def _reset_log(self, idx: int | None = None) -> None:
self._ep_ret[idx] = 0.0
self._ep_cost[idx] = 0.0
self._ep_len[idx] = 0.0

if torch.any(torch.isnan(self._ep_ret)):
#print(f"[ERROR _reset_log] _ep_ret initialized as NaN!")
self._ep_ret = torch.zeros_like(self._ep_ret)

if torch.any(torch.isnan(self._ep_cost)):
#print(f"[ERROR _reset_log] _ep_cost initialized as NaN!")
self._ep_cost = torch.zeros_like(self._ep_cost)

if torch.any(torch.isnan(self._ep_len)):
#print(f"[ERROR _reset_log] _ep_len initialized as NaN!")
self._ep_len = torch.zeros_like(self._ep_len)
18 changes: 18 additions & 0 deletions omnisafe/algorithms/on_policy/base/policy_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from omnisafe.models.actor_critic.constraint_actor_critic import ConstraintActorCritic
from omnisafe.utils import distributed

import numpy as np

@registry.register
# pylint: disable-next=too-many-instance-attributes,too-few-public-methods,line-too-long
Expand Down Expand Up @@ -264,6 +265,23 @@ def learn(self) -> tuple[float, float, float]:
)
self._logger.store({'Time/Rollout': time.time() - rollout_time})

'''
print(f"[DEBUG] Epoch {epoch}: Checking stats after rollout")
ep_ret_stats = self._logger.get_stats('Metrics/EpRet')
ep_cost_stats = self._logger.get_stats('Metrics/EpCost')
ep_len_stats = self._logger.get_stats('Metrics/EpLen')
print(f" EpRet: {ep_ret_stats}")
print(f" EpCost: {ep_cost_stats}")
print(f" EpLen: {ep_len_stats}")

if len(ep_ret_stats) > 0 and np.any(np.isnan(ep_ret_stats)):
print(f" WARNING: EpRet contains NaN!")
if len(ep_cost_stats) > 0 and np.any(np.isnan(ep_cost_stats)):
print(f" WARNING: EpCost contains NaN!")
if len(ep_len_stats) > 0 and np.any(np.isnan(ep_len_stats)):
print(f" WARNING: EpLen contains NaN!")
'''

update_time = time.time()
self._update()
self._logger.store({'Time/Update': time.time() - update_time})
Expand Down
16 changes: 14 additions & 2 deletions omnisafe/algorithms/on_policy/naive_lagrange/ppo_lag.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,20 @@ def _update(self) -> None:
where :math:`\lambda` is the Lagrange multiplier parameter.
"""
# note that logger already uses MPI statistics across all processes..
Jc = self._logger.get_stats('Metrics/EpCost')[0]
assert not np.isnan(Jc), 'cost for updating lagrange multiplier is nan'


#FIX FROM HERE
#assert not np.isnan(Jc), 'cost for updating lagrange multiplier is nan'
costs = self._logger.get_stats('Metrics/EpCost')

Jc = costs[0]
if np.isnan(Jc):
print(f"[PPO-Lag Warning] NaN cost detected, using 1e-8 instead")
Jc = 1e-8

assert not np.isnan(Jc), f'cost is nan after protection: {Jc}'
#FIX END

# first update Lagrange multiplier parameter
self._lagrange.update_lagrange_multiplier(Jc)
# then update the policy and value function
Expand Down
45 changes: 45 additions & 0 deletions reproduce_nan_issue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Reproduce NaN cost assertion in PPOLag with zero cost_limit.
Expected error: AssertionError: cost for updating lagrange multiplier is nan
"""

import omnisafe

def reproduce_nan_bug():
"""Reproduce the NaN assertion bug in PPOLag."""

print("=" * 60)
print("Reproducing PPOLag NaN assertion bug")
print("=" * 60)

agent = omnisafe.Agent(
algo='PPOLag',
env_id='SafetyPointGoal1-v0', # A relatively safe environment
train_terminal_cfgs={
'total_steps': 2000, # Enough to trigger within first epoch
'parallel': 1,
'device': 'cpu',
},
custom_cfgs={
'algo_cfgs': {
'steps_per_epoch': 200,
},
'lagrange_cfgs': {
'cost_limit': 0.0 # This is the key trigger
}
}
)

print("Configuration:")
print(f" Algorithm: PPOLag")
print(f" Environment: SafetyPointGoal1-v0")
print(f" Cost limit: 0.0")
print()
print("Expected behavior: Will crash with AssertionError")
print("Actual behavior: ...")
print()

agent.learn()

if __name__ == '__main__':
reproduce_nan_bug()
Loading