当前位置:网站首页>[Reinforcement Learning] "Easy RL" - Q-learning - CliffWalking (cliff walking) code interpretation
[Reinforcement Learning] "Easy RL" - Q-learning - CliffWalking (cliff walking) code interpretation
2022-08-10 06:33:00 【None072】
目录
0. 前言
The code for this blog comes from the mushroom book《Easy RL》QLearn the Cliff Walk practical part of the learning section,I have a complete interpretation of the code while learning,如有错误之处,烦请指正.
Easy-RL github :https://github.com/datawhalechina/easy-rl
This part of the code has two core files:
- qlearning.py
- task0.py
首先学习 task0 部分
1. 超参数
机器学习模型中一般有两类参数:一类需要从数据中学习和估计得到,称为模型参数(Parameter),即模型本身的参数.Another category is tuning parameters in machine learning algorithms(tuning parameters),需要人为设定,称为超参数(Hyperparameter).
class Config:
"""超参数 """
def __init__(self):
################################## 环境超参数 ###################################
self.algo_name = 'Q-learning' # 算法名称,我们使用Q学习算法
self.env_name = 'CliffWalking-v0' # 环境名称,悬崖行走
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu") # 检测GPU,如果没装CUDA的话默认为CPU
self.seed = 10 # 随机种子,置0则不设置随机种子.The random values in our learning process all correspond to a random seed,It is convenient for us to reproduce the learning results
self.train_eps = 400 # 训练的回合数
self.test_eps = 30 # 测试的回合数
################################################################################
################################## 算法超参数 ###################################
self.gamma = 0.90 # 强化学习中的折扣因子
self.epsilon_start = 0.95 # ε-Initialization in a greedy strategyepsilon,Decreasing this value reduces the chance of random exploration at the start of learning
self.epsilon_end = 0.01 # ε-Termination in a greedy strategyepsilon,The smaller the learning result, the closer it is
self.epsilon_decay = 300 # e-greedy策略中epsilon的衰减率,The larger the value, the faster the decay
self.lr = 0.1 # 学习率
################################################################################
################################# 保存结果相关参数 ################################
self.result_path = curr_path + "/outputs/" + self.env_name + \
'/' + curr_time + '/results/' # 保存结果的路径
self.model_path = curr_path + "/outputs/" + self.env_name + \
'/' + curr_time + '/models/' # 保存模型的路径
self.save_fig = True # 是否保存图片,注意这里改为 save_fig
################################################################################
2. 训练
def train(cfg, env, agent):
print('开始训练!')
print(f'环境:{
cfg.env_name}, 算法:{
cfg.algo_name}, 设备:{
cfg.device}')
rewards = [] # Record the rewards for each round,Used to record and analyze changes in rewards
ma_rewards = [] # Oscillation may occur due to the reward obtained,Use a moving average amount to reflect the trend of reward changes
# Start round training
for i_ep in range(cfg.train_eps):
ep_reward = 0 # Record the rewards for each round
state = env.reset() # 重置环境,开始新的回合
# Start walking for the current round,until the end
while True:
action = agent.choose_action(state) # 根据算法选择一个动作
next_state, reward, done, _ = env.step(action) # An action interaction with the environment
agent.update(state, action, reward, next_state, done) # Q学习算法更新
state = next_state # 更新状态
ep_reward += reward
if done:
break
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(ma_rewards[-1] * 0.9 + ep_reward * 0.1)
else:
ma_rewards.append(ep_reward)
print("回合数:{}/{},奖励{:.1f}".format(i_ep + 1, cfg.train_eps, ep_reward))
print('完成训练!')
return rewards, ma_rewards
2.1 Initialize the environment and agent
def env_agent_config(cfg, seed=1):
"""Create environments and agents Args: cfg ([type]): [description] seed (int, optional): 随机种子. Defaults to 1. Returns: env [type]: 环境 agent : 智能体 """
env = gym.make(cfg.env_name)
env = CliffWalkingWapper(env) # Decorate the environment with custom decorators
env.seed(seed) # 设置随机种子,Each seed corresponds to a random result,只是为了让结果可以精确复现,一般情况下可删去
n_states = env.observation_space.n # 状态维度,即 48 个状态
n_actions = env.action_space.n # 动作维度, 即 4 个动作
agent = QLearning(n_states, n_actions, cfg) # Set parameters for the agent
return env, agent
2.2 智能体选择动作
对于上述代码中的action = agent.choose_action(state)
其方法实现如下:
def choose_action(self, state):
self.sample_count += 1
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
math.exp(-1. * self.sample_count / self.epsilon_decay) # epsilonwill decrease,Here choose exponentially decreasing
# e-greedy 策略
if np.random.uniform(0, 1) > self.epsilon:
action = np.argmax(self.Q_table[str(state)]) # 选择Q(s,a)Maximum corresponding action
else:
action = np.random.choice(self.n_actions) # 随机选择动作
return action
此处使用的ε-Greedy algorithm formula:
As the learning process increases,epsilon Exponential decay occurs,直到逼近 epsilon_end.
The randomly chosen number is greater than epsilon ,即值在 1-epsilon 范围内时,选择Q(s,a)Maximum corresponding action.
现在,Let's try printing the current state:print(self.Q_table[str(state)])
输出结果为:[ -7.45800334 -78.37958986 -7.46127197 -7.48193639]
The four values in the above array are the values that each action will generate.
2.3 The environment receives actions and feeds back the next state and reward
After the action is selected,We use this action to have one interaction with the environment:
next_state, reward, done, _ = env.step(action)
through a given action,We can get the next state and reward from the map.
- For example in the starting grid36执行动作UP=0,下一个状态为24,奖励为-1;
- We also need to set the bounds of the map,For example to perform an action at the starting point LEFT=1,The next state is still36,奖励为−1W;
- If the action is performedRIGHT=3,Then you will fall off a cliff,下一个状态为36,奖励为 −100 .
The specific logical calculation process is inC:\Python310\Lib\site-packages\gym\envs\toy_text\cliffwalking.py
查看.
参数 done Used to determine whether to reach the end point.
2.4 The agent performs policy updates(学习)
现在,We got the current state、选择的动作、Reward and next state,It can be used in the intelligent bodyQ学习算法更新Q表格:
agent.update(state, action, reward, next_state, done) # Q学习算法更新
方法实现如下:
def update(self, state, action, reward, next_state, done):
Q_predict = self.Q_table[str(state)][action] # Read the predicted value
if done: # 终止状态判断
Q_target = reward # The next action cannot be obtained in the terminated state,直接将 Q_target Update to the corresponding reward
else:
Q_target = reward + self.gamma * np.max(self.Q_table[str(next_state)])
self.Q_table[str(state)][action] += self.lr * (Q_target - Q_predict)
The formulas involved are described in the book QIncremental learning pseudocode for learning:
这样,The value of the action corresponding to the current state is updated,That is, policy updates.
3. 结果处理
在上文中,We completed one round of study,After each round of study is over,We need to record the reward for this round,for subsequent visualization:
rewards.append(ep_reward)
if ma_rewards:
ma_rewards.append(ma_rewards[-1] * 0.9 + ep_reward * 0.1)
else:
ma_rewards.append(ep_reward)
Oscillation may occur due to the reward obtained,We use a moving average amount to reflect the trend of reward changes,That is, use the new reward and the previous reward to calculate an average reward and add it to the list.
3.1 模型保存
Wait until all rounds have been executed,Save this trained model:
make_dir(cfg.result_path, cfg.model_path) # Create a folder to save the results and model paths
agent.save(path=cfg.model_path) # 保存模型
save的实现:
def save(self, path):
import dill
torch.save(
obj=self.Q_table,
f=path + "Qlearning_model.pkl",
pickle_module=dill
)
print("保存模型成功!")
dill模块:https://pypi.org/project/dill/
dill extends python’s pickle module for serializing(序列化) and de-serializing(反序列化) python objects to the majority of the built-in python types. Serialization is the process of converting an object to a byte stream, and the inverse of which is converting a byte stream back to a python object hierarchy.
dill provides the user the same interface as the pickle module, and also includes some additional features. In addition to pickling python objects, dill provides the ability to save the state of an interpreter session in a single command. Hence, it would be feasable to save an interpreter session, close the interpreter, ship the pickled file to another computer, open a new interpreter, unpickle the session and thus continue from the ‘saved’ state of the original interpreter session.
我们用 pkl 文件(this storage method,可以将pythonSome temporary variables used during the project、或者需要提取、暂存的字符串、列表、Dictionaries and other data are saved)to save the trained model,即 Q表格.Packaged modules are used dill模块.
torch.save()
Save a serialization(serialized)target to disk.函数使用了Python的picklePrograms are used for serialization.模型(models),张量(tensors)和文件夹(dictionaries)are all target types that can be saved with this function.
3.2 模型读取
def load(self, path):
import dill
self.Q_table = torch.load(f=path + 'Qlearning_model.pkl', pickle_module=dill)
print("加载模型成功!")
Similar to model saving,使用torch.load()
Perform a model read operation,Thereby loading the trained one Q表格.
3.3 模型测试
Model testing and training methods are basically the same,唯一的区别只是不用再进行 Q表格的更新,That is, without the following line of code:
agent.update(state, action, reward, next_state, done) # Q学习算法更新
边栏推荐
- Unity screen coordinates to world coordinates, mouse click to get 3D position
- Qt使用私有接口绘制窗口阴影
- 强化学习_10_Datawhale稀疏奖励
- 高质量WordPress下载站模板5play主题
- unity箭头控制物体移动
- 第12章 数据库其它调优策略【2.索引及调优篇】【MySQL高级】
- Analysis of minix_super_block.s_nzones of mkfs.minix.c
- 机器学习_LGB调参汇总(开箱即食)
- COLMAP+OpenMVS实现物体三维重建mesh模型
- 【8月9日活动预告】Prometheus峰会
猜你喜欢
随机推荐
强化学习_05_DataWhale近端策略优化
UnityShader入门精要-纹理动画、顶点动画
驱动的参数传入:module_param,module_param_array,module_param_cb
tqdm高级使用方法(类keras进度条)
进制的前缀表示和后缀表示
MySQL之InnoDB引擎(六)
动态代理-cglib
2022河南萌新联赛第(五)场:信息工程大学 H - 小明喝奶茶
程序员的十楼层。看看自己在第几层。PS:我的目标是:30岁第四层
OpenGL学习笔记(LearnOpenGL)-第五部分 纹理
H2数据库如何动态插入数据
Talking about 3 Common Shadow Rendering Techniques in Games (3): Shadow Mapping
2022河南萌新联赛第(五)场:信息工程大学 B - 交通改造
直接跳转与间接跳转
第12章 数据库其它调优策略【2.索引及调优篇】【MySQL高级】
unity在UI界面上展示旋转模型
OSPF的dr和bdr
Basic use of Unity's navigation and wayfinding system
BUUCTF笔记(web)
UnityShader入门精要-透明效果