Simple embedding based text classifier inspired by fastText, implemented in tensorflow

Overview

FastText in Tensorflow

This project is based on the ideas in Facebook's FastText but implemented in Tensorflow. However, it is not an exact replica of fastText.

Classification is done by embedding each word, taking the mean embedding over the full text and classifying that using a linear classifier. The embedding is trained with the classifier. You can also specify to use 2+ character ngrams. These ngrams get hashed then embedded in a similar manner to the orginal words. Note, ngrams make training much slower but only make marginal improvements in performance, at least in English.

I may implement skipgram and cbow training later. Or preloading embedding tables.

<< Still WIP >>

You can use Horovod to distribute training across multiple GPUs, on one or multiple servers. See usage section below.

FastText Language Identification

I have added utilities to train a classifier to detect languages, as described in Fast and Accurate Language Identification using FastText

See usage below. It basically works in the same way as default usage.

Implemented:

  • classification of text using word embeddings
  • char ngrams, hashed to n bins
  • training and prediction program
  • serve models on tensorflow serving
  • preprocess facebook format, or text input into tensorflow records

Not Implemented:

  • separate word vector training (though can export embeddings)
  • heirarchical softmax.
  • quantize models (supported by tensorflow, but I haven't tried it yet)

Usage

The following are examples of how to use the applications. Get full help with --help option on any of the programs.

To transform input data into tensorflow Example format:

process_input.py --facebook_input=queries.txt --output_dir=. --ngrams=2,3,4

Or, using a text file with one example per line with an extra file for labels:

process_input.py --text_input=queries.txt --labels=labels.txt --output_dir=.

To train a text classifier:

classifier.py \
  --train_records=queries.tfrecords \
  --eval_records=queries.tfrecords \
  --label_file=labels.txt \
  --vocab_file=vocab.txt \
  --model_dir=model \
  --export_dir=model

To predict classifications for text, use a saved_model from classifier. classifier.py --export_dir stores a saved model in a numbered directory below export_dir. Pass this directory to the following to use that model for predictions:

predictor.py
  --saved_model=model/12345678
  --text="some text to classify"
  --signature_def=proba

To export the embedding layer you can export from predictor. Note, this will only be the text embedding, not the ngram embeddings.

predictor.py
  --saved_model=model/12345678
  --text="some text to classify"
  --signature_def=embedding

Use the provided script to train easily:

train_classifier.sh path-to-data-directory

Language Identification

To implement something similar to the method described in Fast and Accurate Language Identification using FastText you need to download the data:

lang_dataset.sh [datadir]

You can then process the training and validation data using process_input.py and classifier.py as described above.

There is a utility script to do this for you:

train_langdetect.sh datadir

It reaches about 96% accuracy using word embeddings and this increases to nearly 99% when adding --ngrams=2,3,4

Distributed Training

You can run training across multiple GPUs either on one or multiple servers. To do so you need to install MPI and Horovod then add the --horovod option. It runs very close to the GPU multiple in terms of performance. I.e. if you have 2 GPUs on your server, it should run close to 2x the speed.

NUM_GPUS=2
mpirun -np $NUM_GPUS python classifier.py \
  --horovod \
  --train_records=queries.tfrecords \
  --eval_records=queries.tfrecords \
  --label_file=labels.txt \
  --vocab_file=vocab.txt \
  --model_dir=model \
  --export_dir=model

The training script has this option added: train_classifier.sh.

Tensorflow Serving

As well as using predictor.py to run a saved model to provide predictions, it is easy to serve a saved model using Tensorflow Serving with a client server setup. There is a supplied simple rpc client (predictor_client.py) that provides predictions by using tensorflow server.

First make sure you install the tensorflow serving binaries. Instructions are here.

You then serve the latest saved model by supplying the base export directory where you exported saved models to. This directory will contain the numbered model directories:

tensorflow_model_server --port=9000 --model_base_path=model

Now you can make requests to the server using gRPC calls. An example simple client is provided in predictor_client.py:

predictor_client.py --text="Some text to classify"

Facebook Examples

<< NOT IMPLEMENTED YET >>

You can compare with Facebook's fastText by running similar examples to what's provided in their repository.

./classification_example.sh
./classification_results.sh
Owner
Alan Patterson
Alan Patterson
RefineMask (CVPR 2021)

RefineMask: Towards High-Quality Instance Segmentation with Fine-Grained Features (CVPR 2021) This repo is the official implementation of RefineMask:

Gang Zhang 191 Jan 07, 2023
Python scripts performing class agnostic object localization using the Object Localization Network model in ONNX.

ONNX Object Localization Network Python scripts performing class agnostic object localization using the Object Localization Network model in ONNX. Ori

Ibai Gorordo 15 Oct 14, 2022
A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualization

Website, Tutorials, and Docs    Uncertainty Toolbox A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualizatio

Uncertainty Toolbox 1.4k Dec 28, 2022
LocUNet is a deep learning method to localize a UE based solely on the reported signal strengths from a set of BSs.

LocUNet LocUNet is a deep learning method to localize a UE based solely on the reported signal strengths from a set of BSs. The method utilizes accura

4 Oct 05, 2022
Official re-implementation of the Calibrated Adversarial Refinement model described in the paper Calibrated Adversarial Refinement for Stochastic Semantic Segmentation

Official re-implementation of the Calibrated Adversarial Refinement model described in the paper Calibrated Adversarial Refinement for Stochastic Semantic Segmentation

Elias Kassapis 31 Nov 22, 2022
Simple (but Strong) Baselines for POMDPs

Recurrent Model-Free RL is a Strong Baseline for Many POMDPs Welcome to the POMDP world! This repo provides some simple baselines for POMDPs, specific

Tianwei V. Ni 172 Dec 29, 2022
Wafer Fault Detection using MlOps Integration

Wafer Fault Detection using MlOps Integration This is an end to end machine learning project with MlOps integration for predicting the quality of wafe

Sethu Sai Medamallela 0 Mar 11, 2022
Pytorch implementation for "Implicit Feature Alignment: Learn to Convert Text Recognizer to Text Spotter".

Implicit Feature Alignment: Learn to Convert Text Recognizer to Text Spotter This is a pytorch-based implementation for paper Implicit Feature Alignme

wangtianwei 61 Nov 12, 2022
Collection of generative models in Tensorflow

tensorflow-generative-model-collections Tensorflow implementation of various GANs and VAEs. Related Repositories Pytorch version Pytorch version of th

3.8k Dec 30, 2022
Designing a Practical Degradation Model for Deep Blind Image Super-Resolution (ICCV, 2021) (PyTorch) - We released the training code!

Designing a Practical Degradation Model for Deep Blind Image Super-Resolution Kai Zhang, Jingyun Liang, Luc Van Gool, Radu Timofte Computer Vision Lab

Kai Zhang 804 Jan 08, 2023
Implementation for HFGI: High-Fidelity GAN Inversion for Image Attribute Editing

HFGI: High-Fidelity GAN Inversion for Image Attribute Editing High-Fidelity GAN Inversion for Image Attribute Editing Update: We released the inferenc

Tengfei Wang 371 Dec 30, 2022
BEGAN in PyTorch

BEGAN in PyTorch This project is still in progress. If you are looking for the working code, use BEGAN-tensorflow. Requirements Python 2.7 Pillow tqdm

Taehoon Kim 260 Dec 07, 2022
Pomodoro timer that acknowledges the inexorable, infinite passage of time

Pomodouroboros Most pomodoro trackers assume you're going to start them. But time and tide wait for no one - the great pomodoro of the cosmos is cold

Glyph 66 Dec 13, 2022
'Aligned mixture of latent dynamical systems' (amLDS) for stimulus decoding probabilistic manifold alignment across animals. P. Herrero-Vidal et al. NeurIPS 2021 code.

Across-animal odor decoding by probabilistic manifold alignment (NeurIPS 2021) This repository is the official implementation of aligned mixture of la

Pedro Herrero-Vidal 3 Jul 12, 2022
Code for the paper "Spatio-temporal Self-Supervised Representation Learning for 3D Point Clouds" (ICCV 2021)

Spatio-temporal Self-Supervised Representation Learning for 3D Point Clouds This is the official code implementation for the paper "Spatio-temporal Se

Hesper 63 Jan 05, 2023
🌊 Online machine learning in Python

In a nutshell River is a Python library for online machine learning. It is the result of a merger between creme and scikit-multiflow. River's ambition

OnlineML 4k Jan 02, 2023
Web-interface + rest API for classification and regression (https://jeff1evesque.github.io/machine-learning.docs)

Machine Learning This project provides a web-interface, as well as a programmatic-api for various machine learning algorithms. Supported algorithms: S

Jeff Levesque 252 Dec 11, 2022
Template repository for managing machine learning research projects built with PyTorch-Lightning

Tutorial Repository with a minimal example for showing how to deploy training across various compute infrastructure.

Sidd Karamcheti 3 Feb 11, 2022
Implementing yolov4 target detection and tracking based on nao robot

Implementing yolov4 target detection and tracking based on nao robot

6 Apr 19, 2022
Wide Residual Networks (WideResNets) in PyTorch

Wide Residual Networks (WideResNets) in PyTorch WideResNets for CIFAR10/100 implemented in PyTorch. This implementation requires less GPU memory than

Jason Kuen 296 Dec 27, 2022