Các tác giả trong bài báo lập luận về tầm quan trọng của việc học phân phối lợi nhuận thay vì lợi nhuận kỳ vọng. Và họ đề xuất mô hình hóa các phân phối như vậy với khối lượng xác suất được đặt trên một hỗ trợ rời rạc
[imath]z[/imath]
trong đó
[imath]z[/imath]
là một vectơ với
[imath]N_{atoms} \in \mathbb{N}^+
[/imath]
nguyên tử (atom). xác định bởi
[imath]z_i = V_{min} + (i-1) \frac{V_{max} - V_{min}}{N-1}
[/imath]
với
[imath]i \in \{1, ..., N_{atoms}\}
[/imath]
The key insight là phân phối trả về thỏa mãn một biến thể của phương trình Bellman. Đối với một trạng thái nhất định
[imath]S_t[/imath]
và hành động
[imath]A_t[/imath]
. Phân phối lợi tức theo chính sách tối ưu
[imath]π∗[/imath]
phải khớp với phân phối mục tiêu được xác định bằng cách lấy phân phối cho trạng thái tiếp theo
[imath]S_{t + 1}[/imath]
và action
[imath]a^{*}_{t+1} = \pi^{*}(S_{t+1})
[/imath]
quy ước nó về 0 theo chiết khấu và chuyển nó theo phần thưởng (hoặc phân phối phần thưởng, trong trường hợp ngẫu nhiên). Một biến thể phân phối của Q-learning sau đó được tạo ra bằng cách xây dựng một hỗ trợ mới cho phân phối đích, và sau đó giảm thiểu sự phân kỳ Kullbeck-Leibler giữa phân phối
[imath]d_t[/imath]
và phân phối đích
[math]
d_t' = (R_{t+1} + \gamma_{t+1} z, p_{\hat{\theta}} (S_{t+1}, \hat{a}^{*}_{t+1})), \\
D_{KL} (\phi_z d_t' \| d_t).
[/math]
Ở đây
[imath]ϕ_z[/imath]
là phép chiếu L2 của phân phối mục tiêu lên hỗ trợ cố định
[imath]z[/imath]
và
[imath]\hat{a}^*_{t+1} = \arg\max_{a} q_{\hat{\theta}} (S_{t+1}, a)
[/imath]
là hành động tham lam đối với các giá trị hành động có ý nghĩa.
[imath]q_{\hat{\theta}} (S_{t+1}, a) = z^{T}p_{\theta}(S_{t+1}, a)
[/imath]
ở trạng thái
[imath]S_{t+1}[/imath]
Network
Phân phối được tham số hóa có thể được biểu diễn bằng mạng nơ-ron, như trong DQN, nhưng với đầu ra atom_size x out_dim
. Một softmax được áp dụng độc lập cho từng action dimension của đầu ra để đảm bảo rằng phân phối cho từng hành động được chuẩn hóa một cách thích hợp.
Để ước tính giá trị
[imath]q[/imath]
, chúng ta sử dụng sản phẩm bên trong của phân phối và hỗ trợ softmax của mỗi hành động là tập hợp các atom
[imath]\{z_i = V_{min} + i\Delta z: 0 \le i < N\}, \Delta z = \frac{V_{max} - V_{min}}{N-1}
[/imath]
[math]
Q(s_t, a_t) = \sum_i z_i p_i(s_t, a_t), \\
\text{where } p_i \text{ is the probability of } z_i \text{ (the output of softmax)}.
[/math]
class Network(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
atom_size: int,
support: torch.Tensor
):
"""Initialization."""
super(Network, self).__init__()
self.support = support
self.out_dim = out_dim
self.atom_size = atom_size
self.layers = nn.Sequential(
nn.Linear(in_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, out_dim * atom_size)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward method implementation."""
dist = self.dist(x)
q = torch.sum(dist * self.support, dim=2)
return q
def dist(self, x: torch.Tensor) -> torch.Tensor:
"""Get distribution for atoms."""
q_atoms = self.layers(x).view(-1, self.out_dim, self.atom_size)
dist = F.softmax(q_atoms, dim=-1)
dist = dist.clamp(min=1e-3) # for avoiding nans
return dist
Categorical DQN Agent
class DQNAgent:
"""DQN Agent interacting with environment.
Attribute:
env (gym.Env): openAI Gym environment
memory (ReplayBuffer): replay memory to store transitions
batch_size (int): batch size for sampling
epsilon (float): parameter for epsilon greedy policy
epsilon_decay (float): step size to decrease epsilon
max_epsilon (float): max value of epsilon
min_epsilon (float): min value of epsilon
target_update (int): period for target model's hard update
gamma (float): discount factor
dqn (Network): model to train and select actions
dqn_target (Network): target model to update
optimizer (torch.optim): optimizer for training dqn
transition (list): transition information including
state, action, reward, next_state, done
v_min (float): min value of support
v_max (float): max value of support
atom_size (int): the unit number of support
support (torch.Tensor): support for categorical dqn
"""
def __init__(
self,
env: gym.Env,
memory_size: int,
batch_size: int,
target_update: int,
epsilon_decay: float,
max_epsilon: float = 1.0,
min_epsilon: float = 0.1,
gamma: float = 0.99,
# Categorical DQN parameters
v_min: float = 0.0,
v_max: float = 200.0,
atom_size: int = 51,
):
"""Initialization.
Args:
env (gym.Env): openAI Gym environment
memory_size (int): length of memory
batch_size (int): batch size for sampling
target_update (int): period for target model's hard update
epsilon_decay (float): step size to decrease epsilon
lr (float): learning rate
max_epsilon (float): max value of epsilon
min_epsilon (float): min value of epsilon
gamma (float): discount factor
v_min (float): min value of support
v_max (float): max value of support
atom_size (int): the unit number of support
"""
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
self.env = env
self.memory = ReplayBuffer(obs_dim, memory_size, batch_size)
self.batch_size = batch_size
self.epsilon = max_epsilon
self.epsilon_decay = epsilon_decay
self.max_epsilon = max_epsilon
self.min_epsilon = min_epsilon
self.target_update = target_update
self.gamma = gamma
# device: cpu / gpu
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(self.device)
# Categorical DQN parameters
self.v_min = v_min
self.v_max = v_max
self.atom_size = atom_size
self.support = torch.linspace(
self.v_min, self.v_max, self.atom_size
).to(self.device)
# networks: dqn, dqn_target
self.dqn = Network(
obs_dim, action_dim, atom_size, self.support
).to(self.device)
self.dqn_target = Network(
obs_dim, action_dim, atom_size, self.support
).to(self.device)
self.dqn_target.load_state_dict(self.dqn.state_dict())
self.dqn_target.eval()
# optimizer
self.optimizer = optim.Adam(self.dqn.parameters())
# transition to store in memory
self.transition = list()
# mode: train / test
self.is_test = False
def select_action(self, state: np.ndarray) -> np.ndarray:
"""Select an action from the input state."""
# epsilon greedy policy
if self.epsilon > np.random.random():
selected_action = self.env.action_space.sample()
else:
selected_action = self.dqn(
torch.FloatTensor(state).to(self.device),
).argmax()
selected_action = selected_action.detach().cpu().numpy()
if not self.is_test:
self.transition = [state, selected_action]
return selected_action
def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
"""Take an action and return the response of the env."""
next_state, reward, done, _ = self.env.step(action)
if not self.is_test:
self.transition += [reward, next_state, done]
self.memory.store(*self.transition)
return next_state, reward, done
def update_model(self) -> torch.Tensor:
"""Update the model by gradient descent."""
samples = self.memory.sample_batch()
loss = self._compute_dqn_loss(samples)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def train(self, num_frames: int, plotting_interval: int = 200):
"""Train the agent."""
self.is_test = False
state = self.env.reset()
update_cnt = 0
epsilons = []
losses = []
scores = []
score = 0
for frame_idx in range(1, num_frames + 1):
action = self.select_action(state)
next_state, reward, done = self.step(action)
state = next_state
score += reward
# if episode ends
if done:
state = self.env.reset()
scores.append(score)
score = 0
# if training is ready
if len(self.memory) >= self.batch_size:
loss = self.update_model()
losses.append(loss)
update_cnt += 1
# linearly decrease epsilon
self.epsilon = max(
self.min_epsilon, self.epsilon - (
self.max_epsilon - self.min_epsilon
) * self.epsilon_decay
)
epsilons.append(self.epsilon)
# if hard update is needed
if update_cnt % self.target_update == 0:
self._target_hard_update()
# plotting
if frame_idx % plotting_interval == 0:
self._plot(frame_idx, scores, losses, epsilons)
self.env.close()
def test(self, video_folder: str) -> None:
"""Test the agent."""
self.is_test = True
# for recording a video
naive_env = self.env
self.env = gym.wrappers.RecordVideo(self.env, video_folder=video_folder)
state = self.env.reset()
done = False
score = 0
while not done:
action = self.select_action(state)
next_state, reward, done = self.step(action)
state = next_state
score += reward
print("score: ", score)
self.env.close()
# reset
self.env = naive_env
def _compute_dqn_loss(self, samples: Dict[str, np.ndarray]) -> torch.Tensor:
"""Return categorical dqn loss."""
device = self.device # for shortening the following lines
state = torch.FloatTensor(samples["obs"]).to(device)
next_state = torch.FloatTensor(samples["next_obs"]).to(device)
action = torch.LongTensor(samples["acts"]).to(device)
reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
# Categorical DQN algorithm
delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1)
with torch.no_grad():
next_action = self.dqn_target(next_state).argmax(1)
next_dist = self.dqn_target.dist(next_state)
next_dist = next_dist[range(self.batch_size), next_action]
t_z = reward + (1 - done) * self.gamma * self.support
t_z = t_z.clamp(min=self.v_min, max=self.v_max)
b = (t_z - self.v_min) / delta_z
l = b.floor().long()
u = b.ceil().long()
offset = (
torch.linspace(
0, (self.batch_size - 1) * self.atom_size, self.batch_size
).long()
.unsqueeze(1)
.expand(self.batch_size, self.atom_size)
.to(self.device)
)
proj_dist = torch.zeros(next_dist.size(), device=self.device)
proj_dist.view(-1).index_add_(
0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
)
proj_dist.view(-1).index_add_(
0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
)
dist = self.dqn.dist(state)
log_p = torch.log(dist[range(self.batch_size), action])
loss = -(proj_dist * log_p).sum(1).mean()
return loss
def _target_hard_update(self):
"""Hard update: target <- local."""
self.dqn_target.load_state_dict(self.dqn.state_dict())
def _plot(
self,
frame_idx: int,
scores: List[float],
losses: List[float],
epsilons: List[float],
):
"""Plot the training progresses."""
clear_output(True)
plt.figure(figsize=(20, 5))
plt.subplot(131)
plt.title('frame %s. score: %s' % (frame_idx, np.mean(scores[-10:])))
plt.plot(scores)
plt.subplot(132)
plt.title('loss')
plt.plot(losses)
plt.subplot(133)
plt.title('epsilons')
plt.plot(epsilons)
plt.show()