PyTorch trainer and model for Sequence Classification

Overview

PyTorch-trainer-and-model-for-Sequence-Classification

After cloning the repository, modify your training data so that the training data is a .csv file and it has 2 columns: Text and Label

In the below example, we will assume that our training data has 3 labels, the name of our training data file is train_data.csv

Example Usage

Import dependencies

import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer, AutoConfig

from EarlyStopping import *
from modelling import *
from utils import *

Specify arguments

args.pretrained_path will be the path of our pretrained language model

class args:
    fold = 0
    pretrained_path = 'bert-base-uncased'
    max_length = 400
    train_batch_size = 16
    val_batch_size = 64
    epochs = 5
    learning_rate = 1e-5
    accumulation_steps = 2
    num_splits = 5

Create train and validation data

In this example we will train the model using cross-validation. We will split our training data into args.num_splits folds.

df = pd.read_csv('./train_data.csv')
df = create_k_folds(df, args.num_splits)

df_train = df[df['kfold'] == args.fold].reset_index(drop = True)
df_valid = df[df['kfold'] == args.fold].reset_index(drop = True)

Load the language model and its tokenizer

config = AutoConfig.from_pretrained(args.path)
tokenizer = AutoTokenizer.from_pretrained(args.path)
model_transformer = AutoModel.from_pretrained(args.path)

Prepare train and validation dataloaders

features = []
for i in range(len(df_train)):
    features.append(prepare_features(tokenizer, df_train.iloc[i, :].to_dict(), args.max_length))
    
train_dataset = CreateDataset(features)
train_dataloader = create_dataloader(train_dataset, args.train_batch_size, 'train')

features = []
for i in range(len(df_valid)):
    features.append(prepare_features(tokenizer, df_valid.iloc[i, :].to_dict(), args.max_length))
    
val_dataset = CreateDataset(features)
val_dataloader = create_dataloader(val_dataset, args.val_batch_size, 'val')

Use EarlyStopping and customize the score function

NOTE: The customized score function should have 2 parameters: the logits, and the actual label

def accuracy(logits, labels):
    logits = logits.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    pred_classes = np.argmax(logits * (1 / np.sum(logits, axis = -1)).reshape(logits.shape[0], 1), axis = -1)
    pred_classes = pred_classes.reshape(labels.shape)
    
    return np.sum(pred_classes == labels) / labels.shape[0]

es = EarlyStopping(mode = 'max', patience = 3, monitor = 'val_acc', out_path = 'model.bin')
es.monitor_score_function = accuracy

Create and train the model

Calling the fit method, the training process will begin

model = Model(config, model_transformer, num_labels = 3)
model.to('cuda')
num_train_steps = int(len(train_dataset) / args.train_batch_size * args.epochs)
model.fit(args.epochs, args.learning_rate, num_train_steps, args.accumulation_steps, 
          train_dataloader, val_dataloader, es)

NOTE: To complete the cross-validation training process, run the code above again with args.fold equals 1, 2, ..., args.num_splits - 1

Owner
NhanTieu
NhanTieu
Real-time ground filtering algorithm of cloud points acquired using Terrestrial Laser Scanner (TLS)

This repository contains tools to simulate the ground filtering process of a registered point cloud. The repository contains two filtering methods. The first method uses a normal vector, and fit to p

5 Aug 25, 2022
A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution.

Awesome Pretrained StyleGAN2 A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution. Note the readme is a

Justin 1.1k Dec 24, 2022
Python版OpenCVのTracking APIのサンプルです。DaSiamRPNアルゴリズムまで対応しています。

OpenCV-Object-Tracker-Sample Python版OpenCVのTracking APIのサンプルです。   Requirement opencv-contrib-python 4.5.3.56 or later Algorithm 2021/07/16時点でOpenCVには以

KazuhitoTakahashi 36 Jan 01, 2023
A collection of easy-to-use, ready-to-use, interesting deep neural network models

Interesting and reproducible research works should be conserved. This repository wraps a collection of deep neural network models into a simple and un

Aria Ghora Prabono 16 Jun 16, 2022
A python program to hack instagram

hackinsta a program to hack instagram Yokoback_(instahack) is the file to open, you need libraries write on import. You run that file in the same fold

2 Jan 22, 2022
A mini lib that implements several useful functions binding to PyTorch in C++.

Torch-gather A mini library that implements several useful functions binding to PyTorch in C++. What does gather do? Why do we need it? When dealing w

maxwellzh 8 Sep 07, 2022
Official repository for the ISBI 2021 paper Transformer Assisted Convolutional Neural Network for Cell Instance Segmentation

SegPC-2021 This is the official repository for the ISBI 2021 paper Transformer Assisted Convolutional Neural Network for Cell Instance Segmentation by

Datascience IIT-ISM 13 Dec 14, 2022
Simple embedding based text classifier inspired by fastText, implemented in tensorflow

FastText in Tensorflow This project is based on the ideas in Facebook's FastText but implemented in Tensorflow. However, it is not an exact replica of

Alan Patterson 306 Dec 02, 2022
This repository contains the code and models necessary to replicate the results of paper: How to Robustify Black-Box ML Models? A Zeroth-Order Optimization Perspective

Black-Box-Defense This repository contains the code and models necessary to replicate the results of our recent paper: How to Robustify Black-Box ML M

OPTML Group 2 Oct 05, 2022
[NeurIPS 2021] Official implementation of paper "Learning to Simulate Self-driven Particles System with Coordinated Policy Optimization".

Code for Coordinated Policy Optimization Webpage | Code | Paper | Talk (English) | Talk (Chinese) Hi there! This is the source code of the paper “Lear

DeciForce: Crossroads of Machine Perception and Autonomy 81 Dec 19, 2022
Official code for UnICORNN (ICML 2021)

UnICORNN (Undamped Independent Controlled Oscillatory RNN) [ICML 2021] This repository contains the implementation to reproduce the numerical experime

Konstantin Rusch 21 Dec 22, 2022
Task-based end-to-end model learning in stochastic optimization

Task-based End-to-end Model Learning in Stochastic Optimization This repository is by Priya L. Donti, Brandon Amos, and J. Zico Kolter and contains th

CMU Locus Lab 164 Dec 29, 2022
Rank1 Conversation Emotion Detection Task

Rank1-Conversation_Emotion_Detection_Task accuracy macro-f1 recall 0.826 0.7544 0.719 基于预训练模型和时序预测模型的对话情感探测任务 1 摘要 针对对话情感探测任务,本文将其分为文本分类和时间序列预测两个子任务,分

Yuchen Han 2 Nov 28, 2021
Towards Improving Embedding Based Models of Social Network Alignment via Pseudo Anchors

PSML paper: Towards Improving Embedding Based Models of Social Network Alignment via Pseudo Anchors PSML_IONE,PSML_ABNE,PSML_DEEPLINK,PSML_SNNA: numpy

13 Nov 27, 2022
Tutorial page of the Climate Hack, the greatest hackathon ever

Tutorial page of the Climate Hack, the greatest hackathon ever

UCL Artificial Intelligence Society 12 Jul 02, 2022
[ WSDM '22 ] On Sampling Collaborative Filtering Datasets

On Sampling Collaborative Filtering Datasets This repository contains the implementation of many popular sampling strategies, along with various expli

Noveen Sachdeva 17 Dec 08, 2022
HyperDict - Self linked dictionary in Python

Hyper Dictionary Advanced python dictionary(hash-table), which can link it-self

8 Feb 06, 2022
Dense Passage Retriever - is a set of tools and models for open domain Q&A task.

Dense Passage Retrieval Dense Passage Retrieval (DPR) - is a set of tools and models for state-of-the-art open-domain Q&A research. It is based on the

Meta Research 1.1k Jan 03, 2023
Official implementation of "Learning Not to Reconstruct" (BMVC 2021)

Official PyTorch implementation of "Learning Not to Reconstruct Anomalies" This is the implementation of the paper "Learning Not to Reconstruct Anomal

Marcella Astrid 13 Dec 04, 2022
FwordCTF 2021 Infrastructure and Source code of Web/Bash challenges

FwordCTF 2021 You can find here the source code of the challenges I wrote (Web and Bash) in FwordCTF 2021 and the source code of the platform with our

Kahla 5 Nov 25, 2022