当前位置:网站首页>[code analysis (7)] communication efficient learning of deep networks from decentralized data
[code analysis (7)] communication efficient learning of deep networks from decentralized data
2022-04-23 13:50:00 【Silent city of the sky】
baseline_main.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from utils import get_dataset
from options import args_parser
from update import test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
if __name__ == '__main__':
args = args_parser()
if args.gpu:
torch.cuda.set_device(args.gpu)
device = 'cuda' if args.gpu else 'cpu'
# load datasets
train_dataset, test_dataset, _ = get_dataset(args)
# BUILD MODEL
if args.model == 'cnn':
# Convolutional neural netork
if args.dataset == 'mnist':
global_model = CNNMnist(args=args)
elif args.dataset == 'fmnist':
global_model = CNNFashion_Mnist(args=args)
elif args.dataset == 'cifar':
global_model = CNNCifar(args=args)
elif args.model == 'mlp':
# Multi-layer preceptron
img_size = train_dataset[0][0].shape
len_in = 1
for x in img_size:
len_in *= x
global_model = MLP(dim_in=len_in, dim_hidden=64,
dim_out=args.num_classes)
else:
exit('Error: unrecognized model')
# Set the model to train and send it to device.
global_model.to(device)
global_model.train()
print(global_model)
# Training
# Set optimizer and criterion
'''
Set optimizer and updata.py in LocalUpdate Function of
update_weights in
equally
'''
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(global_model.parameters(), lr=args.lr,
momentum=0.5)
elif args.optimizer == 'adam':
optimizer = torch.optim.Adam(global_model.parameters(), lr=args.lr,
weight_decay=1e-4)
# It specifies batch_size=64
# torch.Size([64, 3, 32, 32])
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
criterion = torch.nn.NLLLoss().to(device)
epoch_loss = []
'''
Below for loop :
and updata.py in LocalUpdate Function of
update_weights In the same
The difference is update.py Medium is args.local_ep
Here is args.epochs
'''
# print('000000000000000000')
# print(type(trainloader))
# <class 'torch.utils.data.dataloader.DataLoader'>
# print(trainloader.dataset)
'''
Dataset CIFAR10
Number of datapoints: 50000
Root location: ../data/cifar/
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
)
'''
for epoch in tqdm(range(args.epochs)):
batch_loss = []
for batch_idx, (images, labels) in enumerate(trainloader):
images, labels = images.to(device), labels.to(device)
'''
batch_idx:
(images, labels)
enumerate(self.trainloader) The training set data will be a
batch One batch Take it out to train
Use enumerate Conduct dataloader Used to read data in
The training of neural network is the first data reading method , Its basic form
That is to say for index, item in enumerate(dataloader['train']),
among item in 0 For data ,1 by label.
'''
# update.py Medium is model.zero_grad()
optimizer.zero_grad()
'''
First set the gradient to zero (optimizer.zero_grad())
'''
outputs = global_model(images)
loss = criterion(outputs, labels)
loss.backward()
'''
The gradient value of each parameter is obtained by back propagation calculation (loss.backward())
'''
optimizer.step()
'''
Finally, a one-step parameter update is performed by gradient descent (optimizer.step())
'''
# print('000000000000000')
# print(batch_idx) # 0,1,2,.....
# print(len(images)) # 64
# print(type(images)) # <class 'torch.Tensor'>
# print(enumerate(trainloader)) # <enumerate object at 0x00000293CC019F98>
# print(images.shape)
'''
torch.Size([64, 3, 32, 32])
tensor([[ [[ 0.6235, 0.7098, 0.8196, ..., 0.2392, 0.4039, -0.2863],
...
[ 0.5137, 0.7255, 0.7333, ..., 0.6706, 0.6627, 0.6314]],
...
[[ 0.9922, 0.9765, 0.9608, ..., 0.8353, 0.8353, 0.8588],
...
[ 0.7020, 0.7490, 0.7569, ..., 0.5059, 0.5451, 0.6471]],
]])
'''
# print(labels)
# print(labels.shape)
'''
tensor([1, 9, 1, 8, 7, 4, 4, 0, 7, 6, 1, 3, 8, 7, 6, 3, 4, 0, 5, 2, 2, 8, 5, 4,
2, 0, 8, 8, 3, 9, 5, 7, 3, 1, 7, 1, 9, 2, 4, 1, 2, 9, 8, 9, 8, 1, 4, 7,
8, 3, 5, 5, 9, 6, 0, 7, 6, 6, 1, 3, 5, 7, 3, 5])
torch.Size([64])
'''
# print(len(trainloader.dataset)) # 50000
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch+1, batch_idx * len(images), len(trainloader.dataset),
100. * batch_idx / len(trainloader), loss.item()))
'''
batch_idx=
batch_idx * len(images)
batch_idx * 64
batch_idx=0,...,50,...,100,...150
200,...,250,...,300,...350
400,...,450,...,500,...550
600,...,650,...,700,...750
'''
batch_loss.append(loss.item())
# print(batch_idx) # 781
loss_avg = sum(batch_loss)/len(batch_loss)
print('\nTrain loss:', loss_avg)
# once epoch Add once
epoch_loss.append(loss_avg)
# Plot loss
plt.figure()
plt.plot(range(len(epoch_loss)), epoch_loss)
plt.xlabel('epochs')
plt.ylabel('Train loss')
plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model,
args.epochs))
# testing
test_acc, test_loss = test_inference(args, global_model, test_dataset)
print('Test on', len(test_dataset), 'samples')
print("Test Accuracy: {:.2f}%".format(100*test_acc))
版权声明
本文为[Silent city of the sky]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230556365641.html
边栏推荐
- Express②(路由)
- Oracle defines self incrementing primary keys through triggers and sequences, and sets a scheduled task to insert a piece of data into the target table every second
- Tersus notes employee information 516 MySQL query (time period uniqueness judgment of 2 fields)
- leetcode--357. 统计各位数字都不同的数字个数
- Processing of ASM network not automatically started in 19C
- 自动化的艺术
- TIA博途中基於高速計數器觸發中斷OB40實現定點加工動作的具體方法示例
- Dolphin scheduler scheduling spark task stepping record
- Detailed explanation and usage of with function in SQL
- L2-024 部落 (25 分)
猜你喜欢
Small case of web login (including verification code login)
10g database cannot be started when using large memory host
零拷贝技术
Building MySQL environment under Ubuntu & getting to know SQL
TIA博途中基於高速計數器觸發中斷OB40實現定點加工動作的具體方法示例
ACFs file system creation, expansion, reduction and other configuration steps
2022年江西最新建筑八大员(质量员)模拟考试题库及答案解析
Express②(路由)
Express ② (routage)
Dolphin scheduler integrates Flink task pit records
随机推荐
解决tp6下载报错Could not find package topthink/think with stability stable.
Two ways to deal with conflicting data in MySQL and PG Libraries
Oracle lock table query and unlocking method
Storage scheme of video viewing records of users in station B
Dynamic subset division problem
Oracle generates millisecond timestamps
Information: 2021 / 9 / 29 10:01 - build completed with 1 error and 0 warnings in 11S 30ms error exception handling
淘宝发布宝贝提示“您的消保保证金额度不足,已启动到期保障”
Resolution: argument 'radius' is required to be an integer
GDB的使用
Reading notes: meta matrix factorization for federated rating predictions
Ora-600 encountered in Oracle environment [qkacon: fjswrwo]
Analysis of redo log generated by select command
Get the attribute value difference between two different objects with reflection and annotation
The query did not generate a result set exception resolution when the dolphin scheduler schedules the SQL task to create a table
初探 Lambda Powertools TypeScript
Lenovo Savior y9000x 2020
低频量化之明日涨停预测
Express中间件③(自定义中间件)
Software test system integration project management engineer full truth simulation question (including answer and analysis)