From 3b15aa52cfa3a45a04ced1d9b00022a2a17e9e34 Mon Sep 17 00:00:00 2001 From: Jin Zhu Date: Mon, 15 Jul 2024 21:24:28 +0100 Subject: [PATCH 1/3] Discrete IQL --- d3rlpy/algos/qlearning/iql.py | 102 +++++++++++++++++++- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 110 ++++++++++++++++++++++ d3rlpy/algos/qlearning/torch/iql_impl.py | 106 ++++++++++++++++++++- 3 files changed, 314 insertions(+), 4 deletions(-) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index 4f1ce04c..8150ebdf 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -4,17 +4,20 @@ from ...constants import ActionSpace from ...models.builders import ( create_continuous_q_function, + create_discrete_q_function, + create_categorical_policy, create_normal_policy, create_value_function, ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field from ...models.q_functions import MeanQFunctionFactory +from ...models.q_functions import QFunctionFactory, make_q_func_field from ...types import Shape from .base import QLearningAlgoBase -from .torch.iql_impl import IQLImpl, IQLModules +from .torch.iql_impl import IQLImpl, IQLModules, DiscreteIQLImpl, DiscreteIQLModules -__all__ = ["IQLConfig", "IQL"] +__all__ = ["IQLConfig", "IQL", "DiscreteIQLConfig", "DiscreteIQL"] @dataclasses.dataclass() @@ -175,5 +178,100 @@ def inner_create_impl( def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS +@dataclasses.dataclass() +class DiscreteIQLConfig(LearnableConfig): + actor_learning_rate: float = 3e-4 + critic_learning_rate: float = 3e-4 + + q_func_factory: QFunctionFactory = make_q_func_field() + encoder_factory: EncoderFactory = make_encoder_field() + value_encoder_factory: EncoderFactory = make_encoder_field() + critic_optim_factory: OptimizerFactory = make_optimizer_field() + + actor_encoder_factory: EncoderFactory = make_encoder_field() + actor_optim_factory: OptimizerFactory = make_optimizer_field() + + batch_size: int = 256 + gamma: float = 0.99 + tau: float = 0.005 + n_critics: int = 2 + expectile: float = 0.7 + weight_temp: float = 3.0 + max_weight: float = 100.0 + + def create(self, device: DeviceArg = False) -> "DiscreteIQL": + return DiscreteIQL(self, device) + + @staticmethod + def get_type() -> str: + return "discrete_iql" + +class DiscreteIQL(QLearningAlgoBase[DiscreteIQLImpl, DiscreteIQLConfig]): + def inner_create_impl( + self, observation_shape: Shape, action_size: int + ) -> None: + policy = create_categorical_policy( + observation_shape, + action_size, + self._config.actor_encoder_factory, + self._device, + ) + q_funcs, q_func_forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( + observation_shape, + action_size, + self._config.encoder_factory, + self._config.q_func_factory, + n_ensembles=self._config.n_critics, + device=self._device, + ) + value_func = create_value_function( + observation_shape, + self._config.value_encoder_factory, + device=self._device, + ) + + q_func_params = list(q_funcs.named_modules()) + v_func_params = list(value_func.named_modules()) + critic_optim = self._config.critic_optim_factory.create( + q_func_params + v_func_params, lr=self._config.critic_learning_rate + ) + actor_optim = self._config.actor_optim_factory.create( + policy.named_modules(), lr=self._config.actor_learning_rate + ) + + modules = DiscreteIQLModules( + policy=policy, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + value_func=value_func, + actor_optim=actor_optim, + critic_optim=critic_optim, + ) + + self._impl = DiscreteIQLImpl( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=self._config.gamma, + tau=self._config.tau, + expectile=self._config.expectile, + weight_temp=self._config.weight_temp, + max_weight=self._config.max_weight, + device=self._device, + ) + + def get_action_type(self) -> ActionSpace: + return ActionSpace.DISCRETE register_learnable(IQLConfig) +register_learnable(DiscreteIQLConfig) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 1ce7181f..8d613720 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -157,6 +157,116 @@ def q_function_optim(self) -> Optimizer: return self._modules.critic_optim +class DiscreteDDPGBaseImpl( + ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta +): + _modules: DDPGBaseModules + _gamma: float + _tau: float + _q_func_forwarder: ContinuousEnsembleQFunctionForwarder + _targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder + + def __init__( + self, + observation_shape: Shape, + action_size: int, + modules: DDPGBaseModules, + q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + gamma: float, + tau: float, + device: str, + ): + super().__init__( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + device=device, + ) + self._gamma = gamma + self._tau = tau + self._q_func_forwarder = q_func_forwarder + self._targ_q_func_forwarder = targ_q_func_forwarder + hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) + + def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: + self._modules.critic_optim.zero_grad() + q_tpn = self.compute_target(batch) + loss = self.compute_critic_loss(batch, q_tpn) + loss.critic_loss.backward() + self._modules.critic_optim.step() + return asdict_as_float(loss) + + def compute_critic_loss( + self, batch: TorchMiniBatch, q_tpn: torch.Tensor + ) -> DDPGBaseCriticLoss: + loss = self._q_func_forwarder.compute_error( + observations=batch.observations, + actions=batch.actions, + rewards=batch.rewards, + target=q_tpn, + terminals=batch.terminals, + gamma=self._gamma**batch.intervals, + ) + return DDPGBaseCriticLoss(loss) + + def update_actor( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> Dict[str, float]: + # Q function should be inference mode for stability + self._modules.q_funcs.eval() + self._modules.actor_optim.zero_grad() + loss = self.compute_actor_loss(batch, action) + loss.actor_loss.backward() + self._modules.actor_optim.step() + return asdict_as_float(loss) + + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + metrics = {} + action = self._modules.policy(batch.observations) + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch, action)) + self.update_critic_target() + return metrics + + @abstractmethod + def compute_actor_loss( + self, batch: TorchMiniBatch, action: ActionOutput + ) -> DDPGBaseActorLoss: + pass + + @abstractmethod + def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: + pass + + def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: + return torch.argmax(self._modules.policy(x).mean).unsqueeze(0) + + @abstractmethod + def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: + pass + + def update_critic_target(self) -> None: + soft_sync(self._modules.targ_q_funcs, self._modules.q_funcs, self._tau) + + @property + def policy(self) -> Policy: + return self._modules.policy + + @property + def policy_optim(self) -> Optimizer: + return self._modules.actor_optim + + @property + def q_function(self) -> nn.ModuleList: + return self._modules.q_funcs + + @property + def q_function_optim(self) -> Optimizer: + return self._modules.critic_optim + @dataclasses.dataclass(frozen=True) class DDPGModules(DDPGBaseModules): targ_policy: Policy diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 1dfd29d0..0fb21462 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -1,11 +1,14 @@ import dataclasses import torch +import torch.nn.functional as F from ....models.torch import ( ActionOutput, ContinuousEnsembleQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, NormalPolicy, + CategoricalPolicy, ValueFunction, build_gaussian_distribution, ) @@ -15,11 +18,11 @@ DDPGBaseActorLoss, DDPGBaseCriticLoss, DDPGBaseImpl, + DiscreteDDPGBaseImpl, DDPGBaseModules, ) -__all__ = ["IQLImpl", "IQLModules"] - +__all__ = ["IQLImpl", "IQLModules", "DiscreteIQLImpl", "DiscreteIQLModules"] @dataclasses.dataclass(frozen=True) class IQLModules(DDPGBaseModules): @@ -120,3 +123,102 @@ def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: dist = build_gaussian_distribution(self._modules.policy(x)) return dist.sample() + + +@dataclasses.dataclass(frozen=True) +class DiscreteIQLModules(DDPGBaseModules): + policy: CategoricalPolicy + value_func: ValueFunction + + +@dataclasses.dataclass(frozen=True) +class DiscreteIQLCriticLoss(DDPGBaseCriticLoss): + q_loss: torch.Tensor + v_loss: torch.Tensor + + +class DiscreteIQLImpl(DiscreteDDPGBaseImpl): + _modules: DiscreteIQLModules + _expectile: float + _weight_temp: float + _max_weight: float + + def __init__( + self, + observation_shape: Shape, + action_size: int, + modules: DiscreteIQLModules, + q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + gamma: float, + tau: float, + expectile: float, + weight_temp: float, + max_weight: float, + device: str, + ): + super().__init__( + observation_shape=observation_shape, + action_size=action_size, + modules=modules, + q_func_forwarder=q_func_forwarder, + targ_q_func_forwarder=targ_q_func_forwarder, + gamma=gamma, + tau=tau, + device=device, + ) + self._expectile = expectile + self._weight_temp = weight_temp + self._max_weight = max_weight + + def compute_critic_loss( + self, batch: TorchMiniBatch, q_tpn: torch.Tensor + ) -> IQLCriticLoss: + q_loss = self._q_func_forwarder.compute_error( + observations=batch.observations, + actions=batch.actions.long(), + rewards=batch.rewards, + target=q_tpn, + terminals=batch.terminals, + gamma=self._gamma**batch.intervals, + ) + v_loss = self.compute_value_loss(batch) + return IQLCriticLoss( + critic_loss=q_loss + v_loss, + q_loss=q_loss, + v_loss=v_loss, + ) + + def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: + with torch.no_grad(): + return self._modules.value_func(batch.next_observations) + + def compute_actor_loss(self, batch: TorchMiniBatch, action) -> DDPGBaseActorLoss: + assert self._modules.policy + # compute weight + with torch.no_grad(): + v = self._modules.value_func(batch.observations) + min_Q = self._targ_q_func_forwarder.compute_target(batch.observations, reduction="min").gather( + 1, batch.actions.long() + ) + + exp_a = torch.exp((min_Q - v) * self._weight_temp).clamp(max=self._max_weight) + # compute log probability + dist = self._modules.policy(batch.observations) + log_probs = dist.log_prob(batch.actions.squeeze(-1)).unsqueeze(1) + + return DDPGBaseActorLoss(-(exp_a * log_probs).mean()) + + def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: + q_t = self._targ_q_func_forwarder.compute_expected_q(batch.observations) + one_hot = F.one_hot(batch.actions.long().view(-1), num_classes=self.action_size) + q_t = (q_t * one_hot).sum(dim=1, keepdim=True) + + v_t = self._modules.value_func(batch.observations) + diff = q_t.detach() - v_t + weight = (self._expectile - (diff < 0.0).float()).abs().detach() + return (weight * (diff**2)).mean() + + def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: + dist = self._modules.policy(x) + return dist.sample() From 252021103346675463a62064b9ede1d2d77eb286 Mon Sep 17 00:00:00 2001 From: Jin Zhu Date: Mon, 15 Jul 2024 21:53:34 +0100 Subject: [PATCH 2/3] fix bug on prediction --- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 8d613720..73541851 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -242,7 +242,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: pass def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: - return torch.argmax(self._modules.policy(x).mean).unsqueeze(0) + return torch.argmax(self._modules.policy(x).probs).unsqueeze(0) @abstractmethod def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: From f4cd4331455aee02e58cbb4da4bb5503f404d6d2 Mon Sep 17 00:00:00 2001 From: Jin Zhu Date: Mon, 22 Jul 2024 20:47:38 +0100 Subject: [PATCH 3/3] update code --- d3rlpy/algos/qlearning/iql.py | 86 +++++++++++++++++++++-- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 18 ++--- d3rlpy/algos/qlearning/torch/iql_impl.py | 23 +++--- tests/algos/qlearning/test_iql.py | 28 +++++++- 4 files changed, 132 insertions(+), 23 deletions(-) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index 8150ebdf..cbdcb9cd 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -3,19 +3,27 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace from ...models.builders import ( + create_categorical_policy, create_continuous_q_function, create_discrete_q_function, - create_categorical_policy, create_normal_policy, create_value_function, ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field -from ...models.q_functions import MeanQFunctionFactory -from ...models.q_functions import QFunctionFactory, make_q_func_field +from ...models.q_functions import ( + MeanQFunctionFactory, + QFunctionFactory, + make_q_func_field, +) from ...types import Shape from .base import QLearningAlgoBase -from .torch.iql_impl import IQLImpl, IQLModules, DiscreteIQLImpl, DiscreteIQLModules +from .torch.iql_impl import ( + DiscreteIQLImpl, + DiscreteIQLModules, + IQLImpl, + IQLModules, +) __all__ = ["IQLConfig", "IQL", "DiscreteIQLConfig", "DiscreteIQL"] @@ -178,8 +186,72 @@ def inner_create_impl( def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS + @dataclasses.dataclass() class DiscreteIQLConfig(LearnableConfig): + r"""Implicit Q-Learning algorithm. + + IQL is the offline RL algorithm that avoids ever querying values of unseen + actions while still being able to perform multi-step dynamic programming + updates. + + There are three functions to train in IQL. First the state-value function + is trained via expectile regression. + + .. math:: + + L_V(\psi) = \mathbb{E}_{(s, a) \sim D} + [L_2^\tau (Q_\theta (s, a) - V_\psi (s))] + + where :math:`L_2^\tau (u) = |\tau - \mathbb{1}(u < 0)|u^2`. + + The Q-function is trained with the state-value function to avoid query the + actions. + + .. math:: + + L_Q(\theta) = \mathbb{E}_{(s, a, r, s') \sim D} + [(r + \gamma V_\psi(s') - Q_\theta(s, a))^2] + + Finally, the policy function is trained by using advantage weighted + regression compared with `IQL`, here we use a categorical policy. + + .. math:: + + L_\pi (\phi) = \mathbb{E}_{(s, a) \sim D} + [\exp(\beta (Q_\theta - V_\psi(s))) \log \pi_\phi(a|s)] + + References: + * `Kostrikov et al., Offline Reinforcement Learning with Implicit + Q-Learning. `_ + + Args: + observation_scaler (d3rlpy.preprocessing.ObservationScaler): + Observation preprocessor. + action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. + reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. + actor_learning_rate (float): Learning rate for policy function. + critic_learning_rate (float): Learning rate for Q functions. + actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): + Optimizer factory for the actor. + critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): + Optimizer factory for the critic. + actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the actor. + critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the critic. + value_encoder_factory (d3rlpy.models.encoders.EncoderFactory): + Encoder factory for the value function. + batch_size (int): Mini-batch size. + gamma (float): Discount factor. + tau (float): Target network synchronization coefficiency. + n_critics (int): Number of Q functions for ensemble. + expectile (float): Expectile value for value function training. + weight_temp (float): Inverse temperature value represented as + :math:`\beta`. + max_weight (float): Maximum advantage weight value to clip. + """ + actor_learning_rate: float = 3e-4 critic_learning_rate: float = 3e-4 @@ -187,10 +259,10 @@ class DiscreteIQLConfig(LearnableConfig): encoder_factory: EncoderFactory = make_encoder_field() value_encoder_factory: EncoderFactory = make_encoder_field() critic_optim_factory: OptimizerFactory = make_optimizer_field() - + actor_encoder_factory: EncoderFactory = make_encoder_field() actor_optim_factory: OptimizerFactory = make_optimizer_field() - + batch_size: int = 256 gamma: float = 0.99 tau: float = 0.005 @@ -206,6 +278,7 @@ def create(self, device: DeviceArg = False) -> "DiscreteIQL": def get_type() -> str: return "discrete_iql" + class DiscreteIQL(QLearningAlgoBase[DiscreteIQLImpl, DiscreteIQLConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int @@ -273,5 +346,6 @@ def inner_create_impl( def get_action_type(self) -> ActionSpace: return ActionSpace.DISCRETE + register_learnable(IQLConfig) register_learnable(DiscreteIQLConfig) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 73541851..e3587a6a 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -10,12 +10,13 @@ from ....models.torch import ( ActionOutput, ContinuousEnsembleQFunctionForwarder, + DiscreteEnsembleQFunctionForwarder, Policy, ) from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync from ....types import Shape, TorchObservation from ..base import QLearningAlgoImplBase -from .utility import ContinuousQFunctionMixin +from .utility import ContinuousQFunctionMixin, DiscreteQFunctionMixin __all__ = [ "DDPGImpl", @@ -158,21 +159,21 @@ def q_function_optim(self) -> Optimizer: class DiscreteDDPGBaseImpl( - ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta + DiscreteQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta ): _modules: DDPGBaseModules _gamma: float _tau: float - _q_func_forwarder: ContinuousEnsembleQFunctionForwarder - _targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder + _q_func_forwarder: DiscreteEnsembleQFunctionForwarder + _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder def __init__( self, observation_shape: Shape, action_size: int, modules: DDPGBaseModules, - q_func_forwarder: ContinuousEnsembleQFunctionForwarder, - targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, + q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, gamma: float, tau: float, device: str, @@ -216,7 +217,7 @@ def update_actor( # Q function should be inference mode for stability self._modules.q_funcs.eval() self._modules.actor_optim.zero_grad() - loss = self.compute_actor_loss(batch, action) + loss = self.compute_actor_loss(batch, None) loss.actor_loss.backward() self._modules.actor_optim.step() return asdict_as_float(loss) @@ -233,7 +234,7 @@ def inner_update( @abstractmethod def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: None ) -> DDPGBaseActorLoss: pass @@ -267,6 +268,7 @@ def q_function(self) -> nn.ModuleList: def q_function_optim(self) -> Optimizer: return self._modules.critic_optim + @dataclasses.dataclass(frozen=True) class DDPGModules(DDPGBaseModules): targ_policy: Policy diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 0fb21462..b2c54cb0 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -5,10 +5,10 @@ from ....models.torch import ( ActionOutput, + CategoricalPolicy, ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, NormalPolicy, - CategoricalPolicy, ValueFunction, build_gaussian_distribution, ) @@ -18,12 +18,13 @@ DDPGBaseActorLoss, DDPGBaseCriticLoss, DDPGBaseImpl, - DiscreteDDPGBaseImpl, DDPGBaseModules, + DiscreteDDPGBaseImpl, ) __all__ = ["IQLImpl", "IQLModules", "DiscreteIQLImpl", "DiscreteIQLModules"] + @dataclasses.dataclass(frozen=True) class IQLModules(DDPGBaseModules): policy: NormalPolicy @@ -193,16 +194,20 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): return self._modules.value_func(batch.next_observations) - def compute_actor_loss(self, batch: TorchMiniBatch, action) -> DDPGBaseActorLoss: + def compute_actor_loss( + self, batch: TorchMiniBatch, action: None + ) -> DDPGBaseActorLoss: assert self._modules.policy # compute weight with torch.no_grad(): v = self._modules.value_func(batch.observations) - min_Q = self._targ_q_func_forwarder.compute_target(batch.observations, reduction="min").gather( - 1, batch.actions.long() - ) + min_Q = self._targ_q_func_forwarder.compute_target( + batch.observations, reduction="min" + ).gather(1, batch.actions.long()) - exp_a = torch.exp((min_Q - v) * self._weight_temp).clamp(max=self._max_weight) + exp_a = torch.exp((min_Q - v) * self._weight_temp).clamp( + max=self._max_weight + ) # compute log probability dist = self._modules.policy(batch.observations) log_probs = dist.log_prob(batch.actions.squeeze(-1)).unsqueeze(1) @@ -211,7 +216,9 @@ def compute_actor_loss(self, batch: TorchMiniBatch, action) -> DDPGBaseActorLoss def compute_value_loss(self, batch: TorchMiniBatch) -> torch.Tensor: q_t = self._targ_q_func_forwarder.compute_expected_q(batch.observations) - one_hot = F.one_hot(batch.actions.long().view(-1), num_classes=self.action_size) + one_hot = F.one_hot( + batch.actions.long().view(-1), num_classes=self.action_size + ) q_t = (q_t * one_hot).sum(dim=1, keepdim=True) v_t = self._modules.value_func(batch.observations) diff --git a/tests/algos/qlearning/test_iql.py b/tests/algos/qlearning/test_iql.py index 32c36864..c7598029 100644 --- a/tests/algos/qlearning/test_iql.py +++ b/tests/algos/qlearning/test_iql.py @@ -2,8 +2,11 @@ import pytest -from d3rlpy.algos.qlearning.iql import IQLConfig +from d3rlpy.algos.qlearning.iql import DiscreteIQLConfig, IQLConfig from d3rlpy.types import Shape +from d3rlpy.models import ( + QFunctionFactory, +) from ...models.torch.model_test import DummyEncoderFactory from ...testing_utils import create_scaler_tuple @@ -28,3 +31,26 @@ def test_iql(observation_shape: Shape, scalers: Optional[str]) -> None: ) iql = config.create() algo_tester(iql, observation_shape) # type: ignore + + +@pytest.mark.parametrize( + "observation_shape", [(100,), (4, 84, 84), ((100,), (200,))] +) +@pytest.mark.parametrize("scalers", [None, "min_max"]) +def test_discrete_iql( + observation_shape: Shape, + q_func_factory: QFunctionFactory, + scalers: Optional[str]) -> None: + observation_scaler, _, reward_scaler = create_scaler_tuple( + scalers, observation_shape + ) + config = DiscreteIQLConfig( + actor_encoder_factory=DummyEncoderFactory(), + encoder_factory=DummyEncoderFactory(), + value_encoder_factory=DummyEncoderFactory(), + q_func_factory=q_func_factory, + observation_scaler=observation_scaler, + reward_scaler=reward_scaler, + ) + iql = config.create() + algo_tester(iql, observation_shape) # type: ignore