Kiến trúc mạng được đề xuất, được đặt tên là kiến trúc dueling, tách biệt rõ ràng việc biểu diễn các giá trị trạng thái và các lợi thế hành động
Mạng dueling tự động tạo ra các ước tính riêng biệt về chức năng giá trị trạng thái và chức năng lợi thế mà không cần bất kỳ sự giám sát bổ sung nào. Một cách trực quan, kiến trúc dueling có thể tìm hiểu trạng thái nào là có giá trị mà không cần phải tìm hiểu tác động của từng hành động đối với mỗi trạng thái. Điều này đặc biệt hữu ích ở những nơi mà các hành động của nó không ảnh hưởng đến môi trường theo bất kỳ cách nào có liên quan.
Kiến trúc kết hợp biểu thị cả giá trị
[imath]V (s)[/imath]
và lợi thế
[imath]A (s, a)[/imath]
chức năng với một mô hình sâu duy nhất mà đầu ra của nó kết hợp cả hai để tạo ra giá trị trạng thái-hành động
[imath]Q(s , a)[/imath]
. Không giống như trong cập nhật lợi thế, biểu diễn và thuật toán được tách biệt theo cấu trúc.
[math]
A^\pi (s, a) = Q^\pi (s, a) - V^\pi (s).
[/math]
Hàm giá trị
[imath]V[/imath]
đo lường mức độ tốt của nó trong một trạng thái cụ thể
[imath]s[/imath]
. Tuy nhiên, hàm
[imath]Q [/imath]
đo lường giá trị của việc chọn một hành động cụ thể khi ở trạng thái này. Bây giờ, sử dụng định nghĩa về lợi thế, chúng ta có thể muốn xây dựng mô-đun tổng hợp như sau:
[math]
Q(s, a; \theta, \alpha, \beta) = V (s; \theta, \beta) + A(s, a; \theta, \alpha),
[/math]
trong đó
[imath]θ[/imath]
biểu thị các tham số của các lớp chập, trong khi
[imath]α[/imath]
và
[imath]β[/imath]
là các tham số của hai luồng của các lớp được kết nối đầy đủ.
Thật không may, phương trình trên là không xác định được theo nghĩa là đã cho
[imath]Q[/imath]
, chúng ta không thể khôi phục
[imath]V[/imath]
và
[imath]A[/imath]
duy nhất; ví dụ, có các cặp
[imath]V[/imath]
và
[imath]A [/imath]
không đếm được làm cho giá trị
[imath]Q[/imath]
bằng không. Để giải quyết vấn đề về khả năng nhận dạng này, chúng ta có thể buộc công cụ ước tính hàm lợi thế không có lợi thế ở hành động đã chọn. Chúng ta để mô-đun cuối cùng của mạng thực hiện ánh xạ chuyển tiếp.
[math]
Q(s, a; \theta, \alpha, \beta) = V (s; \theta, \beta) + \big( A(s, a; \theta, \alpha) - \max_{a' \in |\mathcal{A}|} A(s, a'; \theta, \alpha) \big).
[/math]
Công thức này đảm bảo rằng chúng ta có thể khôi phục
[imath]V[/imath]
và
[imath]A[/imath]
duy nhất, nhưng việc tối ưu hóa không ổn định như vậy vì các lợi thế phải bù đắp bất kỳ thay đổi nào đối với lợi thế của hành động tối ưu. Vì lý do này, một mô-đun thay thế thay thế toán tử max bằng giá trị trung bình được đề xuất:
[math]
Q(s, a; \theta, \alpha, \beta) = V (s; \theta, \beta) + \big( A(s, a; \theta, \alpha) - \frac{1}{|\mathcal{A}|} \sum_{a'} A(s, a'; \theta, \alpha) \big).
[/math]
Không giống như dạng tối đa lợi thế, trong công thức này, lợi thế chỉ cần thay đổi càng nhanh càng tốt, do đó nó làm tăng tính ổn định của tối ưu hóa.
Cài đặt
import thư viện
import os
from typing import Dict, List, Tuple
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IPython.display import clear_output
from torch.nn.utils import clip_grad_norm_
Replay buffer
class ReplayBuffer:
"""A simple numpy replay buffer."""
def __init__(self, obs_dim: int, size: int, batch_size: int = 32):
self.obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
self.next_obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
self.acts_buf = np.zeros([size], dtype=np.float32)
self.rews_buf = np.zeros([size], dtype=np.float32)
self.done_buf = np.zeros(size, dtype=np.float32)
self.max_size, self.batch_size = size, batch_size
self.ptr, self.size, = 0, 0
def store(
self,
obs: np.ndarray,
act: np.ndarray,
rew: float,
next_obs: np.ndarray,
done: bool,
):
self.obs_buf[self.ptr] = obs
self.next_obs_buf[self.ptr] = next_obs
self.acts_buf[self.ptr] = act
self.rews_buf[self.ptr] = rew
self.done_buf[self.ptr] = done
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample_batch(self) -> Dict[str, np.ndarray]:
idxs = np.random.choice(self.size, size=self.batch_size, replace=False)
return dict(obs=self.obs_buf[idxs],
next_obs=self.next_obs_buf[idxs],
acts=self.acts_buf[idxs],
rews=self.rews_buf[idxs],
done=self.done_buf[idxs])
def __len__(self) -> int:
return self.size
Dueling Network
class Network(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
"""Initialization."""
super(Network, self).__init__()
# set common feature layer
self.feature_layer = nn.Sequential(
nn.Linear(in_dim, 128),
nn.ReLU(),
)
# set advantage layer
self.advantage_layer = nn.Sequential(
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, out_dim),
)
# set value layer
self.value_layer = nn.Sequential(
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward method implementation."""
feature = self.feature_layer(x)
value = self.value_layer(feature)
advantage = self.advantage_layer(feature)
q = value + advantage - advantage.mean(dim=-1, keepdim=True)
return q
DQN + DuelingNet Agent (w/o Double-DQN & PER)
việc sử dụng clip_gradnorm để ngăn chặn sự bùng nổ gradient.
class DQNAgent:
"""DQN Agent interacting with environment.
Attribute:
env (gym.Env): openAI Gym environment
memory (ReplayBuffer): replay memory to store transitions
batch_size (int): batch size for sampling
epsilon (float): parameter for epsilon greedy policy
epsilon_decay (float): step size to decrease epsilon
max_epsilon (float): max value of epsilon
min_epsilon (float): min value of epsilon
target_update (int): period for target model's hard update
gamma (float): discount factor
dqn (Network): model to train and select actions
dqn_target (Network): target model to update
optimizer (torch.optim): optimizer for training dqn
transition (list): transition information including
state, action, reward, next_state, done
"""
def __init__(
self,
env: gym.Env,
memory_size: int,
batch_size: int,
target_update: int,
epsilon_decay: float,
max_epsilon: float = 1.0,
min_epsilon: float = 0.1,
gamma: float = 0.99,
):
"""Initialization.
Args:
env (gym.Env): openAI Gym environment
memory_size (int): length of memory
batch_size (int): batch size for sampling
target_update (int): period for target model's hard update
epsilon_decay (float): step size to decrease epsilon
lr (float): learning rate
max_epsilon (float): max value of epsilon
min_epsilon (float): min value of epsilon
gamma (float): discount factor
"""
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
self.env = env
self.memory = ReplayBuffer(obs_dim, memory_size, batch_size)
self.batch_size = batch_size
self.epsilon = max_epsilon
self.epsilon_decay = epsilon_decay
self.max_epsilon = max_epsilon
self.min_epsilon = min_epsilon
self.target_update = target_update
self.gamma = gamma
# device: cpu / gpu
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
print(self.device)
# networks: dqn, dqn_target
self.dqn = Network(obs_dim, action_dim).to(self.device)
self.dqn_target = Network(obs_dim, action_dim).to(self.device)
self.dqn_target.load_state_dict(self.dqn.state_dict())
self.dqn_target.eval()
# optimizer
self.optimizer = optim.Adam(self.dqn.parameters())
# transition to store in memory
self.transition = list()
# mode: train / test
self.is_test = False
def select_action(self, state: np.ndarray) -> np.ndarray:
"""Select an action from the input state."""
# epsilon greedy policy
if self.epsilon > np.random.random():
selected_action = self.env.action_space.sample()
else:
selected_action = self.dqn(
torch.FloatTensor(state).to(self.device)
).argmax()
selected_action = selected_action.detach().cpu().numpy()
if not self.is_test:
self.transition = [state, selected_action]
return selected_action
def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
"""Take an action and return the response of the env."""
next_state, reward, done, _ = self.env.step(action)
if not self.is_test:
self.transition += [reward, next_state, done]
self.memory.store(*self.transition)
return next_state, reward, done
def update_model(self) -> torch.Tensor:
"""Update the model by gradient descent."""
samples = self.memory.sample_batch()
loss = self._compute_dqn_loss(samples)
self.optimizer.zero_grad()
loss.backward()
# DuelingNet: we clip the gradients to have their norm less than or equal to 10.
clip_grad_norm_(self.dqn.parameters(), 10.0)
self.optimizer.step()
return loss.item()
def train(self, num_frames: int, plotting_interval: int = 200):
"""Train the agent."""
self.is_test = False
state = self.env.reset()
update_cnt = 0
epsilons = []
losses = []
scores = []
score = 0
for frame_idx in range(1, num_frames + 1):
action = self.select_action(state)
next_state, reward, done = self.step(action)
state = next_state
score += reward
# if episode ends
if done:
state = self.env.reset()
scores.append(score)
score = 0
# if training is ready
if len(self.memory) >= self.batch_size:
loss = self.update_model()
losses.append(loss)
update_cnt += 1
# linearly decrease epsilon
self.epsilon = max(
self.min_epsilon, self.epsilon - (
self.max_epsilon - self.min_epsilon
) * self.epsilon_decay
)
epsilons.append(self.epsilon)
# if hard update is needed
if update_cnt % self.target_update == 0:
self._target_hard_update()
# plotting
if frame_idx % plotting_interval == 0:
self._plot(frame_idx, scores, losses, epsilons)
self.env.close()
def test(self, video_folder: str) -> None:
"""Test the agent."""
self.is_test = True
# for recording a video
naive_env = self.env
self.env = gym.wrappers.RecordVideo(self.env, video_folder=video_folder)
state = self.env.reset()
done = False
score = 0
while not done:
action = self.select_action(state)
next_state, reward, done = self.step(action)
state = next_state
score += reward
print("score: ", score)
self.env.close()
# reset
self.env = naive_env
def _compute_dqn_loss(self, samples: Dict[str, np.ndarray]) -> torch.Tensor:
"""Return dqn loss."""
device = self.device # for shortening the following lines
state = torch.FloatTensor(samples["obs"]).to(device)
next_state = torch.FloatTensor(samples["next_obs"]).to(device)
action = torch.LongTensor(samples["acts"].reshape(-1, 1)).to(device)
reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
# G_t = r + gamma * v(s_{t+1}) if state != Terminal
# = r otherwise
curr_q_value = self.dqn(state).gather(1, action)
next_q_value = self.dqn_target(next_state).max(
dim=1, keepdim=True
)[0].detach()
mask = 1 - done
target = (reward + self.gamma * next_q_value * mask).to(self.device)
# calculate dqn loss
loss = F.smooth_l1_loss(curr_q_value, target)
return loss
def _target_hard_update(self):
"""Hard update: target <- local."""
self.dqn_target.load_state_dict(self.dqn.state_dict())
def _plot(
self,
frame_idx: int,
scores: List[float],
losses: List[float],
epsilons: List[float],
):
"""Plot the training progresses."""
clear_output(True)
plt.figure(figsize=(20, 5))
plt.subplot(131)
plt.title('frame %s. score: %s' % (frame_idx, np.mean(scores[-10:])))
plt.plot(scores)
plt.subplot(132)
plt.title('loss')
plt.plot(losses)
plt.subplot(133)
plt.title('epsilons')
plt.plot(epsilons)
plt.show()