当前位置:网站首页>[code analysis (2)] communication efficient learning of deep networks from decentralized data
[code analysis (2)] communication efficient learning of deep networks from decentralized data
2022-04-23 13:47:00 【Silent city of the sky】
utils.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy
import torch
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid
def get_dataset(args):
""" Returns train and test datasets and a user group which is a dict where
the keys are the user index and the values are the corresponding data for
each of those users.
"""
'''
args Including how ????
Return training set and test set
Return to user group :
dict type {
key:value}
key: User's index
value: The corresponding data of these users
'''
if args.dataset == 'cifar':
data_dir = '../data/cifar/'
apply_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
'''
It's usually used Compose Put multiple steps together
transforms.ToTensor():
tensor yes CHW,numpy yes HWC
convert a PIL image to tensor (H*W*C) in
range [0,255] to a torch.Tensor(C*H*W)
in the range [0.0,1.0]
ToTensor() It's able to scale the grayscale from 0-255 Change to 0-1 Between
transforms.Normalize Normalize the tensor image with mean and standard deviation
And then there's transform.Normalize() Then put 0-1 Change to (-1,1)
For each channel ,Normalize Do the following :
image=(image-mean)/std
among mean and std Pass respectively (0.5,0.5,0.5) and (0.5,0.5,0.5) Make a designation
The original 0-1 minimum value 0 Has become a (0-0.5)/0.5=-1, And the maximum 1 Has become a (1-0.5)/0.5=1.
transforms.Normalize(mean, std) Calculation formula :
input[channel] = (input[channel] - mean[channel]) / std[channel]
Normalize() The function converts the data into a standard normal distribution , Make the model easier to converge
'''
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=apply_transform)
'''
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)
1.root, Express cifar10 Relative directory of data loading
2.train, Indicates whether to load the training set of the database ,false Load test set when
3.download, Indicates whether to automatically download cifar Data sets
4.transform, Indicates whether data preprocessing is required ,none For no pretreatment
'''
# sample training data amongst users
# Sample training data
if args.iid:
# Sample IID user data from Cifar
user_groups = cifar_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Cifar
'''
from Cifar Data sets sample users non-IID data
'''
if args.unequal:
# Chose uneuqal splits for every user
'''
Make an unbalanced distinction between each user's data
'''
raise NotImplementedError()
'''
about Cifar Data sets , It didn't come true non-IID Data imbalance code
'''
else:
# Chose euqal splits for every user
user_groups = cifar_noniid(train_dataset, args.num_users)
'''
Data between users noniid But the data division is uniform
'''
elif args.dataset == 'mnist' or 'fmnist':
if args.dataset == 'mnist':
data_dir = '../data/mnist/'
else:
data_dir = '../data/fmnist/'
apply_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.MNIST(data_dir, train=False, download=True,
transform=apply_transform)
# sample training data amongst users
if args.iid:
# Sample IID user data from Mnist
user_groups = mnist_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Mnist
if args.unequal:
# Chose uneuqal splits for every user
user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
'''
Data between users non-iid, And the data of each user is divided unequally
'''
else:
# Chose euqal splits for every user
user_groups = mnist_noniid(train_dataset, args.num_users)
return train_dataset, test_dataset, user_groups
def average_weights(w):
"""
Returns the average of the weights.
w It should be an array 🥧
"""
w_avg = copy.deepcopy(w[0])
'''
deepcopy function :
test = torch.randn(4, 4)
print(test)
tensor([[ 1.8693, -0.3236, -0.3465, 0.9226],
[ 0.0369, -0.5064, 1.1233, -0.7962],
[-0.5229, 1.0592, 0.4072, -1.2498],
[ 0.2530, -0.4465, -0.8152, -0.9206]])
w = copy.deepcopy(test[0])
print(w)
tensor([ 1.8693, -0.3236, -0.3465, 0.9226])
'''
# print('++++++++')
# print(w)
# print('=====')
# print(w_avg)
# print('++++++++++++++++++')
# print(len(w)) == 10
# This function accepts list Type of local_weights
for key in w_avg.keys():
for i in range(1, len(w)):
# range(1, 10):1,2,3,4,5,6,7,8,9
w_avg[key] += w[i][key]
w_avg[key] = torch.div(w_avg[key], len(w))
'''
The sum of all elements divided by w Size
w What type is it again ???
'''
return w_avg
def exp_details(args):
print('\nExperimental details:')
print(f' Model : {args.model}')
print(f' Optimizer : {args.optimizer}')
print(f' Learning : {args.lr}')
print(f' Global Rounds : {args.epochs}\n')
print(' Federated parameters:')
if args.iid:
print(' IID')
else:
print(' Non-IID')
print(f' Fraction of users : {args.frac}')
print(f' Local Batch size : {args.local_bs}')
print(f' Local Epochs : {args.local_ep}\n')
'''
epoch: One epoch Refers to all data sent into the network to complete a forward
The process of calculation and back propagation , Due to a epoch Often too big , We will
Divide it into smaller batches.
'''
return
版权声明
本文为[Silent city of the sky]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230556365836.html
边栏推荐
- Solution: you have 18 unapplied migration (s) Your project may not work properly until you apply
- Es introduction learning notes
- Failure to connect due to improper parameter setting of Rac environment database node. Troubleshooting
- Small case of web login (including verification code login)
- Oracle view related
- Why do you need to learn container technology to engage in cloud native development
- TCP reset Gongji principle and actual combat reproduction
- Tangent space
- Antd design form verification
- Explanation of input components in Chapter 16
猜你喜欢
Information: 2021 / 9 / 29 10:01 - build completed with 1 error and 0 warnings in 11S 30ms error exception handling
SSM project deployed in Alibaba cloud
Reading notes: meta matrix factorization for federated rating predictions
零拷貝技術
Leetcode brush question 897 incremental sequential search tree
Leetcode brush question 𞓜 13 Roman numeral to integer
Unified task distribution scheduling execution framework
Why do you need to learn container technology to engage in cloud native development
Detailed explanation of redis (Basic + data type + transaction + persistence + publish and subscribe + master-slave replication + sentinel + cache penetration, breakdown and avalanche)
为什么从事云原生开发需要学习容器技术
随机推荐
Interval query through rownum
Basic SQL query and learning
Double pointer instrument panel reading (I)
Information: 2021 / 9 / 29 10:01 - build completed with 1 error and 0 warnings in 11S 30ms error exception handling
Opening: identification of double pointer instrument panel
Leetcode | 38 appearance array
19c environment ora-01035 login error handling
Isparta is a tool that generates webp, GIF and apng from PNG and supports the transformation of webp, GIF and apng
为什么从事云原生开发需要学习容器技术
Resolution: argument 'radius' is required to be an integer
Apache Atlas Compilation and installation records
Cross carbon market and Web3 to achieve renewable transformation
What does the SQL name mean
Es introduction learning notes
Dolphin scheduler source package Src tar. GZ decompression problem
Test on the time required for Oracle to delete data with delete
Lenovo Savior y9000x 2020
Oracle and MySQL batch query all table names and table name comments under users
Use of GDB
Leetcode? The first common node of two linked lists