当前位置:网站首页>【代码解析(2)】Communication-Efficient Learning of Deep Networks from Decentralized Data
【代码解析(2)】Communication-Efficient Learning of Deep Networks from Decentralized Data
2022-04-23 05:58:00 【缄默的天空之城】
utils.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy
import torch
from torchvision import datasets, transforms
from sampling import mnist_iid, mnist_noniid, mnist_noniid_unequal
from sampling import cifar_iid, cifar_noniid
def get_dataset(args):
""" Returns train and test datasets and a user group which is a dict where
the keys are the user index and the values are the corresponding data for
each of those users.
"""
'''
args包含怎么????
返回训练集和测试集
返回用户组:
dict类型{
key:value}
key:用户的索引
value:这些用户的相应数据
'''
if args.dataset == 'cifar':
data_dir = '../data/cifar/'
apply_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
'''
一般用Compose把多个步骤整合到一起
transforms.ToTensor():
tensor是CHW,numpy是HWC
convert a PIL image to tensor (H*W*C) in
range [0,255] to a torch.Tensor(C*H*W)
in the range [0.0,1.0]
ToTensor()能够把灰度范围从0-255变换到0-1之间
transforms.Normalize用均值和标准差归一化张量图像
而后面的transform.Normalize()则把0-1变换到(-1,1)
对每个通道而言,Normalize执行以下操作:
image=(image-mean)/std
其中mean和std分别通过(0.5,0.5,0.5)和(0.5,0.5,0.5)进行指定
原来的0-1最小值0则变成(0-0.5)/0.5=-1,而最大值1则变成(1-0.5)/0.5=1.
transforms.Normalize(mean, std) 的计算公式:
input[channel] = (input[channel] - mean[channel]) / std[channel]
Normalize() 函数的作用是将数据转换为标准正态分布,使模型更容易收敛
'''
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=apply_transform)
'''
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=None)
1.root,表示cifar10数据的加载的相对目录
2.train,表示是否加载数据库的训练集,false的时候加载测试集
3.download,表示是否自动下载cifar数据集
4.transform,表示是否需要对数据进行预处理,none为不进行预处理
'''
# sample training data amongst users
# 采样训练数据
if args.iid:
# Sample IID user data from Cifar
user_groups = cifar_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Cifar
'''
从Cifar数据集对用户采样non-IID数据
'''
if args.unequal:
# Chose uneuqal splits for every user
'''
对每个用户数据进行不平衡区分
'''
raise NotImplementedError()
'''
对于Cifar数据集,没有实现non-IID数据不平衡代码
'''
else:
# Chose euqal splits for every user
user_groups = cifar_noniid(train_dataset, args.num_users)
'''
用户之间数据noniid但是数据划分是均匀的
'''
elif args.dataset == 'mnist' or 'fmnist':
if args.dataset == 'mnist':
data_dir = '../data/mnist/'
else:
data_dir = '../data/fmnist/'
apply_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.MNIST(data_dir, train=False, download=True,
transform=apply_transform)
# sample training data amongst users
if args.iid:
# Sample IID user data from Mnist
user_groups = mnist_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Mnist
if args.unequal:
# Chose uneuqal splits for every user
user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
'''
用户之间数据non-iid,并且每个用户数据不平等划分
'''
else:
# Chose euqal splits for every user
user_groups = mnist_noniid(train_dataset, args.num_users)
return train_dataset, test_dataset, user_groups
def average_weights(w):
"""
Returns the average of the weights.
w应该是数组🥧
"""
w_avg = copy.deepcopy(w[0])
'''
deepcopy函数:
test = torch.randn(4, 4)
print(test)
tensor([[ 1.8693, -0.3236, -0.3465, 0.9226],
[ 0.0369, -0.5064, 1.1233, -0.7962],
[-0.5229, 1.0592, 0.4072, -1.2498],
[ 0.2530, -0.4465, -0.8152, -0.9206]])
w = copy.deepcopy(test[0])
print(w)
tensor([ 1.8693, -0.3236, -0.3465, 0.9226])
'''
# print('++++++++')
# print(w)
# print('=====')
# print(w_avg)
# print('++++++++++++++++++')
# print(len(w)) == 10
# 这个函数接受的是list类型的local_weights
for key in w_avg.keys():
for i in range(1, len(w)):
# range(1, 10):1,2,3,4,5,6,7,8,9
w_avg[key] += w[i][key]
w_avg[key] = torch.div(w_avg[key], len(w))
'''
所有元素之和除以w的大小
w是什么类型来着???
'''
return w_avg
def exp_details(args):
print('\nExperimental details:')
print(f' Model : {args.model}')
print(f' Optimizer : {args.optimizer}')
print(f' Learning : {args.lr}')
print(f' Global Rounds : {args.epochs}\n')
print(' Federated parameters:')
if args.iid:
print(' IID')
else:
print(' Non-IID')
print(f' Fraction of users : {args.frac}')
print(f' Local Batch size : {args.local_bs}')
print(f' Local Epochs : {args.local_ep}\n')
'''
epoch:一个epoch指代所有的数据送入网络中完成一次前向
计算及反向传播的过程,由于一个epoch常常太大,我们会
将它分成几个较小的batches。
'''
return
版权声明
本文为[缄默的天空之城]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_42139772/article/details/122549593
边栏推荐
猜你喜欢
随机推荐
关于软件的空间占用,安装目录
TP5中的getField()方法变化,tp5获取单个字段值
New features of ES6
使用jsonwebtoken生成访问密钥
A website that directly downloads PNG icons without logging in
Scientists say Australian plan to cull up to 10,000 wild horses doesn’t go far enough
Promise(四)
.Net Core 下使用 Quartz —— 【4】作业和触发器之作业属性和异常
JS正则匹配先行断言和后行断言
Oracle改成mysql
1-4 NodeJS的安装之配置可执行脚本
offset和client获取dom元素位置信息
PHP 无限极分类和树形
.Net Core 下使用 Quartz —— 【3】作业和触发器之作业传参
JS性能优化
SiteServer CMS5.0使用总结
小米摄像头异常解决
.Net Core3.1 使用 RazorEngine.NetCore 制作实体生成器 (MVC网页版)
【批量更改mysql表以及表中字段对应的编码】
百度地图案例-缩放组件、地图比例组件