当前位置:网站首页>【源码】使用深度学习训练一个游戏
【源码】使用深度学习训练一个游戏
2022-08-04 03:05:00 【落难Coder】
运行环境
pygame
numpy
opencv
提示
- 运行报错
AttributeError: module 'tensorflow' has no attribute 'mul'
解决方案:
TensorFlow的“mul”函数变成“multiply”函数了。TensorFlow版本不同,使用的函数不同。
将mul改为multiply。
- 运行报错
saver.save(sess, "model.ckpt")
改为相对路径保存
saver.save(sess, "./model.ckpt")
训练
import pygame
import random
from pygame.locals import *
import numpy as np
from collections import deque
import tensorflow as tf
import cv2
BLACK = (0 ,0 ,0 )
WHITE = (255,255,255)
SCREEN_SIZE = [320,400]
BAR_SIZE = [50, 5]
BALL_SIZE = [15, 15]
# 神经网络的输出
MOVE_STAY = [1, 0, 0]
MOVE_LEFT = [0, 1, 0]
MOVE_RIGHT = [0, 0, 1]
class Game(object):
def __init__(self):
pygame.init()
self.clock = pygame.time.Clock()
self.screen = pygame.display.set_mode(SCREEN_SIZE)
pygame.display.set_caption('Simple Game')
self.ball_pos_x = SCREEN_SIZE[0]//2 - BALL_SIZE[0]/2
self.ball_pos_y = SCREEN_SIZE[1]//2 - BALL_SIZE[1]/2
self.ball_dir_x = -1 # -1 = left 1 = right
self.ball_dir_y = -1 # -1 = up 1 = down
self.ball_pos = pygame.Rect(self.ball_pos_x, self.ball_pos_y, BALL_SIZE[0], BALL_SIZE[1])
self.bar_pos_x = SCREEN_SIZE[0]//2-BAR_SIZE[0]//2
self.bar_pos = pygame.Rect(self.bar_pos_x, SCREEN_SIZE[1]-BAR_SIZE[1], BAR_SIZE[0], BAR_SIZE[1])
# action是MOVE_STAY、MOVE_LEFT、MOVE_RIGHT
# ai控制棒子左右移动;返回游戏界面像素数和对应的奖励。(像素->奖励->强化棒子往奖励高的方向移动)
def step(self, action):
if action == MOVE_LEFT:
self.bar_pos_x = self.bar_pos_x - 2
elif action == MOVE_RIGHT:
self.bar_pos_x = self.bar_pos_x + 2
else:
pass
if self.bar_pos_x < 0:
self.bar_pos_x = 0
if self.bar_pos_x > SCREEN_SIZE[0] - BAR_SIZE[0]:
self.bar_pos_x = SCREEN_SIZE[0] - BAR_SIZE[0]
self.screen.fill(BLACK)
self.bar_pos.left = self.bar_pos_x
pygame.draw.rect(self.screen, WHITE, self.bar_pos)
self.ball_pos.left += self.ball_dir_x * 2
self.ball_pos.bottom += self.ball_dir_y * 3
pygame.draw.rect(self.screen, WHITE, self.ball_pos)
if self.ball_pos.top <= 0 or self.ball_pos.bottom >= (SCREEN_SIZE[1] - BAR_SIZE[1]+1):
self.ball_dir_y = self.ball_dir_y * -1
if self.ball_pos.left <= 0 or self.ball_pos.right >= (SCREEN_SIZE[0]):
self.ball_dir_x = self.ball_dir_x * -1
reward = 0
if self.bar_pos.top <= self.ball_pos.bottom and (self.bar_pos.left < self.ball_pos.right and self.bar_pos.right > self.ball_pos.left):
reward = 1 # 击中奖励
elif self.bar_pos.top <= self.ball_pos.bottom and (self.bar_pos.left > self.ball_pos.right or self.bar_pos.right < self.ball_pos.left):
reward = -1 # 没击中惩罚
# 获得游戏界面像素
screen_image = pygame.surfarray.array3d(pygame.display.get_surface())
pygame.display.update()
# 返回游戏界面像素和对应的奖励
return reward, screen_image
# learning_rate
LEARNING_RATE = 0.99
# 更新梯度
INITIAL_EPSILON = 1.0
FINAL_EPSILON = 0.05
# 测试观测次数
EXPLORE = 500000
OBSERVE = 50000
# 存储过往经验大小
REPLAY_MEMORY = 500000
BATCH = 100
output = 3 # 输出层神经元数。代表3种操作-MOVE_STAY:[1, 0, 0] MOVE_LEFT:[0, 1, 0] MOVE_RIGHT:[0, 0, 1]
input_image = tf.placeholder("float", [None, 80, 100, 4]) # 游戏像素
action = tf.placeholder("float", [None, output]) # 操作
# 定义CNN-卷积神经网络 参考:http://blog.topspeedsnail.com/archives/10451
def convolutional_neural_network(input_image):
weights = {
'w_conv1':tf.Variable(tf.zeros([8, 8, 4, 32])),
'w_conv2':tf.Variable(tf.zeros([4, 4, 32, 64])),
'w_conv3':tf.Variable(tf.zeros([3, 3, 64, 64])),
'w_fc4':tf.Variable(tf.zeros([3456, 784])),
'w_out':tf.Variable(tf.zeros([784, output]))}
biases = {
'b_conv1':tf.Variable(tf.zeros([32])),
'b_conv2':tf.Variable(tf.zeros([64])),
'b_conv3':tf.Variable(tf.zeros([64])),
'b_fc4':tf.Variable(tf.zeros([784])),
'b_out':tf.Variable(tf.zeros([output]))}
conv1 = tf.nn.relu(tf.nn.conv2d(input_image, weights['w_conv1'], strides = [1, 4, 4, 1], padding = "VALID") + biases['b_conv1'])
conv2 = tf.nn.relu(tf.nn.conv2d(conv1, weights['w_conv2'], strides = [1, 2, 2, 1], padding = "VALID") + biases['b_conv2'])
conv3 = tf.nn.relu(tf.nn.conv2d(conv2, weights['w_conv3'], strides = [1, 1, 1, 1], padding = "VALID") + biases['b_conv3'])
conv3_flat = tf.reshape(conv3, [-1, 3456])
fc4 = tf.nn.relu(tf.matmul(conv3_flat, weights['w_fc4']) + biases['b_fc4'])
output_layer = tf.matmul(fc4, weights['w_out']) + biases['b_out']
return output_layer
# 深度强化学习入门: https://www.nervanasys.com/demystifying-deep-reinforcement-learning/
# 训练神经网络
def train_neural_network(input_image):
predict_action = convolutional_neural_network(input_image)
argmax = tf.placeholder("float", [None, output])
gt = tf.placeholder("float", [None])
action = tf.reduce_sum(tf.multiply(predict_action, argmax), reduction_indices = 1)
cost = tf.reduce_mean(tf.square(action - gt))
optimizer = tf.train.AdamOptimizer(1e-6).minimize(cost)
game = Game()
D = deque()
_, image = game.step(MOVE_STAY)
# 转换为灰度值
image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)
# 转换为二值
ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
input_image_data = np.stack((image, image, image, image), axis = 2)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
n = 0
epsilon = INITIAL_EPSILON
while True:
action_t = predict_action.eval(feed_dict = {
input_image : [input_image_data]})[0]
argmax_t = np.zeros([output], dtype=np.int)
if(random.random() <= INITIAL_EPSILON):
maxIndex = random.randrange(output)
else:
maxIndex = np.argmax(action_t)
argmax_t[maxIndex] = 1
if epsilon > FINAL_EPSILON:
epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
#for event in pygame.event.get(): macOS需要事件循环,否则白屏
# if event.type == QUIT:
# pygame.quit()
# sys.exit()
reward, image = game.step(list(argmax_t))
image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)
ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
image = np.reshape(image, (80, 100, 1))
input_image_data1 = np.append(image, input_image_data[:, :, 0:3], axis = 2)
D.append((input_image_data, argmax_t, reward, input_image_data1))
if len(D) > REPLAY_MEMORY:
D.popleft()
if n > OBSERVE:
minibatch = random.sample(D, BATCH)
input_image_data_batch = [d[0] for d in minibatch]
argmax_batch = [d[1] for d in minibatch]
reward_batch = [d[2] for d in minibatch]
input_image_data1_batch = [d[3] for d in minibatch]
gt_batch = []
out_batch = predict_action.eval(feed_dict = {
input_image : input_image_data1_batch})
for i in range(0, len(minibatch)):
gt_batch.append(reward_batch[i] + LEARNING_RATE * np.max(out_batch[i]))
optimizer.run(feed_dict = {
gt : gt_batch, argmax : argmax_batch, input_image : input_image_data_batch})
input_image_data = input_image_data1
n = n+1
if n % 10000 == 0:
saver.save(sess, './game.cpk', global_step = n) # 保存模型
print(n, "epsilon:", epsilon, " " ,"action:", maxIndex, " " ,"reward:", reward)
train_neural_network(input_image)
运行示例

边栏推荐
- new Date converts strings into date formats Compatible with IE, how ie8 converts strings into date formats through new Date, how to replace strings in js, and explain the replace() method in detail
- 2千兆光+6千兆电导轨式网管型工业级以太网交换机支持X-Ring冗余环网一键环网交换机
- Exclude_reserved_words 排除关键字
- 单片机C语言->的用法,和意思
- Mini program + new retail, play the new way of playing in the industry!
- How to read the resources files in the directory path?
- tkmapper的crud示例:
- 复制带随机指针的链表
- y86.第四章 Prometheus大厂监控体系及实战 -- prometheus存储(十七)
- ingress 待完善
猜你喜欢

MCU C language -> usage, and meaning

Why use Selenium for automated testing

STM8S105K4T6------Serial port sending and receiving

Development of Taurus. MVC WebAPI introductory tutorial 1: download environment configuration and operation framework (including series directory).

new Date converts strings into date formats Compatible with IE, how ie8 converts strings into date formats through new Date, how to replace strings in js, and explain the replace() method in detail

Polygon zkEVM网络节点

MySQL查询优化与调优

怎样提高网络数据安全性

STM8S105k4t6c--------------点亮LED

逻辑漏洞----其他类型
随机推荐
esp8266-01s刷固件步骤
MySQL 查询练习(1)
Good bosses, please ask the flink CDC oracle to Doris, found that the CPU is unusual, a run down
Y86. Chapter iv Prometheus giant monitoring system and the actual combat, Prometheus storage (17)
new Date converts strings into date formats Compatible with IE, how ie8 converts strings into date formats through new Date, how to replace strings in js, and explain the replace() method in detail
基于Qt的目录统计QDirStat
基地址:环境变量
uni-app 从零开始-基础模版(一)
怎样提高网络数据安全性
STM8S项目创建(STVD创建)---使用 COSMIC 创建 C 语言项目
出海季,互联网出海锦囊之本地化
从图文展示到以云为核,第五代验证码独有的策略情报能力
2022支付宝C2C现金红包PHP源码DEMO/兼容苹果/安卓浏览器和扫码形式
sql注入一般流程(附例题)
【学习笔记之菜Dog学C】动态内存管理
自定义通用分页标签02
仿牛客论坛项目梳理
KingbaseES数据库启动失败,报“内存段超过可用内存”
Homemade bluetooth mobile app to control stm8/stm32/C51 onboard LED
Innovation and Integration | Huaqiu Empowerment Helps OpenHarmony Ecological Hardware Development and Landing