当前位置:网站首页>[code analysis (2)] communication efficient learning of deep networks from decentralized data
[code analysis (2)] communication efficient learning of deep networks from decentralized data
2022-04-23 13:47:00 【Silent city of the sky】
utils.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy
import torch
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid
def get_dataset(args):
""" Returns train and test datasets and a user group which is a dict where
the keys are the user index and the values are the corresponding data for
each of those users.
"""
'''
args Including how ????
Return training set and test set
Return to user group :
dict type {
key:value}
key: User's index
value: The corresponding data of these users
'''
if args.dataset == 'cifar':
data_dir = '../data/cifar/'
apply_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
'''
It's usually used Compose Put multiple steps together
transforms.ToTensor():
tensor yes CHW,numpy yes HWC
convert a PIL image to tensor (H*W*C) in
range [0,255] to a torch.Tensor(C*H*W)
in the range [0.0,1.0]
ToTensor() It's able to scale the grayscale from 0-255 Change to 0-1 Between
transforms.Normalize Normalize the tensor image with mean and standard deviation
And then there's transform.Normalize() Then put 0-1 Change to (-1,1)
For each channel ,Normalize Do the following :
image=(image-mean)/std
among mean and std Pass respectively (0.5,0.5,0.5) and (0.5,0.5,0.5) Make a designation
The original 0-1 minimum value 0 Has become a (0-0.5)/0.5=-1, And the maximum 1 Has become a (1-0.5)/0.5=1.
transforms.Normalize(mean, std) Calculation formula :
input[channel] = (input[channel] - mean[channel]) / std[channel]
Normalize() The function converts the data into a standard normal distribution , Make the model easier to converge
'''
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=apply_transform)
'''
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)
1.root, Express cifar10 Relative directory of data loading
2.train, Indicates whether to load the training set of the database ,false Load test set when
3.download, Indicates whether to automatically download cifar Data sets
4.transform, Indicates whether data preprocessing is required ,none For no pretreatment
'''
# sample training data amongst users
# Sample training data
if args.iid:
# Sample IID user data from Cifar
user_groups = cifar_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Cifar
'''
from Cifar Data sets sample users non-IID data
'''
if args.unequal:
# Chose uneuqal splits for every user
'''
Make an unbalanced distinction between each user's data
'''
raise NotImplementedError()
'''
about Cifar Data sets , It didn't come true non-IID Data imbalance code
'''
else:
# Chose euqal splits for every user
user_groups = cifar_noniid(train_dataset, args.num_users)
'''
Data between users noniid But the data division is uniform
'''
elif args.dataset == 'mnist' or 'fmnist':
if args.dataset == 'mnist':
data_dir = '../data/mnist/'
else:
data_dir = '../data/fmnist/'
apply_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.MNIST(data_dir, train=False, download=True,
transform=apply_transform)
# sample training data amongst users
if args.iid:
# Sample IID user data from Mnist
user_groups = mnist_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Mnist
if args.unequal:
# Chose uneuqal splits for every user
user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
'''
Data between users non-iid, And the data of each user is divided unequally
'''
else:
# Chose euqal splits for every user
user_groups = mnist_noniid(train_dataset, args.num_users)
return train_dataset, test_dataset, user_groups
def average_weights(w):
"""
Returns the average of the weights.
w It should be an array 🥧
"""
w_avg = copy.deepcopy(w[0])
'''
deepcopy function :
test = torch.randn(4, 4)
print(test)
tensor([[ 1.8693, -0.3236, -0.3465, 0.9226],
[ 0.0369, -0.5064, 1.1233, -0.7962],
[-0.5229, 1.0592, 0.4072, -1.2498],
[ 0.2530, -0.4465, -0.8152, -0.9206]])
w = copy.deepcopy(test[0])
print(w)
tensor([ 1.8693, -0.3236, -0.3465, 0.9226])
'''
# print('++++++++')
# print(w)
# print('=====')
# print(w_avg)
# print('++++++++++++++++++')
# print(len(w)) == 10
# This function accepts list Type of local_weights
for key in w_avg.keys():
for i in range(1, len(w)):
# range(1, 10):1,2,3,4,5,6,7,8,9
w_avg[key] += w[i][key]
w_avg[key] = torch.div(w_avg[key], len(w))
'''
The sum of all elements divided by w Size
w What type is it again ???
'''
return w_avg
def exp_details(args):
print('\nExperimental details:')
print(f' Model : {args.model}')
print(f' Optimizer : {args.optimizer}')
print(f' Learning : {args.lr}')
print(f' Global Rounds : {args.epochs}\n')
print(' Federated parameters:')
if args.iid:
print(' IID')
else:
print(' Non-IID')
print(f' Fraction of users : {args.frac}')
print(f' Local Batch size : {args.local_bs}')
print(f' Local Epochs : {args.local_ep}\n')
'''
epoch: One epoch Refers to all data sent into the network to complete a forward
The process of calculation and back propagation , Due to a epoch Often too big , We will
Divide it into smaller batches.
'''
return
版权声明
本文为[Silent city of the sky]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230556365836.html
边栏推荐
- Oracle defines self incrementing primary keys through triggers and sequences, and sets a scheduled task to insert a piece of data into the target table every second
- 零拷贝技术
- Solve tp6 download error course not find package topthink / think with stability stable
- The query did not generate a result set exception resolution when the dolphin scheduler schedules the SQL task to create a table
- Dolphin scheduler scheduling spark task stepping record
- SAP ui5 application development tutorial 72 - trial version of animation effect setting of SAP ui5 page routing
- Special window function rank, deny_ rank, row_ number
- PG SQL intercepts the string to the specified character position
- Oracle clear SQL cache
- sys. dbms_ scheduler. create_ Job creates scheduled tasks (more powerful and rich functions)
猜你喜欢

Detailed explanation of constraints of Oracle table

QT calling external program

Dynamic subset division problem

Leetcode | 38 appearance array

Plato farm, a top-level metauniverse game, has made frequent positive moves recently

Modify the Jupiter notebook style

Lenovo Savior y9000x 2020

顶级元宇宙游戏Plato Farm,近期动作不断利好频频

SQL learning | set operation

零拷貝技術
随机推荐
Oracle database recovery data
Part 3: docker installing MySQL container (custom port)
Window function row commonly used for fusion and de duplication_ number
联想拯救者Y9000X 2020
Analysis of cluster component gpnp failed to start successfully in RAC environment
TCP reset Gongji principle and actual combat reproduction
The interviewer dug a hole for me: what's the use of "/ /" in URI?
Solution: you have 18 unapplied migration (s) Your project may not work properly until you apply
Apache seatunnel 2.1.0 deployment and stepping on the pit
顶级元宇宙游戏Plato Farm,近期动作不断利好频频
Oracle and MySQL batch query all table names and table name comments under users
这个SQL语名是什么意思
Antd design form verification
GDB的使用
SAP ui5 application development tutorial 72 - animation effect setting of SAP ui5 page routing
Oracle RAC database instance startup exception analysis IPC send timeout
Get the attribute value difference between two different objects with reflection and annotation
At the same time, the problems of height collapse and outer margin overlap are solved
Detailed explanation of constraints of Oracle table
Parameter comparison of several e-book readers