Domain Generalization with MixStyle, ICLR'21.

Overview

MixStyle

This repo contains the code of our ICLR'21 paper, "Domain Generalization with MixStyle".

The OpenReview link is https://openreview.net/forum?id=6xHJ37MVxxp.

########## Updates ############

12-04-2021: A variable self._activated is added to MixStyle to better control the computational flow. To deactivate MixStyle without modifying the model code, one can do

def deactivate_mixstyle(m):
    if type(m) == MixStyle:
        m.set_activation_status(False)

model.apply(deactivate_mixstyle)

Similarly, to activate MixStyle, one can do

def activate_mixstyle(m):
    if type(m) == MixStyle:
        m.set_activation_status(True)

model.apply(activate_mixstyle)

Note that MixStyle has been included in Dassl.pytorch. See the code for details.

05-03-2021: You might also be interested in our recently released survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in domain generalization, with coverage on the history, datasets, related problems, methodologies, potential directions, and so on.

##############################

A brief introduction: The key idea of MixStyle is to probablistically mix instance-level feature statistics of training samples across source domains. MixStyle improves model robustness to domain shift by implicitly synthesizing new domains at the feature level for regularizing the training of convolutional neural networks. This idea is largely inspired by neural style transfer which has shown that feature statistics are closely related to image style and therefore arbitrary image style transfer can be achieved by switching the feature statistics between a content and a style image.

MixStyle is very easy to implement. Below we show the PyTorch code of MixStyle.

import random
import torch
import torch.nn as nn


class MixStyle(nn.Module):
    """MixStyle.

    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha

        self._activated = True

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})'

    def set_activation_status(self, status=True):
        self._activated = status

    def forward(self, x):
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1, 1))
        lmda = lmda.to(x.device)

        perm = torch.randperm(B)
        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed*sig_mix + mu_mix

How to apply MixStyle to your CNN models? Say you are using ResNet as the CNN architecture, and want to apply MixStyle after the 1st and 2nd residual blocks, you can first instantiate the MixStyle module using

self.mixstyle = MixStyle(p=0.5, alpha=0.1)

during network construction (in __init__()), and then apply MixStyle in the forward pass like

def forward(self, x):
    x = self.conv1(x) # 1st convolution layer
    x = self.res1(x) # 1st residual block
    x = self.mixstyle(x)
    x = self.res2(x) # 2nd residual block
    x = self.mixstyle(x)
    x = self.res3(x) # 3rd residual block
    x = self.res4(x) # 4th residual block
    ...

In our paper, we have demonstrated the effectiveness of MixStyle on three tasks: image classification, person re-identification, and reinforcement learning. The source code for reproducing all experiments can be found in mixstyle-release/imcls, mixstyle-release/reid, and mixstyle-release/rl, respectively.

Takeaways on applying MixStyle to your tasks:

  • Applying MixStyle to multiple lower layers is generally better
  • Do not apply MixStyle to the last layer that is the closest to the prediction layer
  • Different tasks might favor different combinations

For more analytical studies, please read our paper at https://openreview.net/forum?id=6xHJ37MVxxp.

To cite MixStyle in your publications, please use the following bibtex entry

@inproceedings{zhou2021mixstyle,
  title={Domain Generalization with MixStyle},
  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
  booktitle={ICLR},
  year={2021}
}
Owner
Kaiyang
Researcher in computer vision and machine learning :)
Kaiyang
Differentiable Neural Computers, Sparse Access Memory and Sparse Differentiable Neural Computers, for Pytorch

Differentiable Neural Computers and family, for Pytorch Includes: Differentiable Neural Computers (DNC) Sparse Access Memory (SAM) Sparse Differentiab

ixaxaar 302 Dec 14, 2022
This repository contains code to train and render Mixture of Volumetric Primitives (MVP) models

Mixture of Volumetric Primitives -- Training and Evaluation This repository contains code to train and render Mixture of Volumetric Primitives (MVP) m

Meta Research 125 Dec 29, 2022
Adjusting for Autocorrelated Errors in Neural Networks for Time Series

Adjusting for Autocorrelated Errors in Neural Networks for Time Series This repository is the official implementation of the paper "Adjusting for Auto

Fan-Keng Sun 51 Nov 05, 2022
Datasets, Transforms and Models specific to Computer Vision

vision Datasets, Transforms and Models specific to Computer Vision Installation First install the nightly version of OneFlow python3 -m pip install on

OneFlow 68 Dec 07, 2022
TRIQ implementation

TRIQ Implementation TF-Keras implementation of TRIQ as described in Transformer for Image Quality Assessment. Installation Clone this repository. Inst

Junyong You 115 Dec 30, 2022
The dataset and source code for our paper: "Did You Ask a Good Question? A Cross-Domain Question IntentionClassification Benchmark for Text-to-SQL"

TriageSQL The dataset and source code for our paper: "Did You Ask a Good Question? A Cross-Domain Question Intention Classification Benchmark for Text

Yusen Zhang 22 Nov 09, 2022
AutoVideo: An Automated Video Action Recognition System

AutoVideo is a system for automated video analysis. It is developed based on D3M infrastructure, which describes machine learning with generic pipeline languages. Currently, it focuses on video actio

Data Analytics Lab at Texas A&M University 267 Dec 17, 2022
[ICCV 2021] Focal Frequency Loss for Image Reconstruction and Synthesis

Focal Frequency Loss - Official PyTorch Implementation This repository provides the official PyTorch implementation for the following paper: Focal Fre

Liming Jiang 460 Jan 04, 2023
PyArmadillo: an alternative approach to linear algebra in Python

PyArmadillo is a linear algebra library for the Python language, with an emphasis on ease of use.

Terry Zhuo 58 Oct 11, 2022
This repository attempts to replicate the SqueezeNet architecture and implement the same on an image classification task.

SqueezeNet-Implementation This repository attempts to replicate the SqueezeNet architecture using TensorFlow discussed in the research paper: "Squeeze

Rohan Mathur 3 Dec 13, 2022
A toolset of Python programs for signal modeling and indentification via sparse semilinear autoregressors.

SPAAR Description A toolset of Python programs for signal modeling via sparse semilinear autoregressors. References Vides, F. (2021). Computing Semili

Fredy Vides 0 Oct 30, 2021
Official implementation of EfficientPose

EfficientPose This is the official implementation of EfficientPose. We based our work on the Keras EfficientDet implementation xuannianz/EfficientDet

2 May 17, 2022
Official repository for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization".

StableNet StableNet is a deep stable learning method for out-of-distribution generalization. This is the official repo for CVPR21 paper "Deep Stable L

120 Dec 28, 2022
Pytorch Implementation for Dilated Continuous Random Field

DilatedCRF Pytorch implementation for fully-learnable DilatedCRF. If you find my work helpful, please consider our paper: @article{Mo2022dilatedcrf,

DunnoCoding_Plus 3 Nov 13, 2022
G-NIA model from "Single Node Injection Attack against Graph Neural Networks" (CIKM 2021)

Single Node Injection Attack against Graph Neural Networks This repository is our Pytorch implementation of our paper: Single Node Injection Attack ag

Shuchang Tao 18 Nov 21, 2022
Official code for "Decoupling Zero-Shot Semantic Segmentation"

Decoupling Zero-Shot Semantic Segmentation This is the official code for the arxiv. ZegFormer is the first framework that decouple the zero-shot seman

Jian Ding 108 Dec 30, 2022
Adaptive Dropblock Enhanced GenerativeAdversarial Networks for Hyperspectral Image Classification

This repo holds the codes of our paper: Adaptive Dropblock Enhanced GenerativeAdversarial Networks for Hyperspectral Image Classification, which is ac

Feng Gao 17 Dec 28, 2022
Exporter for Storage Area Network (SAN)

SAN Exporter Prometheus exporter for Storage Area Network (SAN). We all know that each SAN Storage vendor has their own glossary of terms, health/perf

vCloud 32 Dec 16, 2022
Pytorch Implementation of Spiking Neural Networks Calibration, ICML 2021

SNN_Calibration Pytorch Implementation of Spiking Neural Networks Calibration, ICML 2021 Feature Comparison of SNN calibration: Features SNN Direct Tr

Yuhang Li 60 Dec 27, 2022
Deep learning model, heat map, data prepo

deep learning model, heat map, data prepo

Pamela Dekas 1 Jan 14, 2022