SmallInitEmb - LayerNorm(SmallInit(Embedding)) in a Transformer to improve convergence

Overview

SmallInitEmb

LayerNorm(SmallInit(Embedding)) in a Transformer

I find that when training a transformer, the embedding matrix moves slowly, hence it's difficult for the model to jump out of the initial noisy embedding.

(initial embedding)
[[-0.0073  0.0062 -0.0261 ...  0.0086  0.0107 -0.008 ] ... ]
 (after 1 step, the directions of the embedding vectors are not moved much because the numbers change by ~LR = ~4e-4)
[[-0.0069  0.0066 -0.0265 ...  0.009   0.0111 -0.0084] ... ]

So I propose initializing the embedding matrix to tiny values, and put another LayerNorm after it (before all the SA & FFN layers):

if isinstance(module, (nn.Embedding)):
    nn.init.uniform_(module.weight, a=-1e-4, b=1e-4) # SmallInit(Emb)
...
if self.config.USE_SMALL_EMB and self.layer_id == 0:
    x = self.lnPre(x) # LN(SmallInit(Emb))
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))

And then you get improved convergence (especially for BPE models) because the model can quickly jump out of the tiny initial embedding (small changes after 1 step -> significant changes of directions -> significant changes after LayerNorm).

Loss curve comparison: https://wandb.ai/blinkdl/SmallEmbTest

(the gap between LayerNorm(SmallEmb)) and baseline persists after more training)

Moreover, you can directly train PostLN models without warmup with SmallInit(Emb)

if isinstance(module, (nn.Embedding)):
    nn.init.uniform_(module.weight, a=-1e-4, b=1e-4) # SmallInit(Emb)
...
x = self.ln1(x) # this plays the same role as the lnPre in the above PreLN code
x = x + self.att(x)
x = self.ln2(x)
x = x + self.ffn(x)
(note you shall have another LN after the final ffn)
Owner
PENG Bo
http://zhihu.com/people/bopengbopeng
PENG Bo
The implementation of "Optimizing Shoulder to Shoulder: A Coordinated Sub-Band Fusion Model for Real-Time Full-Band Speech Enhancement"

SF-Net for fullband SE This is the repo of the manuscript "Optimizing Shoulder to Shoulder: A Coordinated Sub-Band Fusion Model for Real-Time Full-Ban

Guochen Yu 36 Dec 02, 2022
QT Py Media Knob using rotary encoder & neopixel ring

QTPy-Knob QT Py USB Media Knob using rotary encoder & neopixel ring The QTPy-Knob features: Media knob for volume up/down/mute with "qtpy-knob.py" Cir

Tod E. Kurt 56 Dec 30, 2022
Generative Art Using Neural Visual Grammars and Dual Encoders

Generative Art Using Neural Visual Grammars and Dual Encoders Arnheim 1 The original algorithm from the paper Generative Art Using Neural Visual Gramm

DeepMind 231 Jan 05, 2023
ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021)

ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021) Project Page | Video | Paper | Data We present a novel metho

65 Nov 28, 2022
Alternatives to Deep Neural Networks for Function Approximations in Finance

Alternatives to Deep Neural Networks for Function Approximations in Finance Code companion repo Overview This is a repository of Python code to go wit

15 Dec 17, 2022
Python-experiments - A Repository which contains python scripts to automate things and make your life easier with python

Python Experiments A Repository which contains python scripts to automate things

Vivek Kumar Singh 11 Sep 25, 2022
Pytorch implementation of Hinton's Dynamic Routing Between Capsules

pytorch-capsule A Pytorch implementation of Hinton's "Dynamic Routing Between Capsules". https://arxiv.org/pdf/1710.09829.pdf Thanks to @naturomics fo

Tim Omernick 625 Oct 27, 2022
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
Pretrained Pytorch face detection (MTCNN) and recognition (InceptionResnet) models

Face Recognition Using Pytorch Python 3.7 3.6 3.5 Status This is a repository for Inception Resnet (V1) models in pytorch, pretrained on VGGFace2 and

Tim Esler 3.3k Jan 04, 2023
Learn about quantum computing and algorithm on quantum computing

quantum_computing this repo contains everything i learn about quantum computing and algorithm on quantum computing what is aquantum computing quantum

arfy slowy 8 Dec 25, 2022
Segmentation and Identification of Vertebrae in CT Scans using CNN, k-means Clustering and k-NN

Segmentation and Identification of Vertebrae in CT Scans using CNN, k-means Clustering and k-NN If you use this code for your research, please cite ou

41 Dec 08, 2022
A NSFW content filter.

Project_Nfilter A NSFW content filter. With a motive of minimizing the spreads and leakage of NSFW contents on internet and access to others devices ,

1 Jan 20, 2022
This is the pytorch code for the paper Curious Representation Learning for Embodied Intelligence.

Curious Representation Learning for Embodied Intelligence This is the pytorch code for the paper Curious Representation Learning for Embodied Intellig

19 Oct 19, 2022
[CVPR 2020] Interpreting the Latent Space of GANs for Semantic Face Editing

InterFaceGAN - Interpreting the Latent Space of GANs for Semantic Face Editing Figure: High-quality facial attributes editing results with InterFaceGA

GenForce: May Generative Force Be with You 1.3k Dec 29, 2022
TensorFlow implementation for Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How

Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How TensorFlow implementation for Bayesian Modeling and Unce

Shen Lab at Texas A&M University 8 Sep 02, 2022
Official repository of the paper 'Essentials for Class Incremental Learning'

Essentials for Class Incremental Learning Official repository of the paper 'Essentials for Class Incremental Learning' This Pytorch repository contain

33 Nov 27, 2022
Learning based AI for playing multi-round Koi-Koi hanafuda card games. Have fun.

Koi-Koi AI Learning based AI for playing multi-round Koi-Koi hanafuda card games. Platform Python PyTorch PySimpleGUI (for the interface playing vs AI

Sanghai Guan 10 Nov 20, 2022
Code for CVPR 2018 paper --- Texture Mapping for 3D Reconstruction with RGB-D Sensor

G2LTex This repository contains the implementation of "Texture Mapping for 3D Reconstruction with RGB-D Sensor (CVPR2018)" based on mvs-texturing. Due

Fu Yanping(付燕平) 129 Dec 30, 2022
Modelisation on galaxy evolution using PEGASE-HR

model_galaxy Modelisation on galaxy evolution using PEGASE-HR This is a labwork done in internship at IAP directed by Damien Le Borgne (https://github

Adrien Anthore 1 Jan 14, 2022