当前位置:网站首页>[Reinforcement Learning] "Easy RL" - Q-learning - CliffWalking (cliff walking) code interpretation
[Reinforcement Learning] "Easy RL" - Q-learning - CliffWalking (cliff walking) code interpretation
2022-08-10 06:33:00 【None072】
目录
0. 前言
The code for this blog comes from the mushroom book《Easy RL》QLearn the Cliff Walk practical part of the learning section,I have a complete interpretation of the code while learning,如有错误之处,烦请指正.
Easy-RL github :https://github.com/datawhalechina/easy-rl
This part of the code has two core files:
- qlearning.py
- task0.py
首先学习 task0 部分
1. 超参数
机器学习模型中一般有两类参数:一类需要从数据中学习和估计得到,称为模型参数(Parameter),即模型本身的参数.Another category is tuning parameters in machine learning algorithms(tuning parameters),需要人为设定,称为超参数(Hyperparameter).
class Config:
"""超参数 """
def __init__(self):
################################## 环境超参数 ###################################
self.algo_name = 'Q-learning' # 算法名称,我们使用Q学习算法
self.env_name = 'CliffWalking-v0' # 环境名称,悬崖行走
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu") # 检测GPU,如果没装CUDA的话默认为CPU
self.seed = 10 # 随机种子,置0则不设置随机种子.The random values in our learning process all correspond to a random seed,It is convenient for us to reproduce the learning results
self.train_eps = 400 # 训练的回合数
self.test_eps = 30 # 测试的回合数
################################################################################
################################## 算法超参数 ###################################
self.gamma = 0.90 # 强化学习中的折扣因子
self.epsilon_start = 0.95 # ε-Initialization in a greedy strategyepsilon,Decreasing this value reduces the chance of random exploration at the start of learning
self.epsilon_end = 0.01 # ε-Termination in a greedy strategyepsilon,The smaller the learning result, the closer it is
self.epsilon_decay = 300 # e-greedy策略中epsilon的衰减率,The larger the value, the faster the decay
self.lr = 0.1 # 学习率
################################################################################
################################# 保存结果相关参数 ################################
self.result_path = curr_path + "/outputs/" + self.env_name + \
'/' + curr_time + '/results/' # 保存结果的路径
self.model_path = curr_path + "/outputs/" + self.env_name + \
'/' + curr_time + '/models/' # 保存模型的路径
self.save_fig = True # 是否保存图片,注意这里改为 save_fig
################################################################################
2. 训练
def train(cfg, env, agent):
print('开始训练!')
print(f'环境:{
cfg.env_name}, 算法:{
cfg.algo_name}, 设备:{
cfg.device}')
rewards = [] # Record the rewards for each round,Used to record and analyze changes in rewards
ma_rewards = [] # Oscillation may occur due to the reward obtained,Use a moving average amount to reflect the trend of reward changes
# Start round training
for i_ep in range(cfg.train_eps):
ep_reward = 0 # Record the rewards for each round
state = env.reset() # 重置环境,开始新的回合
# Start walking for the current round,until the end
while True:
action = agent.choose_action(state) # 根据算法选择一个动作
next_state, reward, done, _ = env.step(action) # An action interaction with the environment
agent.update(state, action, reward, next_state, done) # Q学习算法更新
state = next_state # 更新状态
ep_reward += reward
if done:
break
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(ma_rewards[-1] * 0.9 + ep_reward * 0.1)
else:
ma_rewards.append(ep_reward)
print("回合数:{}/{},奖励{:.1f}".format(i_ep + 1, cfg.train_eps, ep_reward))
print('完成训练!')
return rewards, ma_rewards
2.1 Initialize the environment and agent
def env_agent_config(cfg, seed=1):
"""Create environments and agents Args: cfg ([type]): [description] seed (int, optional): 随机种子. Defaults to 1. Returns: env [type]: 环境 agent : 智能体 """
env = gym.make(cfg.env_name)
env = CliffWalkingWapper(env) # Decorate the environment with custom decorators
env.seed(seed) # 设置随机种子,Each seed corresponds to a random result,只是为了让结果可以精确复现,一般情况下可删去
n_states = env.observation_space.n # 状态维度,即 48 个状态
n_actions = env.action_space.n # 动作维度, 即 4 个动作
agent = QLearning(n_states, n_actions, cfg) # Set parameters for the agent
return env, agent
2.2 智能体选择动作
对于上述代码中的action = agent.choose_action(state)
其方法实现如下:
def choose_action(self, state):
self.sample_count += 1
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
math.exp(-1. * self.sample_count / self.epsilon_decay) # epsilonwill decrease,Here choose exponentially decreasing
# e-greedy 策略
if np.random.uniform(0, 1) > self.epsilon:
action = np.argmax(self.Q_table[str(state)]) # 选择Q(s,a)Maximum corresponding action
else:
action = np.random.choice(self.n_actions) # 随机选择动作
return action
此处使用的ε-Greedy algorithm formula:
As the learning process increases,epsilon Exponential decay occurs,直到逼近 epsilon_end.
The randomly chosen number is greater than epsilon ,即值在 1-epsilon 范围内时,选择Q(s,a)Maximum corresponding action.
现在,Let's try printing the current state:print(self.Q_table[str(state)])
输出结果为:[ -7.45800334 -78.37958986 -7.46127197 -7.48193639]
The four values in the above array are the values that each action will generate.
2.3 The environment receives actions and feeds back the next state and reward
After the action is selected,We use this action to have one interaction with the environment:
next_state, reward, done, _ = env.step(action)
through a given action,We can get the next state and reward from the map.
- For example in the starting grid36执行动作UP=0,下一个状态为24,奖励为-1;
- We also need to set the bounds of the map,For example to perform an action at the starting point LEFT=1,The next state is still36,奖励为−1W;
- If the action is performedRIGHT=3,Then you will fall off a cliff,下一个状态为36,奖励为 −100 .
The specific logical calculation process is inC:\Python310\Lib\site-packages\gym\envs\toy_text\cliffwalking.py查看.
参数 done Used to determine whether to reach the end point.
2.4 The agent performs policy updates(学习)
现在,We got the current state、选择的动作、Reward and next state,It can be used in the intelligent bodyQ学习算法更新Q表格:
agent.update(state, action, reward, next_state, done) # Q学习算法更新
方法实现如下:
def update(self, state, action, reward, next_state, done):
Q_predict = self.Q_table[str(state)][action] # Read the predicted value
if done: # 终止状态判断
Q_target = reward # The next action cannot be obtained in the terminated state,直接将 Q_target Update to the corresponding reward
else:
Q_target = reward + self.gamma * np.max(self.Q_table[str(next_state)])
self.Q_table[str(state)][action] += self.lr * (Q_target - Q_predict)
The formulas involved are described in the book QIncremental learning pseudocode for learning:
这样,The value of the action corresponding to the current state is updated,That is, policy updates.
3. 结果处理
在上文中,We completed one round of study,After each round of study is over,We need to record the reward for this round,for subsequent visualization:
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(ma_rewards[-1] * 0.9 + ep_reward * 0.1)
else:
ma_rewards.append(ep_reward)
Oscillation may occur due to the reward obtained,We use a moving average amount to reflect the trend of reward changes,That is, use the new reward and the previous reward to calculate an average reward and add it to the list.
3.1 模型保存
Wait until all rounds have been executed,Save this trained model:
make_dir(cfg.result_path, cfg.model_path) # Create a folder to save the results and model paths
agent.save(path=cfg.model_path) # 保存模型
save的实现:
def save(self, path):
import dill
torch.save(
obj=self.Q_table,
f=path + "Qlearning_model.pkl",
pickle_module=dill
)
print("保存模型成功!")
dill模块:https://pypi.org/project/dill/
dill extends python’s pickle module for serializing(序列化) and de-serializing(反序列化) python objects to the majority of the built-in python types. Serialization is the process of converting an object to a byte stream, and the inverse of which is converting a byte stream back to a python object hierarchy.
dill provides the user the same interface as the pickle module, and also includes some additional features. In addition to pickling python objects, dill provides the ability to save the state of an interpreter session in a single command. Hence, it would be feasable to save an interpreter session, close the interpreter, ship the pickled file to another computer, open a new interpreter, unpickle the session and thus continue from the ‘saved’ state of the original interpreter session.
我们用 pkl 文件(this storage method,可以将pythonSome temporary variables used during the project、或者需要提取、暂存的字符串、列表、Dictionaries and other data are saved)to save the trained model,即 Q表格.Packaged modules are used dill模块.
torch.save()
Save a serialization(serialized)target to disk.函数使用了Python的picklePrograms are used for serialization.模型(models),张量(tensors)和文件夹(dictionaries)are all target types that can be saved with this function.
3.2 模型读取
def load(self, path):
import dill
self.Q_table = torch.load(f=path + 'Qlearning_model.pkl', pickle_module=dill)
print("加载模型成功!")
Similar to model saving,使用torch.load()Perform a model read operation,Thereby loading the trained one Q表格.
3.3 模型测试
Model testing and training methods are basically the same,唯一的区别只是不用再进行 Q表格的更新,That is, without the following line of code:
agent.update(state, action, reward, next_state, done) # Q学习算法更新
边栏推荐
猜你喜欢
随机推荐
数据库学习之表的约束
指纹浏览器在使用易路代理时常见的问题及解决办法
Unity object pool implementation
强化学习_10_Datawhale稀疏奖励
OpenGL学习笔记(LearnOpenGL)-第二部分 绘制三角形
进制的前缀表示和后缀表示
【论文解读】滴滴智能派单-KDD2018 Large-Scale Order Dispatch in On-Demand Ride-Hailing
全网可达并设备加密
The difference between initializing objects as null and empty objects in JS
JS中初始化对象为null和空对象的区别
Why need to hot update game?
强化学习_12_Datawhale深度确定性策略梯度
1413. 逐步求和得到正数的最小值
Introduction to KDE Framework
vsnprint和snprintf的区别
COLMAP+OpenMVS实现物体三维重建mesh模型
Kernel performance analysis summary
npm搭建私服,上传下载包
深入理解数组
虚幻5简单第三人称游戏制作文档









