Việc sử dụng bộ nhớ phát lại dẫn đến các lựa chọn thiết kế ở hai cấp độ: trải nghiệm nào cần lưu trữ và trải nghiệm nào sẽ phát lại (và cách thực hiện). Trọng tâm của Prioritized Experience Replay là tiêu chí để đo lường tầm quan trọng của mỗi quá trình chuyển đổi trạng thái. Một cách tiếp cận hợp lý là sử dụng độ lớn TD error của quá trình chuyển đổi
[imath]δ[/imath]
, nó cho biết mức độ "đáng ngạc nhiên" hoặc không mong đợi của quá trình chuyển đổi. Thuật toán này lưu trữ TD error gặp phải lần cuối cùng với mỗi lần chuyển đổi trạng thái trong bộ nhớ phát lại.
Quá trình chuyển đổi với TD error tuyệt đối được phát lại từ bộ nhớ. Q-learning update được áp dụng cho quá trình chuyển đổi này, cập nhật trọng số tương ứng với TD error. Một điều cần lưu ý rằng các chuyển đổi mới đến mà không có TD-error đã biết, vì vậy nó đặt chúng ở mức ưu tiên tối đa để đảm bảo rằng tất cả trải nghiệm đều được nhìn thấy ít nhất một lần. (xem store method)
Chúng ta có thể sử dụng 2 ý tưởng để giải quyết TD error: 1. ưu tiên tham lam(greedy), 2. ưu tiên ngẫu nhiên. Tuy nhiên, ưu tiên TD error tham lam có một nhược điểm nghiêm trọng. Ưu tiên tham lam tập trung vào một tập hợp con nhỏ của trải nghiệm: error giảm dần, đặc biệt là khi sử dụng sấp sỉ hàm, có nghĩa là các chuyển đổi có lỗi cao ban đầu thường xuyên được phát lại. Sự thiếu đa dạng này làm cho hệ thống dễ bị overfit. Để khắc phục vấn đề này, chúng ta sẽ sử dụng phương pháp lấy mẫu ngẫu nhiên để nội suy giữa ưu tiên tham lam thuần túy và lấy mẫu ngẫu nhiên đồng nhất.
[math]
P(i) = \frac{p_i^{\alpha}}{\sum_k p_k^{\alpha}}
[/math]
trong đó
[imath] p_i > 0 [/imath]
là mức độ ưu tiên của quá trình chuyển đổi
[imath]i[/imath]
. Số mũ
[imath]α[/imath]
xác định mức độ ưu tiên được sử dụng, với
[imath]α = 0[/imath]
tương ứng với trường hợp đồng nhất. Trong thực tế, chúng ta sử dụng thuật ngữ bổ sung
[imath]ϵ[/imath]
để đảm bảo tất cả các chuyển đổi có thể được lấy mẫu:
[imath]p_i = | δ_i | + ϵ[/imath]
, trong đó
[imath]ϵ[/imath]
là hằng số dương nhỏ.
Một lần nữa. Hãy nhớ lại một trong những ý tưởng chính của DQN. Để loại bỏ mối tương quan của các quan sát, nó sử dụng lấy mẫu ngẫu nhiên đồng nhất từ bộ đệm phát lại. Phát lại được ưu tiên giới thiệu xu hướng vì nó không lấy mẫu trải nghiệm ngẫu nhiên một cách đồng nhất do tỷ lệ lấy mẫu tương ứng với TD-error. Chúng ta có thể sửa sai lệch này bằng cách sử dụng trọng số lấy mẫu theo mức độ quan trọng (IS)
[math]
w_i = \big( \frac{1}{N} \cdot \frac{1}{P(i)} \big)^\beta
[/math]
điều đó hoàn toàn bù đắp cho các xác suất không đồng nhất
[imath]P(i)[/imath]
nếu
[imath]β = 1[/imath]
. Các trọng số này có thể được gói lại khi update Q-learning bằng cách sử dụng
[imath]w_iδ_i [/imath]
thay vì
[imath]δ_i[/imath]
. Trong các tình huống học tập củng cố điển hình, bản chất không thiên vị của các lần update là quan trọng nhất gần hội tụ ở cuối quá trình đào tạo, Do đó, chúng ta khai thác tính linh hoạt của việc điều chỉnh lấy mẫu quan trọng theo thời gian, bằng cách xác định lịch trình trên số mũ
[imath] β[/imath]
tiến đến 1 duy nhất khi kết thúc quá trình học.
Cài đặt
Import các thư viện cần thiết
import os
import random
from typing import Dict, List, Tuple
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IPython.display import clear_output
Chúng ta sử dụng một cấu trúc dữ liệu Segment tree để biểu diễn mức độ quan trọng của các trạng thái tốt hơn.
"""Segment tree for Prioritized Replay Buffer."""
import operator
from typing import Callable
class SegmentTree:
""" Create SegmentTree.
Taken from OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
Attributes:
capacity (int)
tree (list)
operation (function)
"""
def __init__(self, capacity: int, operation: Callable, init_value: float):
"""Initialization.
Args:
capacity (int)
operation (function)
init_value (float)
"""
assert (
capacity > 0 and capacity & (capacity - 1) == 0
), "capacity must be positive and a power of 2."
self.capacity = capacity
self.tree = [init_value for _ in range(2 * capacity)]
self.operation = operation
def _operate_helper(
self, start: int, end: int, node: int, node_start: int, node_end: int
) -> float:
"""Returns result of operation in segment."""
if start == node_start and end == node_end:
return self.tree[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._operate_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._operate_helper(start, end, 2 * node + 1, mid + 1, node_end)
else:
return self.operation(
self._operate_helper(start, mid, 2 * node, node_start, mid),
self._operate_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end),
)
def operate(self, start: int = 0, end: int = 0) -> float:
"""Returns result of applying `self.operation`."""
if end <= 0:
end += self.capacity
end -= 1
return self._operate_helper(start, end, 1, 0, self.capacity - 1)
def __setitem__(self, idx: int, val: float):
"""Set value in tree."""
idx += self.capacity
self.tree[idx] = val
idx //= 2
while idx >= 1:
self.tree[idx] = self.operation(self.tree[2 * idx], self.tree[2 * idx + 1])
idx //= 2
def __getitem__(self, idx: int) -> float:
"""Get real value in leaf node of tree."""
assert 0 <= idx < self.capacity
return self.tree[self.capacity + idx]
class SumSegmentTree(SegmentTree):
""" Create SumSegmentTree.
Taken from OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
"""
def __init__(self, capacity: int):
"""Initialization.
Args:
capacity (int)
"""
super(SumSegmentTree, self).__init__(
capacity=capacity, operation=operator.add, init_value=0.0
)
def sum(self, start: int = 0, end: int = 0) -> float:
"""Returns arr[start] + ... + arr[end]."""
return super(SumSegmentTree, self).operate(start, end)
def retrieve(self, upperbound: float) -> int:
"""Find the highest index `i` about upper bound in the tree"""
# TODO: Check assert case and fix bug
assert 0 <= upperbound <= self.sum() + 1e-5, "upperbound: {}".format(upperbound)
idx = 1
while idx < self.capacity: # while non-leaf
left = 2 * idx
right = left + 1
if self.tree[left] > upperbound:
idx = 2 * idx
else:
upperbound -= self.tree[left]
idx = right
return idx - self.capacity
class MinSegmentTree(SegmentTree):
""" Create SegmentTree.
Taken from OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
"""
def __init__(self, capacity: int):
"""Initialization.
Args:
capacity (int)
"""
super(MinSegmentTree, self).__init__(
capacity=capacity, operation=min, init_value=float("inf")
)
def min(self, start: int = 0, end: int = 0) -> float:
"""Returns min(arr[start], ..., arr[end])."""
return super(MinSegmentTree, self).operate(start, end)
Replay buffer
class ReplayBuffer:
"""A simple numpy replay buffer."""
def __init__(self, obs_dim: int, size: int, batch_size: int = 32):
self.obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
self.next_obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
self.acts_buf = np.zeros([size], dtype=np.float32)
self.rews_buf = np.zeros([size], dtype=np.float32)
self.done_buf = np.zeros(size, dtype=np.float32)
self.max_size, self.batch_size = size, batch_size
self.ptr, self.size, = 0, 0
def store(
self,
obs: np.ndarray,
act: np.ndarray,
rew: float,
next_obs: np.ndarray,
done: bool,
):
self.obs_buf[self.ptr] = obs
self.next_obs_buf[self.ptr] = next_obs
self.acts_buf[self.ptr] = act
self.rews_buf[self.ptr] = rew
self.done_buf[self.ptr] = done
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample_batch(self) -> Dict[str, np.ndarray]:
idxs = np.random.choice(self.size, size=self.batch_size, replace=False)
return dict(obs=self.obs_buf[idxs],
next_obs=self.next_obs_buf[idxs],
acts=self.acts_buf[idxs],
rews=self.rews_buf[idxs],
done=self.done_buf[idxs])
def __len__(self) -> int:
return self.size
Prioritized replay Buffer
Khái niệm chính về việc triển khai PER là Cây phân đoạn. Nó lưu trữ và lấy mẫu hiệu quả các quá trình chuyển đổi trong khi quản lý các mức độ ưu tiên của chúng.
class PrioritizedReplayBuffer(ReplayBuffer):
"""Prioritized Replay buffer.
Attributes:
max_priority (float): max priority
tree_ptr (int): next index of tree
alpha (float): alpha parameter for prioritized replay buffer
sum_tree (SumSegmentTree): sum tree for prior
min_tree (MinSegmentTree): min tree for min prior to get max weight
"""
def __init__(
self,
obs_dim: int,
size: int,
batch_size: int = 32,
alpha: float = 0.6
):
"""Initialization."""
assert alpha >= 0
super(PrioritizedReplayBuffer, self).__init__(obs_dim, size, batch_size)
self.max_priority, self.tree_ptr = 1.0, 0
self.alpha = alpha
# capacity must be positive and a power of 2.
tree_capacity = 1
while tree_capacity < self.max_size:
tree_capacity *= 2
self.sum_tree = SumSegmentTree(tree_capacity)
self.min_tree = MinSegmentTree(tree_capacity)
def store(
self,
obs: np.ndarray,
act: int,
rew: float,
next_obs: np.ndarray,
done: bool
):
"""Store experience and priority."""
super().store(obs, act, rew, next_obs, done)
self.sum_tree[self.tree_ptr] = self.max_priority ** self.alpha
self.min_tree[self.tree_ptr] = self.max_priority ** self.alpha
self.tree_ptr = (self.tree_ptr + 1) % self.max_size
def sample_batch(self, beta: float = 0.4) -> Dict[str, np.ndarray]:
"""Sample a batch of experiences."""
assert len(self) >= self.batch_size
assert beta > 0
indices = self._sample_proportional()
obs = self.obs_buf[indices]
next_obs = self.next_obs_buf[indices]
acts = self.acts_buf[indices]
rews = self.rews_buf[indices]
done = self.done_buf[indices]
weights = np.array([self._calculate_weight(i, beta) for i in indices])
return dict(
obs=obs,
next_obs=next_obs,
acts=acts,
rews=rews,
done=done,
weights=weights,
indices=indices,
)
def update_priorities(self, indices: List[int], priorities: np.ndarray):
"""Update priorities of sampled transitions."""
assert len(indices) == len(priorities)
for idx, priority in zip(indices, priorities):
assert priority > 0
assert 0 <= idx < len(self)
self.sum_tree[idx] = priority ** self.alpha
self.min_tree[idx] = priority ** self.alpha
self.max_priority = max(self.max_priority, priority)
def _sample_proportional(self) -> List[int]:
"""Sample indices based on proportions."""
indices = []
p_total = self.sum_tree.sum(0, len(self) - 1)
segment = p_total / self.batch_size
for i in range(self.batch_size):
a = segment * i
b = segment * (i + 1)
upperbound = random.uniform(a, b)
idx = self.sum_tree.retrieve(upperbound)
indices.append(idx)
return indices
def _calculate_weight(self, idx: int, beta: float):
"""Calculate the weight of the experience at idx."""
# get max weight
p_min = self.min_tree.min() / self.sum_tree.sum()
max_weight = (p_min * len(self)) ** (-beta)
# calculate weights
p_sample = self.sum_tree[idx] / self.sum_tree.sum()
weight = (p_sample * len(self)) ** (-beta)
weight = weight / max_weight
return weight
Network
class Network(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
"""Initialization."""
super(Network, self).__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, out_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward method implementation."""
return self.layers(x)
DQN + PER Agent
init
Ở đây, chúng ta sử dụng PrioritizedReplayBuffer, thay vì ReplayBuffer và sử dụng giữ thêm 2 tham số beta và epsilon ưu tiên được sử dụng để tính trọng số và mức độ ưu tiên mới tương ứng.
compute_dqn_loss & update_model
Nó trả về mọi loss trên mỗi mẫu để lấy mẫu mức độ quan trọng trước mức trung bình. Sau khi cập nhật mới, cần phải cập nhật mức độ ưu tiên của tất cả các trải nghiệm được lấy mẫu
train
beta tuyến tính tăng lên 1 ở mỗi bước đào tạo.
class DQNAgent:
"""DQN Agent interacting with environment.
Attribute:
env (gym.Env): openAI Gym environment
memory (ReplayBuffer): replay memory to store transitions
batch_size (int): batch size for sampling
epsilon (float): parameter for epsilon greedy policy
epsilon_decay (float): step size to decrease epsilon
max_epsilon (float): max value of epsilon
min_epsilon (float): min value of epsilon
target_update (int): period for target model's hard update
gamma (float): discount factor
dqn (Network): model to train and select actions
dqn_target (Network): target model to update
optimizer (torch.optim): optimizer for training dqn
transition (list): transition information including
state, action, reward, next_state, done
beta (float): determines how much importance sampling is used
prior_eps (float): guarantees every transition can be sampled
"""
def __init__(
self,
env: gym.Env,
memory_size: int,
batch_size: int,
target_update: int,
epsilon_decay: float,
max_epsilon: float = 1.0,
min_epsilon: float = 0.1,
gamma: float = 0.99,
# PER parameters
alpha: float = 0.2,
beta: float = 0.6,
prior_eps: float = 1e-6,
):
"""Initialization.
Args:
env (gym.Env): openAI Gym environment
memory_size (int): length of memory
batch_size (int): batch size for sampling
target_update (int): period for target model's hard update
epsilon_decay (float): step size to decrease epsilon
lr (float): learning rate
max_epsilon (float): max value of epsilon
min_epsilon (float): min value of epsilon
gamma (float): discount factor
alpha (float): determines how much prioritization is used
beta (float): determines how much importance sampling is used
prior_eps (float): guarantees every transition can be sampled
"""
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
self.env = env
self.batch_size = batch_size
self.epsilon = max_epsilon
self.epsilon_decay = epsilon_decay
self.max_epsilon = max_epsilon
self.min_epsilon = min_epsilon
self.target_update = target_update
self.gamma = gamma
# device: cpu / gpu
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
print(self.device)
# PER
# In DQN, We used "ReplayBuffer(obs_dim, memory_size, batch_size)"
self.beta = beta
self.prior_eps = prior_eps
self.memory = PrioritizedReplayBuffer(
obs_dim, memory_size, batch_size, alpha
)
# networks: dqn, dqn_target
self.dqn = Network(obs_dim, action_dim).to(self.device)
self.dqn_target = Network(obs_dim, action_dim).to(self.device)
self.dqn_target.load_state_dict(self.dqn.state_dict())
self.dqn_target.eval()
# optimizer
self.optimizer = optim.Adam(self.dqn.parameters())
# transition to store in memory
self.transition = list()
# mode: train / test
self.is_test = False
def select_action(self, state: np.ndarray) -> np.ndarray:
"""Select an action from the input state."""
# epsilon greedy policy
if self.epsilon > np.random.random():
selected_action = self.env.action_space.sample()
else:
selected_action = self.dqn(
torch.FloatTensor(state).to(self.device)
).argmax()
selected_action = selected_action.detach().cpu().numpy()
if not self.is_test:
self.transition = [state, selected_action]
return selected_action
def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
"""Take an action and return the response of the env."""
next_state, reward, done, _ = self.env.step(action)
if not self.is_test:
self.transition += [reward, next_state, done]
self.memory.store(*self.transition)
return next_state, reward, done
def update_model(self) -> torch.Tensor:
"""Update the model by gradient descent."""
# PER needs beta to calculate weights
samples = self.memory.sample_batch(self.beta)
weights = torch.FloatTensor(
samples["weights"].reshape(-1, 1)
).to(self.device)
indices = samples["indices"]
# PER: importance sampling before average
elementwise_loss = self._compute_dqn_loss(samples)
loss = torch.mean(elementwise_loss * weights)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# PER: update priorities
loss_for_prior = elementwise_loss.detach().cpu().numpy()
new_priorities = loss_for_prior + self.prior_eps
self.memory.update_priorities(indices, new_priorities)
return loss.item()
def train(self, num_frames: int, plotting_interval: int = 200):
"""Train the agent."""
self.is_test = False
state = self.env.reset()
update_cnt = 0
epsilons = []
losses = []
scores = []
score = 0
for frame_idx in range(1, num_frames + 1):
action = self.select_action(state)
next_state, reward, done = self.step(action)
state = next_state
score += reward
# PER: increase beta
fraction = min(frame_idx / num_frames, 1.0)
self.beta = self.beta + fraction * (1.0 - self.beta)
# if episode ends
if done:
state = self.env.reset()
scores.append(score)
score = 0
# if training is ready
if len(self.memory) >= self.batch_size:
loss = self.update_model()
losses.append(loss)
update_cnt += 1
# linearly decrease epsilon
self.epsilon = max(
self.min_epsilon, self.epsilon - (
self.max_epsilon - self.min_epsilon
) * self.epsilon_decay
)
epsilons.append(self.epsilon)
# if hard update is needed
if update_cnt % self.target_update == 0:
self._target_hard_update()
# plotting
if frame_idx % plotting_interval == 0:
self._plot(frame_idx, scores, losses, epsilons)
self.env.close()
def test(self, video_folder: str) -> None:
"""Test the agent."""
self.is_test = True
# for recording a video
naive_env = self.env
self.env = gym.wrappers.RecordVideo(self.env, video_folder=video_folder)
state = self.env.reset()
done = False
score = 0
while not done:
action = self.select_action(state)
next_state, reward, done = self.step(action)
state = next_state
score += reward
print("score: ", score)
self.env.close()
# reset
self.env = naive_env
def _compute_dqn_loss(self, samples: Dict[str, np.ndarray]) -> torch.Tensor:
"""Return dqn loss."""
device = self.device # for shortening the following lines
state = torch.FloatTensor(samples["obs"]).to(device)
next_state = torch.FloatTensor(samples["next_obs"]).to(device)
action = torch.LongTensor(samples["acts"].reshape(-1, 1)).to(device)
reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
# G_t = r + gamma * v(s_{t+1}) if state != Terminal
# = r otherwise
curr_q_value = self.dqn(state).gather(1, action)
next_q_value = self.dqn_target(
next_state
).max(dim=1, keepdim=True)[0].detach()
mask = 1 - done
target = (reward + self.gamma * next_q_value * mask).to(self.device)
# calculate element-wise dqn loss
elementwise_loss = F.smooth_l1_loss(curr_q_value, target, reduction="none")
return elementwise_loss
def _target_hard_update(self):
"""Hard update: target <- local."""
self.dqn_target.load_state_dict(self.dqn.state_dict())
def _plot(
self,
frame_idx: int,
scores: List[float],
losses: List[float],
epsilons: List[float],
):
"""Plot the training progresses."""
clear_output(True)
plt.figure(figsize=(20, 5))
plt.subplot(131)
plt.title('frame %s. score: %s' % (frame_idx, np.mean(scores[-10:])))
plt.plot(scores)
plt.subplot(132)
plt.title('loss')
plt.plot(losses)
plt.subplot(133)
plt.title('epsilons')
plt.plot(epsilons)
plt.show()