当前位置:网站首页>[code analysis (7)] communication efficient learning of deep networks from decentralized data

[code analysis (7)] communication efficient learning of deep networks from decentralized data

2022-04-23 13:50:00 Silent city of the sky

baseline_main.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import os
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader

from utils import get_dataset
from options import args_parser
from update import test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

if __name__ == '__main__':
    args = args_parser()
    if args.gpu:
        torch.cuda.set_device(args.gpu)
    device = 'cuda' if args.gpu else 'cpu'

    # load datasets
    train_dataset, test_dataset, _ = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)
    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # Training
    # Set optimizer and criterion
    '''
         Set optimizer and updata.py in LocalUpdate Function of 
        update_weights in 
         equally 
    '''
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(global_model.parameters(), lr=args.lr,
                                    momentum=0.5)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(global_model.parameters(), lr=args.lr,
                                     weight_decay=1e-4)

    #  It specifies batch_size=64
    # torch.Size([64, 3, 32, 32])
    trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    criterion = torch.nn.NLLLoss().to(device)

    epoch_loss = []

    '''
         Below for loop :
         and updata.py in LocalUpdate Function of 
        update_weights In the same 
        
         The difference is update.py Medium is args.local_ep
         Here is args.epochs
    '''
    # print('000000000000000000')
    # print(type(trainloader))
    # <class 'torch.utils.data.dataloader.DataLoader'>

    # print(trainloader.dataset)
    '''
        Dataset CIFAR10
        Number of datapoints: 50000
        Root location: ../data/cifar/
        Split: Train
        StandardTransform
        Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        )
    '''

    for epoch in tqdm(range(args.epochs)):
        batch_loss = []

        for batch_idx, (images, labels) in enumerate(trainloader):

            images, labels = images.to(device), labels.to(device)
            '''
                batch_idx:
                (images, labels)
                enumerate(self.trainloader) The training set data will be a 
                batch One batch Take it out to train 
                 Use enumerate Conduct dataloader Used to read data in 
                 The training of neural network is the first data reading method , Its basic form 
                 That is to say for index, item in enumerate(dataloader['train']),
                 among item in 0 For data ,1 by label.
            '''

            # update.py Medium is model.zero_grad()
            optimizer.zero_grad()
            '''
                 First set the gradient to zero (optimizer.zero_grad()'''

            outputs = global_model(images)

            loss = criterion(outputs, labels)

            loss.backward()
            '''
                 The gradient value of each parameter is obtained by back propagation calculation (loss.backward()'''

            optimizer.step()
            '''
                 Finally, a one-step parameter update is performed by gradient descent (optimizer.step()'''

            # print('000000000000000')
            # print(batch_idx)  # 0,1,2,.....
            # print(len(images))  # 64
            # print(type(images))  # <class 'torch.Tensor'>
            # print(enumerate(trainloader))  # <enumerate object at 0x00000293CC019F98>
            # print(images.shape)
            '''
                torch.Size([64, 3, 32, 32])
                tensor([[ [[ 0.6235,  0.7098,  0.8196,  ...,  0.2392,  0.4039, -0.2863],
                          ...
                          [ 0.5137,  0.7255,  0.7333,  ...,  0.6706,  0.6627,  0.6314]],
                          ...
                          [[ 0.9922,  0.9765,  0.9608,  ...,  0.8353,  0.8353,  0.8588],
                          ...
                          [ 0.7020,  0.7490,  0.7569,  ...,  0.5059,  0.5451,  0.6471]],
                        ]])
            '''
            # print(labels)
            # print(labels.shape)
            '''
               tensor([1, 9, 1, 8, 7, 4, 4, 0, 7, 6, 1, 3, 8, 7, 6, 3, 4, 0, 5, 2, 2, 8, 5, 4,
                        2, 0, 8, 8, 3, 9, 5, 7, 3, 1, 7, 1, 9, 2, 4, 1, 2, 9, 8, 9, 8, 1, 4, 7,
                        8, 3, 5, 5, 9, 6, 0, 7, 6, 6, 1, 3, 5, 7, 3, 5])
               torch.Size([64]) 
            '''
            # print(len(trainloader.dataset))  # 50000

            if batch_idx % 50 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch+1, batch_idx * len(images), len(trainloader.dataset),
                    100. * batch_idx / len(trainloader), loss.item()))
            '''
                batch_idx=
                batch_idx * len(images)
                batch_idx * 64 
                batch_idx=0,...,50,...,100,...150
                          200,...,250,...,300,...350
                          400,...,450,...,500,...550
                          600,...,650,...,700,...750
                          
            '''

            batch_loss.append(loss.item())
        # print(batch_idx)  # 781

        loss_avg = sum(batch_loss)/len(batch_loss)

        print('\nTrain loss:', loss_avg)

        #  once epoch Add once 
        epoch_loss.append(loss_avg)

    # Plot loss
    plt.figure()
    plt.plot(range(len(epoch_loss)), epoch_loss)
    plt.xlabel('epochs')
    plt.ylabel('Train loss')
    plt.savefig('../save/nn_{}_{}_{}.png'.format(args.dataset, args.model,
                                                 args.epochs))

    # testing
    test_acc, test_loss = test_inference(args, global_model, test_dataset)
    print('Test on', len(test_dataset), 'samples')
    print("Test Accuracy: {:.2f}%".format(100*test_acc))


版权声明
本文为[Silent city of the sky]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230556365641.html