A library to inspect itermediate layers of PyTorch models.

Overview

A library to inspect itermediate layers of PyTorch models.

Why?

It's often the case that we want to inspect intermediate layers of a model without modifying the code e.g. visualize attention matrices of language models, get values from an intermediate layer to feed to another layer, or applying a loss function to intermediate layers.

Install

$ pip install surgeon-pytorch

PyPI - Python Version

Usage

Inspect

Given a PyTorch model we can display all layers using get_layers:

import torch
import torch.nn as nn

from surgeon_pytorch import Inspect, get_layers

class SomeModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(5, 3)
        self.layer2 = nn.Linear(3, 2)
        self.layer3 = nn.Linear(2, 1)

    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        y = self.layer3(x2)
        return y


model = SomeModel()
print(get_layers(model)) # ['layer1', 'layer2', 'layer3']

Then we can wrap our model to be inspected using Inspect and in every forward call the new model we will also output the provided layer outputs (in second return value):

model_wrapped = Inspect(model, layer='layer2')
x = torch.rand(1, 5)
y, x2 = model_wrapped(x)
print(x2) # tensor([[-0.2726,  0.0910]], grad_fn=<AddmmBackward0>)

We can also provide a list of layers:

model_wrapped = Inspect(model, layer=['layer1', 'layer2'])
x = torch.rand(1, 5)
y, [x1, x2] = model_wrapped(x)
print(x1) # tensor([[ 0.1739,  0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print(x2) # tensor([[-0.2238,  0.0107]], grad_fn=<AddmmBackward0>)

Or a dictionary to get named outputs:

model_wrapped = Inspect(model, layer={'x1': 'layer1', 'x2': 'layer2'})
x = torch.rand(1, 5)
y, layers = model_wrapped(x)
print(layers)
"""
{
    'x1': tensor([[ 0.3707,  0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
    'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""

TODO

  • add extract function to get intermediate block
You might also like...
Ever felt tired after preprocessing the dataset, and not wanting to write any code further to train your model? Ever encountered a situation where you wanted to record the hyperparameters of the trained model and able to retrieve it afterward? Models Playground is here to help you do that. Models playground allows you to train your models right from the browser. pyhsmm - library for approximate unsupervised inference in Bayesian Hidden Markov Models (HMMs) and explicit-duration Hidden semi-Markov Models (HSMMs), focusing on the Bayesian Nonparametric extensions, the HDP-HMM and HDP-HSMM, mostly with weak-limit approximations. PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer
TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

This repository provides an efficient PyTorch-based library for training deep models.

An Efficient Library for Training Deep Models This repository provides an efficient PyTorch-based library for training deep models. Installation Make

TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

TorchMultimodal (Alpha Release) Introduction TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Comments
  • Use one backbone with different heads

    Use one backbone with different heads

    Is it possible to save the results from the backbone and apply them on the heads of the all the other models. My goal was to try to save time by avoiding repeating the backbone part. Instead of running the 3 complete models (left), only run the backbone 1 time and switch only the heads for the 3 models (right), therefore not repeating executing the backbone every time in yolov5 model.

    Thank you for the help!

    question 
    opened by brunopatricio2012 4
  • Support for DataParallel?

    Support for DataParallel?

    Hi, I noticed that the current version does not support parallel models (at least those created using torch.nn.DataParallel) since the forward hook does not differentiate between the different copies of the model and a model wrapped with Inspect will just return the intermediate features of the last copy of the parallelized model to run.

    Are you planning on fixing this issue/supporting this use case?

    opened by zimmerrol 1
Releases(0.0.4)
Owner
archinet.ai
AI Research Group
archinet.ai
StackRec: Efficient Training of Very Deep Sequential Recommender Models by Iterative Stacking

StackRec: Efficient Training of Very Deep Sequential Recommender Models by Iterative Stacking Datasets You can download datasets that have been pre-pr

25 May 29, 2022
Implementation of Google Brain's WaveGrad high-fidelity vocoder

WaveGrad Implementation (PyTorch) of Google Brain's high-fidelity WaveGrad vocoder (paper). First implementation on GitHub with high-quality generatio

Ivan Vovk 363 Dec 27, 2022
Group-Free 3D Object Detection via Transformers

Group-Free 3D Object Detection via Transformers By Ze Liu, Zheng Zhang, Yue Cao, Han Hu, Xin Tong. This repo is the official implementation of "Group-

Ze Liu 213 Dec 07, 2022
ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representation from common sense knowledge graphs.

ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representa

Bats Research 94 Nov 21, 2022
PyTorch implementation of the Pose Residual Network (PRN)

Pose Residual Network This repository contains a PyTorch implementation of the Pose Residual Network (PRN) presented in our ECCV 2018 paper: Muhammed

Salih Karagoz 289 Nov 28, 2022
Molecular AutoEncoder in PyTorch

MolEncoder Molecular AutoEncoder in PyTorch Install $ git clone https://github.com/cxhernandez/molencoder.git && cd molencoder $ python setup.py insta

Carlos Hernández 80 Dec 05, 2022
A cross-lingual COVID-19 fake news dataset

CrossFake An English-Chinese COVID-19 fake&real news dataset from the ICDMW 2021 paper below: Cross-lingual COVID-19 Fake News Detection. Jiangshu Du,

Yingtong Dou 11 Dec 01, 2022
House3D: A Rich and Realistic 3D Environment

House3D: A Rich and Realistic 3D Environment Yi Wu, Yuxin Wu, Georgia Gkioxari and Yuandong Tian House3D is a virtual 3D environment which consists of

Meta Research 1.1k Dec 14, 2022
Awesome Remote Sensing Toolkit based on PaddlePaddle.

基于飞桨框架开发的高性能遥感图像处理开发套件,端到端地完成从训练到部署的全流程遥感深度学习应用。 最新动态 PaddleRS 即将发布alpha版本!欢迎大家试用 简介 PaddleRS是遥感科研院所、相关高校共同基于飞桨开发的遥感处理平台,支持遥感图像分类,目标检测,图像分割,以及变化检测等常用遥

146 Dec 11, 2022
A diff tool for language models

LMdiff Qualitative comparison of large language models. Demo & Paper: http://lmdiff.net LMdiff is a MIT-IBM Watson AI Lab collaboration between: Hendr

Hendrik Strobelt 27 Dec 29, 2022
[ICLR'21] FedBN: Federated Learning on Non-IID Features via Local Batch Normalization

FedBN: Federated Learning on Non-IID Features via Local Batch Normalization This is the PyTorch implemention of our paper FedBN: Federated Learning on

<a href=[email protected]"> 156 Dec 15, 2022
PyTorch implementation of paper: HPNet: Deep Primitive Segmentation Using Hybrid Representations.

HPNet This repository contains the PyTorch implementation of paper: HPNet: Deep Primitive Segmentation Using Hybrid Representations. Installation The

Siming Yan 42 Dec 07, 2022
Image Segmentation and Object Detection in Pytorch

Image Segmentation and Object Detection in Pytorch Pytorch-Segmentation-Detection is a library for image segmentation and object detection with report

Daniil Pakhomov 732 Dec 10, 2022
Tutoriais publicados nas nossas redes sociais para obtenção de dados, análises simples e outras tarefas relevantes no mercado financeiro.

Tutoriais Públicos Tutoriais publicados nas nossas redes sociais para obtenção de dados, análises simples e outras tarefas relevantes no mercado finan

Trading com Dados 68 Oct 15, 2022
[内测中]前向式Python环境快捷封装工具,快速将Python打包为EXE并添加CUDA、NoAVX等支持。

QPT - Quick packaging tool 快捷封装工具 GitHub主页 | Gitee主页 QPT是一款可以“模拟”开发环境的多功能封装工具,最短只需一行命令即可将普通的Python脚本打包成EXE可执行程序,并选择性添加CUDA和NoAVX的支持,尽可能兼容更多的用户环境。 感觉还可

QPT Family 545 Dec 28, 2022
Adaptive Dropblock Enhanced GenerativeAdversarial Networks for Hyperspectral Image Classification

This repo holds the codes of our paper: Adaptive Dropblock Enhanced GenerativeAdversarial Networks for Hyperspectral Image Classification, which is ac

Feng Gao 17 Dec 28, 2022
BBB streaming without Xorg and Pulseaudio and Chromium and other nonsense (heavily WIP)

BBB Streamer NG? Makes a conference like this... ...streamable like this! I also recorded a small video showing the basic features: https://www.youtub

Lukas Schauer 60 Oct 21, 2022
ShuttleNet: Position-aware Fusion of Rally Progress and Player Styles for Stroke Forecasting in Badminton (AAAI 2022)

ShuttleNet: Position-aware Rally Progress and Player Styles Fusion for Stroke Forecasting in Badminton (AAAI 2022) Official code of the paper ShuttleN

Wei-Yao Wang 11 Nov 30, 2022