A library to inspect itermediate layers of PyTorch models.

Overview

A library to inspect itermediate layers of PyTorch models.

Why?

It's often the case that we want to inspect intermediate layers of a model without modifying the code e.g. visualize attention matrices of language models, get values from an intermediate layer to feed to another layer, or applying a loss function to intermediate layers.

Install

$ pip install surgeon-pytorch

PyPI - Python Version

Usage

Inspect

Given a PyTorch model we can display all layers using get_layers:

import torch
import torch.nn as nn

from surgeon_pytorch import Inspect, get_layers

class SomeModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(5, 3)
        self.layer2 = nn.Linear(3, 2)
        self.layer3 = nn.Linear(2, 1)

    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        y = self.layer3(x2)
        return y


model = SomeModel()
print(get_layers(model)) # ['layer1', 'layer2', 'layer3']

Then we can wrap our model to be inspected using Inspect and in every forward call the new model we will also output the provided layer outputs (in second return value):

model_wrapped = Inspect(model, layer='layer2')
x = torch.rand(1, 5)
y, x2 = model_wrapped(x)
print(x2) # tensor([[-0.2726,  0.0910]], grad_fn=<AddmmBackward0>)

We can also provide a list of layers:

model_wrapped = Inspect(model, layer=['layer1', 'layer2'])
x = torch.rand(1, 5)
y, [x1, x2] = model_wrapped(x)
print(x1) # tensor([[ 0.1739,  0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print(x2) # tensor([[-0.2238,  0.0107]], grad_fn=<AddmmBackward0>)

Or a dictionary to get named outputs:

model_wrapped = Inspect(model, layer={'x1': 'layer1', 'x2': 'layer2'})
x = torch.rand(1, 5)
y, layers = model_wrapped(x)
print(layers)
"""
{
    'x1': tensor([[ 0.3707,  0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
    'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""

TODO

  • add extract function to get intermediate block
You might also like...
Ever felt tired after preprocessing the dataset, and not wanting to write any code further to train your model? Ever encountered a situation where you wanted to record the hyperparameters of the trained model and able to retrieve it afterward? Models Playground is here to help you do that. Models playground allows you to train your models right from the browser. pyhsmm - library for approximate unsupervised inference in Bayesian Hidden Markov Models (HMMs) and explicit-duration Hidden semi-Markov Models (HSMMs), focusing on the Bayesian Nonparametric extensions, the HDP-HMM and HDP-HSMM, mostly with weak-limit approximations. PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer
TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

This repository provides an efficient PyTorch-based library for training deep models.

An Efficient Library for Training Deep Models This repository provides an efficient PyTorch-based library for training deep models. Installation Make

TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

TorchMultimodal (Alpha Release) Introduction TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Comments
  • Use one backbone with different heads

    Use one backbone with different heads

    Is it possible to save the results from the backbone and apply them on the heads of the all the other models. My goal was to try to save time by avoiding repeating the backbone part. Instead of running the 3 complete models (left), only run the backbone 1 time and switch only the heads for the 3 models (right), therefore not repeating executing the backbone every time in yolov5 model.

    Thank you for the help!

    question 
    opened by brunopatricio2012 4
  • Support for DataParallel?

    Support for DataParallel?

    Hi, I noticed that the current version does not support parallel models (at least those created using torch.nn.DataParallel) since the forward hook does not differentiate between the different copies of the model and a model wrapped with Inspect will just return the intermediate features of the last copy of the parallelized model to run.

    Are you planning on fixing this issue/supporting this use case?

    opened by zimmerrol 1
Releases(0.0.4)
Owner
archinet.ai
AI Research Group
archinet.ai
A cross-document event and entity coreference resolution system, trained and evaluated on the ECB+ corpus.

A Comprehensive Comparison of Word Embeddings in Event & Entity Coreference Resolution. Introduction This repo contains experimental code derived from

2 May 09, 2022
ANEA: Automated (Named) Entity Annotation for German Domain-Specific Texts

ANEA The goal of Automatic (Named) Entity Annotation is to create a small annotated dataset for NER extracted from German domain-specific texts. Insta

Anastasia Zhukova 2 Oct 07, 2022
RAANet: Range-Aware Attention Network for LiDAR-based 3D Object Detection with Auxiliary Density Level Estimation

RAANet: Range-Aware Attention Network for LiDAR-based 3D Object Detection with Auxiliary Density Level Estimation Anonymous submission Abstract 3D obj

30 Sep 16, 2022
A facial recognition doorbell system using a Raspberry Pi

Facial Recognition Doorbell This project expands on the person-detecting doorbell system to allow it to identify faces, and announce names accordingly

rydercalmdown 22 Apr 15, 2022
Fairness Metrics: All you need to know

Fairness Metrics: All you need to know Testing machine learning software for ethical bias has become a pressing current concern. Recent research has p

Anonymous2020 1 Jan 17, 2022
Goal of the project : Detecting Temporal Boundaries in Sign Language videos

MVA RecVis course final project : Goal of the project : Detecting Temporal Boundaries in Sign Language videos. Sign language automatic indexing is an

Loubna Ben Allal 6 Dec 21, 2022
Make a surveillance camera from your raspberry pi!

rpi-surveillance Make a surveillance camera from your Raspberry Pi 4! The surveillance is built as following: the camera records 10 seconds video and

Vladyslav 62 Feb 03, 2022
This is the official implement of paper "ActionCLIP: A New Paradigm for Action Recognition"

This is an official pytorch implementation of ActionCLIP: A New Paradigm for Video Action Recognition [arXiv] Overview Content Prerequisites Data Prep

268 Jan 09, 2023
Code for "Universal inference meets random projections: a scalable test for log-concavity"

How to use this repository This repository contains code to replicate the results of "Universal inference meets random projections: a scalable test fo

Robin Dunn 0 Nov 21, 2021
Code for MarioNette: Self-Supervised Sprite Learning, in NeurIPS 2021

MarioNette | Webpage | Paper | Video MarioNette: Self-Supervised Sprite Learning Dmitriy Smirnov, Michaël Gharbi, Matthew Fisher, Vitor Guizilini, Ale

Dima Smirnov 28 Nov 18, 2022
AutoPentest-DRL: Automated Penetration Testing Using Deep Reinforcement Learning

AutoPentest-DRL: Automated Penetration Testing Using Deep Reinforcement Learning AutoPentest-DRL is an automated penetration testing framework based o

Cyber Range Organization and Design Chair 217 Jan 01, 2023
Trajectory Extraction of road users via Traffic Camera

Traffic Monitoring Citation The associated paper for this project will be published here as soon as possible. When using this software, please cite th

Julian Strosahl 14 Dec 17, 2022
[2021][ICCV][FSNet] Full-Duplex Strategy for Video Object Segmentation

Full-Duplex Strategy for Video Object Segmentation (ICCV, 2021) Authors: Ge-Peng Ji, Keren Fu, Zhe Wu, Deng-Ping Fan*, Jianbing Shen, & Ling Shao This

Daniel-Ji 55 Dec 22, 2022
DrWhy is the collection of tools for eXplainable AI (XAI). It's based on shared principles and simple grammar for exploration, explanation and visualisation of predictive models.

Responsible Machine Learning With Great Power Comes Great Responsibility. Voltaire (well, maybe) How to develop machine learning models in a responsib

Model Oriented 590 Dec 26, 2022
这是一个yolox-keras的源码,可以用于训练自己的模型。

YOLOX:You Only Look Once目标检测模型在Keras当中的实现 目录 性能情况 Performance 实现的内容 Achievement 所需环境 Environment 小技巧的设置 TricksSet 文件下载 Download 训练步骤 How2train 预测步骤 Ho

Bubbliiiing 64 Nov 10, 2022
Efficient and intelligent interactive segmentation annotation software

Efficient and intelligent interactive segmentation annotation software

294 Dec 30, 2022
ADGAN - The Implementation of paper Controllable Person Image Synthesis with Attribute-Decomposed GAN

ADGAN - The Implementation of paper Controllable Person Image Synthesis with Attribute-Decomposed GAN CVPR 2020 (Oral); Pose and Appearance Attributes Transfer;

Men Yifang 400 Dec 29, 2022
Monify: an Expense tracker Program implemented in a Graphical User Interface that allows users to keep track of their expenses

💳 MONIFY (EXPENSE TRACKER PRO) 💳 Description Monify is an Expense tracker Program implemented in a Graphical User Interface allows users to add inco

Moyosore Weke 1 Dec 14, 2021
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022
Visualizing lattice vibration information from phonon dispersion to atoms (For GPUMD)

Phonon-Vibration-Viewer (For GPUMD) Visualizing lattice vibration information from phonon dispersion for primitive atoms. In this tutorial, we will in

Liangting 6 Dec 10, 2022