CNNs for Sentence Classification in PyTorch

Overview

Introduction

This is the implementation of Kim's Convolutional Neural Networks for Sentence Classification paper in PyTorch.

  1. Kim's implementation of the model in Theano: https://github.com/yoonkim/CNN_sentence
  2. Denny Britz has an implementation in Tensorflow: https://github.com/dennybritz/cnn-text-classification-tf
  3. Alexander Rakhlin's implementation in Keras; https://github.com/alexander-rakhlin/CNN-for-Sentence-Classification-in-Keras

Requirement

  • python 3
  • pytorch > 0.1
  • torchtext > 0.1
  • numpy

Result

I just tried two dataset, MR and SST.

Dataset Class Size Best Result Kim's Paper Result
MR 2 77.5%(CNN-rand-static) 76.1%(CNN-rand-nostatic)
SST 5 37.2%(CNN-rand-static) 45.0%(CNN-rand-nostatic)

I haven't adjusted the hyper-parameters for SST seriously.

Usage

./main.py -h

or

python3 main.py -h

You will get:

CNN text classificer

optional arguments:
  -h, --help            show this help message and exit
  -batch-size N         batch size for training [default: 50]
  -lr LR                initial learning rate [default: 0.01]
  -epochs N             number of epochs for train [default: 10]
  -dropout              the probability for dropout [default: 0.5]
  -max_norm MAX_NORM    l2 constraint of parameters
  -cpu                  disable the gpu
  -device DEVICE        device to use for iterate data
  -embed-dim EMBED_DIM
  -static               fix the embedding
  -kernel-sizes KERNEL_SIZES
                        Comma-separated kernel size to use for convolution
  -kernel-num KERNEL_NUM
                        number of each kind of kernel
  -class-num CLASS_NUM  number of class
  -shuffle              shuffle the data every epoch
  -num-workers NUM_WORKERS
                        how many subprocesses to use for data loading
                        [default: 0]
  -log-interval LOG_INTERVAL
                        how many batches to wait before logging training
                        status
  -test-interval TEST_INTERVAL
                        how many epochs to wait before testing
  -save-interval SAVE_INTERVAL
                        how many epochs to wait before saving
  -predict PREDICT      predict the sentence given
  -snapshot SNAPSHOT    filename of model snapshot [default: None]
  -save-dir SAVE_DIR    where to save the checkpoint

Train

./main.py

You will get:

Batch[100] - loss: 0.655424  acc: 59.3750%
Evaluation - loss: 0.672396  acc: 57.6923%(615/1066) 

Test

If you has construct you test set, you make testing like:

/main.py -test -snapshot="./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt

The snapshot option means where your model load from. If you don't assign it, the model will start from scratch.

Predict

  • Example1

     ./main.py -predict="Hello my dear , I love you so much ." \
               -snapshot="./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt" 
    

    You will get:

     Loading model from [./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt]...
     
     [Text]  Hello my dear , I love you so much .
     [Label] positive
    
  • Example2

     ./main.py -predict="You just make me so sad and I have to leave you ."\
               -snapshot="./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt" 
    

    You will get:

     Loading model from [./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt]...
     
     [Text]  You just make me so sad and I have to leave you .
     [Label] negative
    

Your text must be separated by space, even punctuation.And, your text should longer then the max kernel size.

Reference

Owner
Shawn Ng
Now, I focus on the Natural Language Processing, such as QA
Shawn Ng
A PyTorch implementation of "SimGNN: A Neural Network Approach to Fast Graph Similarity Computation" (WSDM 2019).

SimGNN ⠀⠀⠀ A PyTorch implementation of SimGNN: A Neural Network Approach to Fast Graph Similarity Computation (WSDM 2019). Abstract Graph similarity s

Benedek Rozemberczki 534 Dec 25, 2022
Computer Vision Paper Reviews with Key Summary of paper, End to End Code Practice and Jupyter Notebook converted papers

Computer-Vision-Paper-Reviews Computer Vision Paper Reviews with Key Summary along Papers & Codes. Jonathan Choi 2021 The repository provides 100+ Pap

Jonathan Choi 2 Mar 17, 2022
PyTorch code for Composing Partial Differential Equations with Physics-Aware Neural Networks

FInite volume Neural Network (FINN) This repository contains the PyTorch code for models, training, and testing, and Python code for data generation t

Cognitive Modeling 20 Dec 18, 2022
Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth [Paper]

Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth [Paper] Downloads [Downloads] Trained ckpt files for NYU Depth V2 and

98 Jan 01, 2023
Official implementation of Rethinking Graph Neural Architecture Search from Message-passing (CVPR2021)

Rethinking Graph Neural Architecture Search from Message-passing Intro The GNAS can automatically learn better architecture with the optimal depth of

Shaofei Cai 48 Sep 30, 2022
A simple Python configuration file operator.

A simple Python configuration file operator This project provides a common way to read configurations using config42. Installation It is possible to i

Scott Lau 2 Nov 08, 2021
Tool which allow you to detect and translate text.

Text detection and recognition This repository contains tool which allow to detect region with text and translate it one by one. Description Two pretr

Damian Panek 176 Nov 28, 2022
Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020)

GraspNet Baseline Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020). [paper] [dataset] [API] [do

GraspNet 209 Dec 29, 2022
[ICCV 2021] Encoder-decoder with Multi-level Attention for 3D Human Shape and Pose Estimation

MAED: Encoder-decoder with Multi-level Attention for 3D Human Shape and Pose Estimation Getting Started Our codes are implemented and tested with pyth

ZiNiU WaN 176 Dec 15, 2022
DCSL - Generalizable Crowd Counting via Diverse Context Style Learning

DCSL Generalizable Crowd Counting via Diverse Context Style Learning Requirement

3 Jun 13, 2022
DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference

DeeBERT This is the code base for the paper DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference. Code in this repository is also available

Castorini 132 Nov 14, 2022
Using machine learning to predict undergrad college admissions.

College-Prediction Project- Overview: Many have tried, many have failed. Few trailblazers are ambitious enought to chase acceptance into the top 15 un

John H Klinges 1 Jan 05, 2022
Shōgun

The SHOGUN machine learning toolbox Unified and efficient Machine Learning since 1999. Latest release: Cite Shogun: Develop branch build status: Donat

Shōgun ML 2.9k Jan 04, 2023
Stochastic Scene-Aware Motion Prediction

Stochastic Scene-Aware Motion Prediction [Project Page] [Paper] Description This repository contains the training code for MotionNet and GoalNet of SA

Mohamed Hassan 31 Dec 09, 2022
Federated learning on graph, especially on graph neural networks (GNNs), knowledge graph, and private GNN.

Federated learning on graph, especially on graph neural networks (GNNs), knowledge graph, and private GNN.

keven 198 Dec 20, 2022
Official repo for SemanticGAN https://nv-tlabs.github.io/semanticGAN/

SemanticGAN This is the official code for: Semantic Segmentation with Generative Models: Semi-Supervised Learning and Strong Out-of-Domain Generalizat

151 Dec 28, 2022
This repository contains the entire code for our work "Two-Timescale End-to-End Learning for Channel Acquisition and Hybrid Precoding"

Two-Timescale-DNN Two-Timescale End-to-End Learning for Channel Acquisition and Hybrid Precoding This repository contains the entire code for our work

QiyuHu 3 Mar 07, 2022
Anime Face Detector using mmdet and mmpose

Anime Face Detector This is an anime face detector using mmdetection and mmpose. (To avoid copyright issues, I use generated images by the TADNE model

198 Jan 07, 2023
K-FACE Analysis Project on Pytorch

Installation Setup with Conda # create a new environment conda create --name insightKface python=3.7 # or over conda activate insightKface #install t

Jung Jun Uk 7 Nov 10, 2022
InsTrim: Lightweight Instrumentation for Coverage-guided Fuzzing

InsTrim The paper: InsTrim: Lightweight Instrumentation for Coverage-guided Fuzzing Build Prerequisite llvm-8.0-dev clang-8.0 cmake = 3.2 Make git cl

75 Dec 23, 2022