当前位置:网站首页>[code analysis (7)] communication efficient learning of deep networks from decentralized data
[code analysis (7)] communication efficient learning of deep networks from decentralized data
2022-04-23 13:50:00 【Silent city of the sky】
baseline_main.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from utils import get_dataset
from options import args_parser
from update import test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
if __name__ == '__main__':
args = args_parser()
if args.gpu:
torch.cuda.set_device(args.gpu)
device = 'cuda' if args.gpu else 'cpu'
# load datasets
train_dataset, test_dataset, _ = get_dataset(args)
# BUILD MODEL
if args.model == 'cnn':
# Convolutional neural netork
if args.dataset == 'mnist':
global_model = CNNMnist(args=args)
elif args.dataset == 'fmnist':
global_model = CNNFashion_Mnist(args=args)
elif args.dataset == 'cifar':
global_model = CNNCifar(args=args)
elif args.model == 'mlp':
# Multi-layer preceptron
img_size = train_dataset[0][0].shape
len_in = 1
for x in img_size:
len_in *= x
global_model = MLP(dim_in=len_in, dim_hidden=64,
dim_out=args.num_classes)
else:
exit('Error: unrecognized model')
# Set the model to train and send it to device.
global_model.to(device)
global_model.train()
print(global_model)
# Training
# Set optimizer and criterion
'''
Set optimizer and updata.py in LocalUpdate Function of
update_weights in
equally
'''
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(global_model.parameters(), lr=args.lr,
momentum=0.5)
elif args.optimizer == 'adam':
optimizer = torch.optim.Adam(global_model.parameters(), lr=args.lr,
weight_decay=1e-4)
# It specifies batch_size=64
# torch.Size([64, 3, 32, 32])
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
criterion = torch.nn.NLLLoss().to(device)
epoch_loss = []
'''
Below for loop :
and updata.py in LocalUpdate Function of
update_weights In the same
The difference is update.py Medium is args.local_ep
Here is args.epochs
'''
# print('000000000000000000')
# print(type(trainloader))
# <class 'torch.utils.data.dataloader.DataLoader'>
# print(trainloader.dataset)
'''
Dataset CIFAR10
Number of datapoints: 50000
Root location: ../data/cifar/
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
)
'''
for epoch in tqdm(range(args.epochs)):
batch_loss = []
for batch_idx, (images, labels) in enumerate(trainloader):
images, labels = images.to(device), labels.to(device)
'''
batch_idx:
(images, labels)
enumerate(self.trainloader) The training set data will be a
batch One batch Take it out to train
Use enumerate Conduct dataloader Used to read data in
The training of neural network is the first data reading method , Its basic form
That is to say for index, item in enumerate(dataloader['train']),
among item in 0 For data ,1 by label.
'''
# update.py Medium is model.zero_grad()
optimizer.zero_grad()
'''
First set the gradient to zero (optimizer.zero_grad())
'''
outputs = global_model(images)
loss = criterion(outputs, labels)
loss.backward()
'''
The gradient value of each parameter is obtained by back propagation calculation (loss.backward())
'''
optimizer.step()
'''
Finally, a one-step parameter update is performed by gradient descent (optimizer.step())
'''
# print('000000000000000')
# print(batch_idx) # 0,1,2,.....
# print(len(images)) # 64
# print(type(images)) # <class 'torch.Tensor'>
# print(enumerate(trainloader)) # <enumerate object at 0x00000293CC019F98>
# print(images.shape)
'''
torch.Size([64, 3, 32, 32])
tensor([[ [[ 0.6235, 0.7098, 0.8196, ..., 0.2392, 0.4039, -0.2863],
...
[ 0.5137, 0.7255, 0.7333, ..., 0.6706, 0.6627, 0.6314]],
...
[[ 0.9922, 0.9765, 0.9608, ..., 0.8353, 0.8353, 0.8588],
...
[ 0.7020, 0.7490, 0.7569, ..., 0.5059, 0.5451, 0.6471]],
]])
'''
# print(labels)
# print(labels.shape)
'''
tensor([1, 9, 1, 8, 7, 4, 4, 0, 7, 6, 1, 3, 8, 7, 6, 3, 4, 0, 5, 2, 2, 8, 5, 4,
2, 0, 8, 8, 3, 9, 5, 7, 3, 1, 7, 1, 9, 2, 4, 1, 2, 9, 8, 9, 8, 1, 4, 7,
8, 3, 5, 5, 9, 6, 0, 7, 6, 6, 1, 3, 5, 7, 3, 5])
torch.Size([64])
'''
# print(len(trainloader.dataset)) # 50000
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch+1, batch_idx * len(images), len(trainloader.dataset),
100. * batch_idx / len(trainloader), loss.item()))
'''
batch_idx=
batch_idx * len(images)
batch_idx * 64
batch_idx=0,...,50,...,100,...150
200,...,250,...,300,...350
400,...,450,...,500,...550
600,...,650,...,700,...750
'''
batch_loss.append(loss.item())
# print(batch_idx) # 781
loss_avg = sum(batch_loss)/len(batch_loss)
print('\nTrain loss:', loss_avg)
# once epoch Add once
epoch_loss.append(loss_avg)
# Plot loss
plt.figure()
plt.plot(range(len(epoch_loss)), epoch_loss)
plt.xlabel('epochs')
plt.ylabel('Train loss')
plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model,
args.epochs))
# testing
test_acc, test_loss = test_inference(args, global_model, test_dataset)
print('Test on', len(test_dataset), 'samples')
print("Test Accuracy: {:.2f}%".format(100*test_acc))
版权声明
本文为[Silent city of the sky]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230556365641.html
边栏推荐
- Use future and countdownlatch to realize multithreading to execute multiple asynchronous tasks, and return results after all tasks are completed
- Analysis of redo log generated by select command
- Small case of web login (including verification code login)
- [code analysis (3)] communication efficient learning of deep networks from decentralized data
- Handling of high usage of Oracle undo
- 零拷贝技术
- MySQL [SQL performance analysis + SQL tuning]
- YARN线上动态资源调优
- Oracle kills the executing SQL
- 大专的我,闭关苦学 56 天,含泪拿下阿里 offer,五轮面试,六个小时灵魂拷问
猜你喜欢

联想拯救者Y9000X 2020

Building MySQL environment under Ubuntu & getting to know SQL

PG SQL intercepts the string to the specified character position

Oracle defines self incrementing primary keys through triggers and sequences, and sets a scheduled task to insert a piece of data into the target table every second

TIA博途中基于高速计数器触发中断OB40实现定点加工动作的具体方法示例

ACFs file system creation, expansion, reduction and other configuration steps

Express②(路由)

Small case of web login (including verification code login)

JUC interview questions about synchronized, ThreadLocal, thread pool and atomic atomic classes

SSM project deployed in Alibaba cloud
随机推荐
MySQL [SQL performance analysis + SQL tuning]
Handling of high usage of Oracle undo
服务器中挖矿病毒了,屮
Modify the Jupiter notebook style
Kettle--控件解析
第十五章 软件工程新技术
Dolphin scheduler integrates Flink task pit records
Tensorflow Download
Oracle modify default temporary tablespace
[machine learning] Note 4. KNN + cross validation
[code analysis (5)] communication efficient learning of deep networks from decentralized data
Campus takeout system - "nongzhibang" wechat native cloud development applet
解决方案架构师的小锦囊 - 架构图的 5 种类型
Lenovo Savior y9000x 2020
JUC interview questions about synchronized, ThreadLocal, thread pool and atomic atomic classes
[code analysis (3)] communication efficient learning of deep networks from decentralized data
Express②(路由)
Zero copy technology
19c environment ora-01035 login error handling
TIA博途中基於高速計數器觸發中斷OB40實現定點加工動作的具體方法示例