当前位置:网站首页>[code analysis (2)] communication efficient learning of deep networks from decentralized data
[code analysis (2)] communication efficient learning of deep networks from decentralized data
2022-04-23 13:47:00 【Silent city of the sky】
utils.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy
import torch
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid
def get_dataset(args):
""" Returns train and test datasets and a user group which is a dict where
the keys are the user index and the values are the corresponding data for
each of those users.
"""
'''
args Including how ????
Return training set and test set
Return to user group :
dict type {
key:value}
key: User's index
value: The corresponding data of these users
'''
if args.dataset == 'cifar':
data_dir = '../data/cifar/'
apply_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
'''
It's usually used Compose Put multiple steps together
transforms.ToTensor():
tensor yes CHW,numpy yes HWC
convert a PIL image to tensor (H*W*C) in
range [0,255] to a torch.Tensor(C*H*W)
in the range [0.0,1.0]
ToTensor() It's able to scale the grayscale from 0-255 Change to 0-1 Between
transforms.Normalize Normalize the tensor image with mean and standard deviation
And then there's transform.Normalize() Then put 0-1 Change to (-1,1)
For each channel ,Normalize Do the following :
image=(image-mean)/std
among mean and std Pass respectively (0.5,0.5,0.5) and (0.5,0.5,0.5) Make a designation
The original 0-1 minimum value 0 Has become a (0-0.5)/0.5=-1, And the maximum 1 Has become a (1-0.5)/0.5=1.
transforms.Normalize(mean, std) Calculation formula :
input[channel] = (input[channel] - mean[channel]) / std[channel]
Normalize() The function converts the data into a standard normal distribution , Make the model easier to converge
'''
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=apply_transform)
'''
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)
1.root, Express cifar10 Relative directory of data loading
2.train, Indicates whether to load the training set of the database ,false Load test set when
3.download, Indicates whether to automatically download cifar Data sets
4.transform, Indicates whether data preprocessing is required ,none For no pretreatment
'''
# sample training data amongst users
# Sample training data
if args.iid:
# Sample IID user data from Cifar
user_groups = cifar_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Cifar
'''
from Cifar Data sets sample users non-IID data
'''
if args.unequal:
# Chose uneuqal splits for every user
'''
Make an unbalanced distinction between each user's data
'''
raise NotImplementedError()
'''
about Cifar Data sets , It didn't come true non-IID Data imbalance code
'''
else:
# Chose euqal splits for every user
user_groups = cifar_noniid(train_dataset, args.num_users)
'''
Data between users noniid But the data division is uniform
'''
elif args.dataset == 'mnist' or 'fmnist':
if args.dataset == 'mnist':
data_dir = '../data/mnist/'
else:
data_dir = '../data/fmnist/'
apply_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.MNIST(data_dir, train=False, download=True,
transform=apply_transform)
# sample training data amongst users
if args.iid:
# Sample IID user data from Mnist
user_groups = mnist_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Mnist
if args.unequal:
# Chose uneuqal splits for every user
user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
'''
Data between users non-iid, And the data of each user is divided unequally
'''
else:
# Chose euqal splits for every user
user_groups = mnist_noniid(train_dataset, args.num_users)
return train_dataset, test_dataset, user_groups
def average_weights(w):
"""
Returns the average of the weights.
w It should be an array 🥧
"""
w_avg = copy.deepcopy(w[0])
'''
deepcopy function :
test = torch.randn(4, 4)
print(test)
tensor([[ 1.8693, -0.3236, -0.3465, 0.9226],
[ 0.0369, -0.5064, 1.1233, -0.7962],
[-0.5229, 1.0592, 0.4072, -1.2498],
[ 0.2530, -0.4465, -0.8152, -0.9206]])
w = copy.deepcopy(test[0])
print(w)
tensor([ 1.8693, -0.3236, -0.3465, 0.9226])
'''
# print('++++++++')
# print(w)
# print('=====')
# print(w_avg)
# print('++++++++++++++++++')
# print(len(w)) == 10
# This function accepts list Type of local_weights
for key in w_avg.keys():
for i in range(1, len(w)):
# range(1, 10):1,2,3,4,5,6,7,8,9
w_avg[key] += w[i][key]
w_avg[key] = torch.div(w_avg[key], len(w))
'''
The sum of all elements divided by w Size
w What type is it again ???
'''
return w_avg
def exp_details(args):
print('\nExperimental details:')
print(f' Model : {args.model}')
print(f' Optimizer : {args.optimizer}')
print(f' Learning : {args.lr}')
print(f' Global Rounds : {args.epochs}\n')
print(' Federated parameters:')
if args.iid:
print(' IID')
else:
print(' Non-IID')
print(f' Fraction of users : {args.frac}')
print(f' Local Batch size : {args.local_bs}')
print(f' Local Epochs : {args.local_ep}\n')
'''
epoch: One epoch Refers to all data sent into the network to complete a forward
The process of calculation and back propagation , Due to a epoch Often too big , We will
Divide it into smaller batches.
'''
return
版权声明
本文为[Silent city of the sky]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230556365836.html
边栏推荐
- 顶级元宇宙游戏Plato Farm,近期动作不断利好频频
- MySQL [acid + isolation level + redo log + undo log]
- Oracle lock table query and unlocking method
- Double pointer instrument panel reading (I)
- QT calling external program
- Modify the Jupiter notebook style
- GDB的使用
- ACFs file system creation, expansion, reduction and other configuration steps
- 解决tp6下载报错Could not find package topthink/think with stability stable.
- TCP reset Gongji principle and actual combat reproduction
猜你喜欢

【视频】线性回归中的贝叶斯推断与R语言预测工人工资数据|数据分享

Leetcode brush question 897 incremental sequential search tree

Reading notes: meta matrix factorization for federated rating predictions

Tangent space

Dolphin scheduler configuring dataX pit records

Campus takeout system - "nongzhibang" wechat native cloud development applet

聯想拯救者Y9000X 2020

Apache seatunnel 2.1.0 deployment and stepping on the pit

Ai21 labs | standing on the shoulders of giant frozen language models

Detailed explanation of redis (Basic + data type + transaction + persistence + publish and subscribe + master-slave replication + sentinel + cache penetration, breakdown and avalanche)
随机推荐
Reading notes: Secure federated matrix factorization
AI21 Labs | Standing on the Shoulders of Giant Frozen Language Models(站在巨大的冷冻语言模型的肩膀上)
Dolphin scheduler scheduling spark task stepping record
The query did not generate a result set exception resolution when the dolphin scheduler schedules the SQL task to create a table
Two ways to deal with conflicting data in MySQL and PG Libraries
MySQL [SQL performance analysis + SQL tuning]
JS time to get this Monday and Sunday, judge the time is today, before and after today
Test on the time required for Oracle to delete data with delete
【vmware】vmware tools 地址
Innobackupex incremental backup
Storage scheme of video viewing records of users in station B
Why do you need to learn container technology to engage in cloud native development
Ora-16047 of a DG environment: dgid mismatch between destination setting and target database troubleshooting and listening vncr features
ACFs file system creation, expansion, reduction and other configuration steps
Oracle job scheduled task usage details
Oracle generates millisecond timestamps
Reading notes: meta matrix factorization for federated rating predictions
[multi screen interaction] realize dual multi screen display II: startactivity mode
Oracle database combines the query result sets of multiple columns into one row
爱可可AI前沿推介 (4.23)