SmallInitEmb - LayerNorm(SmallInit(Embedding)) in a Transformer to improve convergence

Overview

SmallInitEmb

LayerNorm(SmallInit(Embedding)) in a Transformer

I find that when training a transformer, the embedding matrix moves slowly, hence it's difficult for the model to jump out of the initial noisy embedding.

(initial embedding)
[[-0.0073  0.0062 -0.0261 ...  0.0086  0.0107 -0.008 ] ... ]
 (after 1 step, the directions of the embedding vectors are not moved much because the numbers change by ~LR = ~4e-4)
[[-0.0069  0.0066 -0.0265 ...  0.009   0.0111 -0.0084] ... ]

So I propose initializing the embedding matrix to tiny values, and put another LayerNorm after it (before all the SA & FFN layers):

if isinstance(module, (nn.Embedding)):
    nn.init.uniform_(module.weight, a=-1e-4, b=1e-4) # SmallInit(Emb)
...
if self.config.USE_SMALL_EMB and self.layer_id == 0:
    x = self.lnPre(x) # LN(SmallInit(Emb))
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))

And then you get improved convergence (especially for BPE models) because the model can quickly jump out of the tiny initial embedding (small changes after 1 step -> significant changes of directions -> significant changes after LayerNorm).

Loss curve comparison: https://wandb.ai/blinkdl/SmallEmbTest

(the gap between LayerNorm(SmallEmb)) and baseline persists after more training)

Moreover, you can directly train PostLN models without warmup with SmallInit(Emb)

if isinstance(module, (nn.Embedding)):
    nn.init.uniform_(module.weight, a=-1e-4, b=1e-4) # SmallInit(Emb)
...
x = self.ln1(x) # this plays the same role as the lnPre in the above PreLN code
x = x + self.att(x)
x = self.ln2(x)
x = x + self.ffn(x)
(note you shall have another LN after the final ffn)
Owner
PENG Bo
http://zhihu.com/people/bopengbopeng
PENG Bo
Intro-to-dl - Resources for "Introduction to Deep Learning" course.

Introduction to Deep Learning course resources https://www.coursera.org/learn/intro-to-deep-learning Running on Google Colab (tested for all weeks) Go

Advanced Machine Learning specialisation by HSE 761 Dec 24, 2022
graph-theoretic framework for robust pairwise data association

CLIPPER: A Graph-Theoretic Framework for Robust Data Association Data association is a fundamental problem in robotics and autonomy. CLIPPER provides

MIT Aerospace Controls Laboratory 118 Dec 28, 2022
FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation.

FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation [Project] [Paper] [arXiv] [Home] Official implementation of FastFCN:

Wu Huikai 815 Dec 29, 2022
This repo is the official implementation of "L2ight: Enabling On-Chip Learning for Optical Neural Networks via Efficient in-situ Subspace Optimization".

L2ight is a closed-loop ONN on-chip learning framework to enable scalable ONN mapping and efficient in-situ learning. L2ight adopts a three-stage learning flow that first calibrates the complicated p

Jiaqi Gu 9 Jul 14, 2022
Code for the paper "Functional Regularization for Reinforcement Learning via Learned Fourier Features"

Reinforcement Learning with Learned Fourier Features State-space Soft Actor-Critic Experiments Move to the state-SAC-LFF repository. cd state-SAC-LFF

Alex Li 10 Nov 11, 2022
Joint-task Self-supervised Learning for Temporal Correspondence (NeurIPS 2019)

Joint-task Self-supervised Learning for Temporal Correspondence Project | Paper Overview Joint-task Self-supervised Learning for Temporal Corresponden

Sifei Liu 167 Dec 14, 2022
Deep Learning (with PyTorch)

Deep Learning (with PyTorch) This notebook repository now has a companion website, where all the course material can be found in video and textual for

Alfredo Canziani 6.2k Jan 07, 2023
TensorFlow implementation of Deep Reinforcement Learning papers

Deep Reinforcement Learning in TensorFlow TensorFlow implementation of Deep Reinforcement Learning papers. This implementation contains: [1] Playing A

Taehoon Kim 1.6k Jan 03, 2023
LBK 26 Dec 28, 2022
Awesome Human Pose Estimation

Human Pose Estimation Related Publication

Zhe Wang 1.2k Dec 26, 2022
DirectVoxGO reconstructs a scene representation from a set of calibrated images capturing the scene.

DirectVoxGO reconstructs a scene representation from a set of calibrated images capturing the scene. We achieve NeRF-comparable novel-view synthesis quality with super-fast convergence.

sunset 709 Dec 31, 2022
Deep Occlusion-Aware Instance Segmentation with Overlapping BiLayers [CVPR 2021]

Deep Occlusion-Aware Instance Segmentation with Overlapping BiLayers [BCNet, CVPR 2021] This is the official pytorch implementation of BCNet built on

Lei Ke 434 Dec 01, 2022
Perfect implement. Model shared. x0.5 (Top1:60.646) and 1.0x (Top1:69.402).

Shufflenet-v2-Pytorch Introduction This is a Pytorch implementation of faceplusplus's ShuffleNet-v2. For details, please read the following papers:

423 Dec 07, 2022
This project implements "virtual speed" from heart rate monito

ANT+ Virtual Stride Based Speed and Distance Monitor Overview This project imple

2 May 20, 2022
Avatarify Python - Avatars for Zoom, Skype and other video-conferencing apps.

Avatarify Python - Avatars for Zoom, Skype and other video-conferencing apps.

Ali Aliev 15.3k Jan 05, 2023
Easily benchmark PyTorch model FLOPs, latency, throughput, max allocated memory and energy consumption

⏱ pytorch-benchmark Easily benchmark model inference FLOPs, latency, throughput, max allocated memory and energy consumption Install pip install pytor

Lukas Hedegaard 21 Dec 22, 2022
GAN-generated image detection based on CNNs

GAN-image-detection This repository contains a GAN-generated image detector developed to distinguish real images from synthetic ones. The detector is

Image and Sound Processing Lab 17 Dec 15, 2022
Distributed Asynchronous Hyperparameter Optimization better than HyperOpt.

UltraOpt : Distributed Asynchronous Hyperparameter Optimization better than HyperOpt. UltraOpt is a simple and efficient library to minimize expensive

98 Aug 16, 2022
Technical experimentations to beat the stock market using deep learning :chart_with_upwards_trend:

DeepStock Technical experimentations to beat the stock market using deep learning. Experimentations Deep Learning Stock Prediction with Daily News Hea

Keon 449 Dec 29, 2022
Source code for models described in the paper "AudioCLIP: Extending CLIP to Image, Text and Audio" (https://arxiv.org/abs/2106.13043)

AudioCLIP Extending CLIP to Image, Text and Audio This repository contains implementation of the models described in the paper arXiv:2106.13043. This

458 Jan 02, 2023