End-To-End Memory Network using Tensorflow

Overview

MemN2N

Implementation of End-To-End Memory Networks with sklearn-like interface using Tensorflow. Tasks are from the bAbl dataset.

MemN2N picture

Get Started

git clone [email protected]:domluna/memn2n.git

mkdir ./memn2n/data/
cd ./memn2n/data/
wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz
tar xzvf ./tasks_1-20_v1-2.tar.gz

cd ../
python single.py

Examples

Running a single bAbI task

Running a joint model on all bAbI tasks

These files are also a good example of usage.

Requirements

  • tensorflow 1.0
  • scikit-learn 0.17.1
  • six 1.10.0

Single Task Results

For a task to pass it has to meet 95%+ testing accuracy. Measured on single tasks on the 1k data.

Pass: 1,4,12,15,20

Several other tasks have 80%+ testing accuracy.

Stochastic gradient descent optimizer was used with an annealed learning rate schedule as specified in Section 4.2 of End-To-End Memory Networks

The following params were used:

  • epochs: 100
  • hops: 3
  • embedding_size: 20
Task Training Accuracy Validation Accuracy Testing Accuracy
1 1.0 1.0 1.0
2 1.0 0.86 0.83
3 1.0 0.64 0.54
4 1.0 0.99 0.98
5 1.0 0.94 0.87
6 1.0 0.97 0.92
7 1.0 0.89 0.84
8 1.0 0.93 0.86
9 1.0 0.86 0.90
10 1.0 0.80 0.78
11 1.0 0.92 0.84
12 1.0 1.0 1.0
13 0.99 0.94 0.90
14 1.0 0.97 0.93
15 1.0 1.0 1.0
16 0.81 0.47 0.44
17 0.76 0.65 0.52
18 0.97 0.96 0.88
19 0.40 0.17 0.13
20 1.0 1.0 1.0

Joint Training Results

Pass: 1,6,9,10,12,13,15,20

Again stochastic gradient descent optimizer was used with an annealed learning rate schedule as specified in Section 4.2 of End-To-End Memory Networks

The following params were used:

  • epochs: 60
  • hops: 3
  • embedding_size: 40
Task Training Accuracy Validation Accuracy Testing Accuracy
1 1.0 0.99 0.999
2 1.0 0.84 0.849
3 0.99 0.72 0.715
4 0.96 0.86 0.851
5 1.0 0.92 0.865
6 1.0 0.97 0.964
7 0.96 0.87 0.851
8 0.99 0.89 0.898
9 0.99 0.96 0.96
10 1.0 0.96 0.928
11 1.0 0.98 0.93
12 1.0 0.98 0.982
13 0.99 0.98 0.976
14 1.0 0.81 0.877
15 1.0 1.0 0.983
16 0.64 0.45 0.44
17 0.77 0.64 0.547
18 0.85 0.71 0.586
19 0.24 0.07 0.104
20 1.0 1.0 0.996

Notes

Single task results are from 10 repeated trails of the single task model accross all 20 tasks with different random initializations. The performance of the model with the lowest validation accuracy for each task is shown in the table above.

Joint training results are from 10 repeated trails of the joint model accross all tasks. The performance of the single model whose validation accuracy passed the most tasks (>= 0.95) is shown in the table above (joint_scores_run2.csv). The scores from all 10 runs are located in the results/ directory.

Owner
Dominique Luna
magnificent stallion
Dominique Luna
Numerical-computing-is-fun - Learning numerical computing with notebooks for all ages.

As much as this series is to educate aspiring computer programmers and data scientists of all ages and all backgrounds, it is also a reminder to mysel

EKA foundation 758 Dec 25, 2022
A video scene detection algorithm is designed to detect a variety of different scenes within a video

Scene-Change-Detection - A video scene detection algorithm is designed to detect a variety of different scenes within a video. There is a very simple definition for a scene: It is a series of logical

1 Jan 04, 2022
BABEL: Bodies, Action and Behavior with English Labels [CVPR 2021]

BABEL is a large dataset with language labels describing the actions being performed in mocap sequences. BABEL labels about 43 hours of mocap sequences from AMASS [1] with action labels.

113 Dec 28, 2022
PyTorch-centric library for evaluating and enhancing the robustness of AI technologies

Responsible AI Toolbox A library that provides high-quality, PyTorch-centric tools for evaluating and enhancing both the robustness and the explainabi

24 Dec 22, 2022
PyTorch implementations of algorithms for density estimation

pytorch-flows A PyTorch implementations of Masked Autoregressive Flow and some other invertible transformations from Glow: Generative Flow with Invert

Ilya Kostrikov 546 Dec 05, 2022
Implementing yolov4 target detection and tracking based on nao robot

Implementing yolov4 target detection and tracking based on nao robot

6 Apr 19, 2022
Code for EmBERT, a transformer model for embodied, language-guided visual task completion.

Code for EmBERT, a transformer model for embodied, language-guided visual task completion.

41 Jan 03, 2023
DrQ-v2: Improved Data-Augmented Reinforcement Learning

DrQ-v2: Improved Data-Augmented RL Agent Method DrQ-v2 is a model-free off-policy algorithm for image-based continuous control. DrQ-v2 builds on DrQ,

Facebook Research 234 Jan 01, 2023
Deep Learning GPU Training System

DIGITS DIGITS (the Deep Learning GPU Training System) is a webapp for training deep learning models. The currently supported frameworks are: Caffe, To

NVIDIA Corporation 4.1k Jan 03, 2023
Pointer networks Tensorflow2

Pointer networks Tensorflow2 原文:https://arxiv.org/abs/1506.03134 仅供参考与学习,内含代码备注 环境 tensorflow==2.6.0 tqdm matplotlib numpy 《pointer networks》阅读笔记 应用场景

HUANG HAO 7 Oct 27, 2022
Optimized primitives for collective multi-GPU communication

NCCL Optimized primitives for inter-GPU communication. Introduction NCCL (pronounced "Nickel") is a stand-alone library of standard communication rout

NVIDIA Corporation 2k Jan 09, 2023
Toolbox to analyze temporal context invariance of deep neural networks

PyTCI A toolbox that estimates the integration window of a sensory response using the "Temporal Context Invariance" paradigm (TCI). The TCI method Int

4 Oct 23, 2022
Auditing Black-Box Prediction Models for Data Minimization Compliance

Data-Minimization-Auditor An auditing tool for model-instability based data minimization that is introduced in "Auditing Black-Box Prediction Models f

Bashir Rastegarpanah 2 Mar 24, 2022
Pytorch-Swin-Unet-V2 - a modified version of Swin Unet based on Swin Transfomer V2

Swin Unet V2 Swin Unet V2 is a modified version of Swin Unet arxiv based on Swin

Chenxu Peng 26 Dec 03, 2022
ObsPy: A Python Toolbox for seismology/seismological observatories.

ObsPy is an open-source project dedicated to provide a Python framework for processing seismological data. It provides parsers for common file formats

ObsPy 979 Jan 07, 2023
CONditionals for Ordinal Regression and classification in tensorflow

Condor Ordinal regression in Tensorflow Keras Tensorflow Keras implementation of CONDOR Ordinal Regression (aka ordinal classification) by Garrett Jen

9 Jul 31, 2022
MTA:SA Server Configer.

MTAConfiger MTA:SA Server Configer. Hi 👋 , I'm Alireza A Python Developer Boy 🔭 I’m currently working on my C# projects 🌱 I’m currently Learning CS

3 Jun 07, 2022
StarGAN v2-Tensorflow - Simple Tensorflow implementation of StarGAN v2

Official Tensorflow implementation Open ! - Clova AI StarGAN v2 — Un-official TensorFlow Implementation [Paper] [Pytorch] : Diverse Image Synthesis f

Junho Kim 110 Jul 02, 2022
The code repository for "PyCIL: A Python Toolbox for Class-Incremental Learning" in PyTorch.

PyCIL: A Python Toolbox for Class-Incremental Learning Introduction • Methods Reproduced • Reproduced Results • How To Use • License • Acknowledgement

Fu-Yun Wang 258 Dec 31, 2022
MonoRec: Semi-Supervised Dense Reconstruction in Dynamic Environments from a Single Moving Camera

MonoRec: Semi-Supervised Dense Reconstruction in Dynamic Environments from a Single Moving Camera

Felix Wimbauer 494 Jan 06, 2023