当前位置:网站首页>[code analysis (7)] communication efficient learning of deep networks from decentralized data
[code analysis (7)] communication efficient learning of deep networks from decentralized data
2022-04-23 13:50:00 【Silent city of the sky】
baseline_main.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from utils import get_dataset
from options import args_parser
from update import test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
if __name__ == '__main__':
args = args_parser()
if args.gpu:
torch.cuda.set_device(args.gpu)
device = 'cuda' if args.gpu else 'cpu'
# load datasets
train_dataset, test_dataset, _ = get_dataset(args)
# BUILD MODEL
if args.model == 'cnn':
# Convolutional neural netork
if args.dataset == 'mnist':
global_model = CNNMnist(args=args)
elif args.dataset == 'fmnist':
global_model = CNNFashion_Mnist(args=args)
elif args.dataset == 'cifar':
global_model = CNNCifar(args=args)
elif args.model == 'mlp':
# Multi-layer preceptron
img_size = train_dataset[0][0].shape
len_in = 1
for x in img_size:
len_in *= x
global_model = MLP(dim_in=len_in, dim_hidden=64,
dim_out=args.num_classes)
else:
exit('Error: unrecognized model')
# Set the model to train and send it to device.
global_model.to(device)
global_model.train()
print(global_model)
# Training
# Set optimizer and criterion
'''
Set optimizer and updata.py in LocalUpdate Function of
update_weights in
equally
'''
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(global_model.parameters(), lr=args.lr,
momentum=0.5)
elif args.optimizer == 'adam':
optimizer = torch.optim.Adam(global_model.parameters(), lr=args.lr,
weight_decay=1e-4)
# It specifies batch_size=64
# torch.Size([64, 3, 32, 32])
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
criterion = torch.nn.NLLLoss().to(device)
epoch_loss = []
'''
Below for loop :
and updata.py in LocalUpdate Function of
update_weights In the same
The difference is update.py Medium is args.local_ep
Here is args.epochs
'''
# print('000000000000000000')
# print(type(trainloader))
# <class 'torch.utils.data.dataloader.DataLoader'>
# print(trainloader.dataset)
'''
Dataset CIFAR10
Number of datapoints: 50000
Root location: ../data/cifar/
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
)
'''
for epoch in tqdm(range(args.epochs)):
batch_loss = []
for batch_idx, (images, labels) in enumerate(trainloader):
images, labels = images.to(device), labels.to(device)
'''
batch_idx:
(images, labels)
enumerate(self.trainloader) The training set data will be a
batch One batch Take it out to train
Use enumerate Conduct dataloader Used to read data in
The training of neural network is the first data reading method , Its basic form
That is to say for index, item in enumerate(dataloader['train']),
among item in 0 For data ,1 by label.
'''
# update.py Medium is model.zero_grad()
optimizer.zero_grad()
'''
First set the gradient to zero (optimizer.zero_grad())
'''
outputs = global_model(images)
loss = criterion(outputs, labels)
loss.backward()
'''
The gradient value of each parameter is obtained by back propagation calculation (loss.backward())
'''
optimizer.step()
'''
Finally, a one-step parameter update is performed by gradient descent (optimizer.step())
'''
# print('000000000000000')
# print(batch_idx) # 0,1,2,.....
# print(len(images)) # 64
# print(type(images)) # <class 'torch.Tensor'>
# print(enumerate(trainloader)) # <enumerate object at 0x00000293CC019F98>
# print(images.shape)
'''
torch.Size([64, 3, 32, 32])
tensor([[ [[ 0.6235, 0.7098, 0.8196, ..., 0.2392, 0.4039, -0.2863],
...
[ 0.5137, 0.7255, 0.7333, ..., 0.6706, 0.6627, 0.6314]],
...
[[ 0.9922, 0.9765, 0.9608, ..., 0.8353, 0.8353, 0.8588],
...
[ 0.7020, 0.7490, 0.7569, ..., 0.5059, 0.5451, 0.6471]],
]])
'''
# print(labels)
# print(labels.shape)
'''
tensor([1, 9, 1, 8, 7, 4, 4, 0, 7, 6, 1, 3, 8, 7, 6, 3, 4, 0, 5, 2, 2, 8, 5, 4,
2, 0, 8, 8, 3, 9, 5, 7, 3, 1, 7, 1, 9, 2, 4, 1, 2, 9, 8, 9, 8, 1, 4, 7,
8, 3, 5, 5, 9, 6, 0, 7, 6, 6, 1, 3, 5, 7, 3, 5])
torch.Size([64])
'''
# print(len(trainloader.dataset)) # 50000
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch+1, batch_idx * len(images), len(trainloader.dataset),
100. * batch_idx / len(trainloader), loss.item()))
'''
batch_idx=
batch_idx * len(images)
batch_idx * 64
batch_idx=0,...,50,...,100,...150
200,...,250,...,300,...350
400,...,450,...,500,...550
600,...,650,...,700,...750
'''
batch_loss.append(loss.item())
# print(batch_idx) # 781
loss_avg = sum(batch_loss)/len(batch_loss)
print('\nTrain loss:', loss_avg)
# once epoch Add once
epoch_loss.append(loss_avg)
# Plot loss
plt.figure()
plt.plot(range(len(epoch_loss)), epoch_loss)
plt.xlabel('epochs')
plt.ylabel('Train loss')
plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model,
args.epochs))
# testing
test_acc, test_loss = test_inference(args, global_model, test_dataset)
print('Test on', len(test_dataset), 'samples')
print("Test Accuracy: {:.2f}%".format(100*test_acc))
版权声明
本文为[Silent city of the sky]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230556365641.html
边栏推荐
- Detailed explanation of redis (Basic + data type + transaction + persistence + publish and subscribe + master-slave replication + sentinel + cache penetration, breakdown and avalanche)
- Building MySQL environment under Ubuntu & getting to know SQL
- Detailed explanation and usage of with function in SQL
- Oracle and MySQL batch query all table names and table name comments under users
- Oracle creates tablespaces and modifies user default tablespaces
- Common types and basic usage of input plug-in of logstash data processing service
- RAC environment error reporting ora-00239: timeout waiting for control file enqueue troubleshooting
- 19c environment ora-01035 login error handling
- Failure to connect due to improper parameter setting of Rac environment database node. Troubleshooting
- Utilisation de GDB
猜你喜欢
![[machine learning] Note 4. KNN + cross validation](/img/a1/5afccedf509eda92a0fe5bf9b6cbe9.png)
[machine learning] Note 4. KNN + cross validation

Express②(路由)

Special window function rank, deny_ rank, row_ number

自动化的艺术

Using Baidu Intelligent Cloud face detection interface to achieve photo quality detection

Leetcode | 38 appearance array

零拷贝技术

Search ideas and cases of large amount of Oracle redo log

SQL learning window function

聯想拯救者Y9000X 2020
随机推荐
美联储数字货币最新进展
Dynamic subset division problem
Es introduction learning notes
Express②(路由)
第十五章 软件工程新技术
Detailed explanation and usage of with function in SQL
Kettle--控件解析
Django::Did you install mysqlclient?
初探 Lambda Powertools TypeScript
leetcode--977. Squares of a Sorted Array
Analysis of unused index columns caused by implicit conversion of timestamp
Oracle modify default temporary tablespace
Building MySQL environment under Ubuntu & getting to know SQL
Get the attribute value difference between two different objects with reflection and annotation
YARN线上动态资源调优
Oracle view related
Double pointer instrument panel reading (I)
Oracle RAC database instance startup exception analysis IPC send timeout
Android 面试主题集合整理
Usereducer basic usage