TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

Overview

Documents | Projects | API References

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and GPT) or huge classes (millions). It has the same API design as PyTorch.

Installation

pip install torchshard

More options in INSTALL.md.

Usage

import torchshard as ts

ts.init_process_group(group_size=2)                       # init parallel groups

m = torch.nn.Sequential(
    torch.nn.Linear(20, 30, bias=True),               
    ts.nn.ParallelLinear(30, 30, bias=True, dim=None),    # equal to nn.Linear()
    ts.nn.ParallelLinear(30, 30, bias=True, dim=0),       # parallel in row dimension
    ts.nn.ParallelLinear(30, 30, bias=True, dim=1),       # parallel in column dimension
).cuda()

x = m(x)                                                  # forward
loss = ts.nn.functional.parallel_cross_entropy(x, y)      # parallel loss function
loss.backward()                                           # backward

torch.save(
  ts.collect_state_dict(m, m.state_dict()), 'm.pt')       # save model state

Performance

The following figure is a showcase of training ResNet-50 on 8 NVIDIA TITAN-XP (12196 MiB) GPUs with scaling up classes from 1000 → 1 Million. The input size is 224 x 224, and the batch size is 256. Parallelism is with 8-way data parallel and 8-way model parallel.

The following figure shows training minGPT on 8 NVIDIA TITAN-XP (12196 MiB) GPUs with scaling up parameters from 10 Million → 808 Million. The input size is 32 x 32, and the batch size is 16. Parallelism is with 1-way data parallel and 8-way model parallel.

Contributing

The TorchShard welcomes your expertise and enthusiasm!

If you are interested in torchshard, you are welcome to help

  • polish code and develop new features
  • develop high-quality tutorials, projects, and advanced materials

Direct pull requests are welcome. Contact: kaiyuyue [at] umd.edu.

Citing TorchShard

If you think TorchShard is helpful in your research and consider to cite it, please use the following BibTeX entry.

@misc{torchshard2021,
  author =       {Kaiyu Yue},
  title =        {TorchShard},
  howpublished = {\url{https://github.com/KaiyuYue/torchshard}},
  year =         {2021}
}
Comments
  • Future Planinig on this project.

    Future Planinig on this project.

    Hello Kaiyu, I love this awesome project. The API design is elegant and simple and the software is lightweight and user-friendly. My understanding is that this project has realized a series of PyTorch wrappers for tensor slicing.

    1. I am curious about the future planning of this project.
    2. Is there some overlap in functionality between torchshard and N-D parallelism proposed in ColossalAI.
    3. How is compatibility with ZeRO? According to am+zero example, the memory footprint has a little change after combination torchshard with ZeRO.
    opened by feifeibear 2
  • Which one is faster?

    Which one is faster?

    Thanks for contributing this great lib. I have one question. Which one is faster (in speed) between dim=0and dim=1? The documentations seem to only contain accuracy results.

    opened by NOBLES5E 2
  • 8 gpus test example raise error.

    8 gpus test example raise error.

    When I do Unit Tests, it can pass when use two gpu devices, run command below: CUDA_VISIBLE_DEVICES=0,1 python3 -m unittest discover -v -s tests

    But I do Unit Tests with eight gpu devices, it raise ncclSystemError. run command: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m unittest discover -v -s tests raise error: RuntimeError: NCCL error in ../torch/lib/c10d/ProcessGroupNCCL.cpp:825, unhandled system error, NCCL version 2.7.8 ncclSystemError: System call (socket, malloc, munmap, etc) failed.

    Is it necessary to pass unittest in eights gpu devices?

    opened by JiaquanYe 1
  • Error?

    Error?

    Hi, thanks for the excellent job! When I install it from pip, and

    import torchshard as ts
    ts.init_process_group(group_size=2) 
    

    The AttributeError occurs:

    AttributeError: module 'torchshard' has no attribute 'init_process_group'
    
    opened by WangWenhao0716 1
  • Multi-node setting?

    Multi-node setting?

    https://github.com/KaiyuYue/torchshard/blob/89e21def180bf6063ceb2e312a61631173abc7e7/projects/minGPT/main.py#L150

    I have noticed that the group_size is set to world_size in examples, but in fact the group_size can be set to other numbers according to my understanding.

    https://github.com/KaiyuYue/torchshard/blob/main/torchshard/distributed/core.py#L18

    I have also found that the get_world_size() will return the number of all processes.

    The two findings make me confused in a multi-node setting, say 2 nodes with each node with 2 processes.

    If the group_size is 2, then there are 2 distinct groups besides the default group (w/ overlap). However, get_world_size() is used without specifying a group can make a layer be splitted to 4 parts, which is expected to be 2 in our case.

    Correct me if I am wrong.

    Good Issue 
    opened by GeneZC 1
  • Is it possible to collect state dict in cpu?

    Is it possible to collect state dict in cpu?

    When I finish one epoch in trianing, the main_worker function will call ts.collect_state_dict(model, state_dict). But because the limit of GPU resource, it will raise Out of Memory in my machine, when call ts.collect_state_dict(model, state_dict). I found that will gather the state_dict in GPU, is it anyway to gather in CPU?

    Good Issue 
    opened by JiaquanYe 2
Releases(v0.1)
Owner
Kaiyu Yue
Kaiyu Yue
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 03, 2023
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 04, 2023
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

Pretrained models for Pytorch (Work in progress) The goal of this repo is: to help to reproduce research papers results (transfer learning setups for

Remi 8.7k Dec 31, 2022
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 07, 2023
A code copied from google-research which named motion-imitation was rewrited with PyTorch

motor-system Introduction A code copied from google-research which named motion-imitation was rewrited with PyTorch. More details can get from this pr

NewEra 6 Jan 08, 2022
A tutorial on "Bayesian Compression for Deep Learning" published at NIPS (2017).

Code release for "Bayesian Compression for Deep Learning" In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of

Karen Ullrich 190 Dec 30, 2022
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 01, 2023
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
High-fidelity performance metrics for generative models in PyTorch

High-fidelity performance metrics for generative models in PyTorch

Vikram Voleti 5 Oct 24, 2021
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022