当前位置:网站首页>[code analysis (7)] communication efficient learning of deep networks from decentralized data
[code analysis (7)] communication efficient learning of deep networks from decentralized data
2022-04-23 13:50:00 【Silent city of the sky】
baseline_main.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from utils import get_dataset
from options import args_parser
from update import test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
if __name__ == '__main__':
args = args_parser()
if args.gpu:
torch.cuda.set_device(args.gpu)
device = 'cuda' if args.gpu else 'cpu'
# load datasets
train_dataset, test_dataset, _ = get_dataset(args)
# BUILD MODEL
if args.model == 'cnn':
# Convolutional neural netork
if args.dataset == 'mnist':
global_model = CNNMnist(args=args)
elif args.dataset == 'fmnist':
global_model = CNNFashion_Mnist(args=args)
elif args.dataset == 'cifar':
global_model = CNNCifar(args=args)
elif args.model == 'mlp':
# Multi-layer preceptron
img_size = train_dataset[0][0].shape
len_in = 1
for x in img_size:
len_in *= x
global_model = MLP(dim_in=len_in, dim_hidden=64,
dim_out=args.num_classes)
else:
exit('Error: unrecognized model')
# Set the model to train and send it to device.
global_model.to(device)
global_model.train()
print(global_model)
# Training
# Set optimizer and criterion
'''
Set optimizer and updata.py in LocalUpdate Function of
update_weights in
equally
'''
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(global_model.parameters(), lr=args.lr,
momentum=0.5)
elif args.optimizer == 'adam':
optimizer = torch.optim.Adam(global_model.parameters(), lr=args.lr,
weight_decay=1e-4)
# It specifies batch_size=64
# torch.Size([64, 3, 32, 32])
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
criterion = torch.nn.NLLLoss().to(device)
epoch_loss = []
'''
Below for loop :
and updata.py in LocalUpdate Function of
update_weights In the same
The difference is update.py Medium is args.local_ep
Here is args.epochs
'''
# print('000000000000000000')
# print(type(trainloader))
# <class 'torch.utils.data.dataloader.DataLoader'>
# print(trainloader.dataset)
'''
Dataset CIFAR10
Number of datapoints: 50000
Root location: ../data/cifar/
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
)
'''
for epoch in tqdm(range(args.epochs)):
batch_loss = []
for batch_idx, (images, labels) in enumerate(trainloader):
images, labels = images.to(device), labels.to(device)
'''
batch_idx:
(images, labels)
enumerate(self.trainloader) The training set data will be a
batch One batch Take it out to train
Use enumerate Conduct dataloader Used to read data in
The training of neural network is the first data reading method , Its basic form
That is to say for index, item in enumerate(dataloader['train']),
among item in 0 For data ,1 by label.
'''
# update.py Medium is model.zero_grad()
optimizer.zero_grad()
'''
First set the gradient to zero (optimizer.zero_grad())
'''
outputs = global_model(images)
loss = criterion(outputs, labels)
loss.backward()
'''
The gradient value of each parameter is obtained by back propagation calculation (loss.backward())
'''
optimizer.step()
'''
Finally, a one-step parameter update is performed by gradient descent (optimizer.step())
'''
# print('000000000000000')
# print(batch_idx) # 0,1,2,.....
# print(len(images)) # 64
# print(type(images)) # <class 'torch.Tensor'>
# print(enumerate(trainloader)) # <enumerate object at 0x00000293CC019F98>
# print(images.shape)
'''
torch.Size([64, 3, 32, 32])
tensor([[ [[ 0.6235, 0.7098, 0.8196, ..., 0.2392, 0.4039, -0.2863],
...
[ 0.5137, 0.7255, 0.7333, ..., 0.6706, 0.6627, 0.6314]],
...
[[ 0.9922, 0.9765, 0.9608, ..., 0.8353, 0.8353, 0.8588],
...
[ 0.7020, 0.7490, 0.7569, ..., 0.5059, 0.5451, 0.6471]],
]])
'''
# print(labels)
# print(labels.shape)
'''
tensor([1, 9, 1, 8, 7, 4, 4, 0, 7, 6, 1, 3, 8, 7, 6, 3, 4, 0, 5, 2, 2, 8, 5, 4,
2, 0, 8, 8, 3, 9, 5, 7, 3, 1, 7, 1, 9, 2, 4, 1, 2, 9, 8, 9, 8, 1, 4, 7,
8, 3, 5, 5, 9, 6, 0, 7, 6, 6, 1, 3, 5, 7, 3, 5])
torch.Size([64])
'''
# print(len(trainloader.dataset)) # 50000
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch+1, batch_idx * len(images), len(trainloader.dataset),
100. * batch_idx / len(trainloader), loss.item()))
'''
batch_idx=
batch_idx * len(images)
batch_idx * 64
batch_idx=0,...,50,...,100,...150
200,...,250,...,300,...350
400,...,450,...,500,...550
600,...,650,...,700,...750
'''
batch_loss.append(loss.item())
# print(batch_idx) # 781
loss_avg = sum(batch_loss)/len(batch_loss)
print('\nTrain loss:', loss_avg)
# once epoch Add once
epoch_loss.append(loss_avg)
# Plot loss
plt.figure()
plt.plot(range(len(epoch_loss)), epoch_loss)
plt.xlabel('epochs')
plt.ylabel('Train loss')
plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model,
args.epochs))
# testing
test_acc, test_loss = test_inference(args, global_model, test_dataset)
print('Test on', len(test_dataset), 'samples')
print("Test Accuracy: {:.2f}%".format(100*test_acc))
版权声明
本文为[Silent city of the sky]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230556365641.html
边栏推荐
- 美联储数字货币最新进展
- 19c RAC steps for modifying VIP and scanip - same network segment
- 剑南春把文字游戏玩明白了
- Three characteristics of volatile keyword [data visibility, prohibition of instruction rearrangement and no guarantee of operation atomicity]
- QT calling external program
- [machine learning] Note 4. KNN + cross validation
- Zero copy technology
- 【vmware】vmware tools 地址
- 低频量化之明日涨停预测
- L2-024 部落 (25 分)
猜你喜欢
Example of specific method for TIA to trigger interrupt ob40 based on high-speed counter to realize fixed-point machining action
Exemple de méthode de réalisation de l'action d'usinage à point fixe basée sur l'interruption de déclenchement du compteur à grande vitesse ob40 pendant le voyage de tia Expo
淘宝发布宝贝提示“您的消保保证金额度不足,已启动到期保障”
Usereducer basic usage
Small case of web login (including verification code login)
Information: 2021 / 9 / 29 10:01 - build completed with 1 error and 0 warnings in 11S 30ms error exception handling
MySQL [SQL performance analysis + SQL tuning]
大专的我,闭关苦学 56 天,含泪拿下阿里 offer,五轮面试,六个小时灵魂拷问
Tangent space
Common types and basic usage of input plug-in of logstash data processing service
随机推荐
Tangent space
Interval query through rownum
Test on the time required for Oracle to delete data with delete
大专的我,闭关苦学 56 天,含泪拿下阿里 offer,五轮面试,六个小时灵魂拷问
MySQL and PgSQL time related operations
服务器中挖矿病毒了,屮
Publish custom plug-ins to local server
Dolphin scheduler integrates Flink task pit records
Three characteristics of volatile keyword [data visibility, prohibition of instruction rearrangement and no guarantee of operation atomicity]
Static interface method calls are not supported at language level '5'
Lenovo Savior y9000x 2020
Dolphin scheduler configuring dataX pit records
Search ideas and cases of large amount of Oracle redo log
Oracle view related
低频量化之明日涨停预测
PG SQL intercepts the string to the specified character position
UNIX final exam summary -- for direct Department
Modify the Jupiter notebook style
Reading notes: meta matrix factorization for federated rating predictions
神经元与神经网络