当前位置:网站首页>pytorch实现GAN入门案例
pytorch实现GAN入门案例
2022-08-09 05:18:00 【王大队长】
最近在知乎上看到一个不错的GAN的入门案例,于是稍微修改了一下后分享出来!
我们都知道GAN主要用来生成,相比于生成图片,我们这次选择更为简单的生成一个一维函数来大致了解GAN的流程及代码实现。
我们的原始数据为y = 2x^2 + 1,我们让GAN来生成与之接近的分布!

代码:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(1)
np.random.seed(1)
# 学习率
LR_G = 0.0001
LR_D = 0.0001
BATCH_SIZE = 64
N_IDEAS = 5 # 输入的噪声维度,可以自己设定(经过神经网络后会把维度调整)
ART_COMPONETS = 15 # 噪声输入后的输出维度
PAINT_POINTS = np.stack([np.linspace(-1,1,ART_COMPONETS) for _ in range(BATCH_SIZE)],0) # 我们原始数据的x坐标,-1~1均匀分布
def artist_work():
a = np.ones((BATCH_SIZE,1)) * 2
paints = a * np.power(PAINT_POINTS,2) + (a-1) # y = 2x^2 + 1
paints = torch.from_numpy(paints).float()
return paints
# 网络结构
G = nn.Sequential(
nn.Linear(N_IDEAS,128),
nn.ReLU(),
nn.Linear(128,ART_COMPONETS)
)
D = nn.Sequential(
nn.Linear(ART_COMPONETS,128),
nn.ReLU(),
nn.Linear(128,1),
nn.Sigmoid()
)
#优化器与损失函数
optimizer_G = torch.optim.Adam(G.parameters(),lr=LR_G)
optimizer_D = torch.optim.Adam(D.parameters(),lr=LR_D)
Criterion = torch.nn.BCELoss()
# 开始训练
plt.ion()
G_losses = [] #储存了损失方便自己画图可视化
D_losses = []
for step in range(10000):
artist_painting = artist_work()
G_idea = torch.randn(BATCH_SIZE,N_IDEAS)
G_paintings = G(G_idea)
pro_atrist0 = D(artist_painting)
pro_atrist1 = D(G_paintings)
G_loss = -1/torch.mean(torch.log(1.-pro_atrist1))
G_losses.append(G_loss.item())
D_loss = Criterion(pro_atrist0, torch.ones_like( pro_atrist0))+Criterion(pro_atrist1, torch.zeros_like(pro_atrist1))
D_losses.append(D_loss.item())
optimizer_G.zero_grad()
G_loss.backward(retain_graph=True) #因为D的反向传播需要用到G,所以设置为True
optimizer_D.zero_grad()
D_loss.backward( )
optimizer_G.step()
optimizer_D.step()
if step % 200 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='original data')
plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % pro_atrist0.data.numpy().mean(), fontdict={'size': 13})
# plt.text(-.5, 2, 'G_loss= %.2f ' % G_loss.data.numpy(), fontdict={'size': 13})
plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.1)
print('训练结束')
plt.ioff()
plt.show()结果:
可以看到在网络结构很简单的情况下还是可以取得一个不错的结果!

参考资料:
边栏推荐
- [Developers must see] [push kit] Collection of typical problems of push service service 2
- Eureka-Server------单节和集群的搭建
- 神经网络预测应力应变-单轴实验
- STM32的Cube学习笔记(ADC)
- 如何让Win11两个屏幕任务栏都显示时间?
- [UNR #6 A] Noodle-based road (shortest path)
- 匿名共享内存 ashmem
- Lock wait timeout exceeded; try restarting transaction 更新数据量范围太大,导致锁表惨案
- 【LeetCode】287. 寻找重复数
- Openresty执行lua脚本
猜你喜欢

明明加了唯一索引,为什么还是产生重复数据?

详谈归并排序时间复杂度过程推导----软考

【ManageEngine】网络性能监控工具
![[Harmony OS] [ArkUI] ets development graphics and animation drawing](/img/36/f4c91f794b1321f11a24505d1617fb.png)
[Harmony OS] [ArkUI] ets development graphics and animation drawing

Quantitative Genetics Heritability Calculation 2: Half Siblings and Full Siblings

Docker部署MySQL

C Advanced-C Language File Operation

STM32系列单片机使用心得

aur安装报错一个或多个文件没有通过有效性检查!

力扣202-快乐数——哈希集合
随机推荐
Transaction rolled back because it has been marked as rollback-only
HAL库的使用之Cube配置编码器输入捕获模式
面向6G的欠采样相移键控可见光调制方案
绕过反调试fuck-debugger
unity urp 实现遮挡显示角色轮廓
Openresty执行lua脚本
Storage System Architecture Evolution
Nacos源码安装
【Harmony OS】【ARK UI】Custom popup
STM32系列单片机使用心得
3.3V控制输出5V的方法
feof它可不简单。
滑动窗口篇
matlab simulink 温度控制时延系统 模糊pid和smith控制
The development trend of software testing
初识二叉树
【LeetCode】761.特殊的二进制序列
2022牛客多校联赛第七场 题解
dsafasfdasfasf
使用Redis zset做消息队列