Fast, general, and tested differentiable structured prediction in PyTorch

Overview

Torch-Struct: Structured Prediction Library

Tests Coverage Status

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.

  • HMM / LinearChain-CRF
  • HSMM / SemiMarkov-CRF
  • Dependency Tree-CRF
  • PCFG Binary Tree-CRF
  • ...

Designed to be used as efficient batched layers in other PyTorch code.

Tutorial paper describing methodology.

Getting Started

!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5) 
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])

png

# Compute marginals
show(dist.marginals[0])

png

# Compute argmax
show(dist.argmax.detach()[0])

png

# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples 
show(dist.sample((1,)).detach()[0, 0])

png

# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])

png

png

# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) 
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))

png

Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

  • LinearChainCRF
  • SemiMarkovCRF
  • DependencyCRF
  • NonProjectiveDependencyCRF
  • TreeCRF
  • NeuralPCFG / NeuralHMM

Each distribution includes:

  • Argmax, sampling, entropy, partition, masking, log_probs, k-max

Extensions:

  • Integration with torchtext, pytorch-transformers, dgl
  • Adapters for generative structured models (CFG / HMM / HSMM)
  • Common tree structured parameterizations TreeLSTM / SpanLSTM

Low-level API:

Everything implemented through semiring dynamic programming.

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop
  • Entropy and first-order semirings.

Examples

Citation

@misc{alex2020torchstruct,
    title={Torch-Struct: Deep Structured Prediction Library},
    author={Alexander M. Rush},
    year={2020},
    eprint={2002.00876},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
Comments
  • add tests for CKY

    add tests for CKY

    This PR fixes several bugs in k-best parsing with dist.topk() and includes a simple test to test the function.

    I made incremental changes so that existing modules relying on the CKY will not be affected.

    opened by zhaoyanpeng 8
  • 1st order cky implementation

    1st order cky implementation

    Hi,

    I'd like to contribute this implementation of a first-order cky-style crf with anchored rule potentials: $\phi[i,j,k,A,B,C] := \phi(A_{i,j} \rightarrow B_{i,k}, C{k+1,j})$.

    I also added code to the _Struct class that allows calculating marginals even if input tensors don't require a gradient (i.e., after model.eval())

    Please let me know if you'd like to see any changes.

    Thanks, Tom

    opened by teffland 6
  • Mini-batch setting with Semi Markov CRF

    Mini-batch setting with Semi Markov CRF

    I encounter learning instability when using a batch size > 1 with the semi-markovian CRF (loss goes to very large negative number), even when explicitly providing "lengths". I think the bug comes from the masking. The model train well when setting batch size 1.

    opened by urchade 5
  • Release on PyPI?

    Release on PyPI?

    Is there any interest on releasing pytorch-struct (and genbmm) on the official Python Package Index?

    I ran into this because I distribute my constituency parser on PyPI, and I just recently pushed a new version that depends on pytorch-struct: https://pypi.org/project/benepar/0.2.0a0/

    It turns out that packages on PyPI aren't allowed to depend on packages only hosted on github, so users of my parser can't just pip install benepar and have it work right away.

    opened by nikitakit 5
  • up sweep and down sweep

    up sweep and down sweep

    I'm interested in the parallel scan algorithm for the linear-chain CRF.

    I read the related paper in the tutorial and found that there are two steps: up sweep and down sweep in order to obtain all-prefix-sum.

    I think in this case, we use that algorithm to obtain all Z(x) with different lengths in a batch. But seems I couldn't find out the down sweep code in the repo. Can you point me out there?

    opened by allanj 5
  • [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    Hey, I found that your implementation of Eisner's algorithm admits arbitrary root number, which is a very severe bug since dependency parsing usually has only one root token.

    In your DepTree.dp() method, you make a conversion to let the root token as the first token in the sentence. Imagine that the root x{0} attacks word x_{i}, I_{0,0} + C_{1, i} = I_{0, i} and I_{0, i} + C_{i,j} = C_{0, j} for some j < L where L is the length of sentence. Now complete span C_{0, j} still have opportunity to attach a new word x_{k} for j< k<=L, making multiple root attachment possible.

    Fortunately, I made some changes to your codes to restrict the root number to 1.

    ` def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True): if arc_scores_in.dim() not in (3, 4): raise ValueError("potentials must have dim of 3 (unlabeled) or 4 (labeled)")

        labeled = arc_scores_in.dim() == 4
        semiring = self.semiring
        # arc_scores_in = _convert(arc_scores_in)
        arc_scores_in, batch, N, lengths = self._check_potentials(
            arc_scores_in, lengths
        )
        arc_scores_in.requires_grad_(True)
        arc_scores = semiring.sum(arc_scores_in) if labeled else arc_scores_in
        alpha = [
            [
                [
                    Chart((batch, N, N), arc_scores, semiring, cache=cache)
                    for _ in range(2)
                ]
                for _ in range(2)
            ]
            for _ in range(2)
        ]
    
        semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
        semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
        semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
        semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)
    
    
        for k in range(1, N):
            f = torch.arange(N - k), torch.arange(k, N)
            ACL = alpha[A][C][L][: N - k, :k]
            ACR = alpha[A][C][R][: N - k, :k]
            BCL = alpha[B][C][L][k:, N - k :]
            BCR = alpha[B][C][R][k:, N - k :]
            x = semiring.dot(ACR, BCL)
            arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])
            alpha[A][I][L][: N - k, k] = arcs_l
            alpha[B][I][L][k:N, N - k - 1] = arcs_l
            arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
            alpha[A][I][R][:N - k, k] = arcs_r
            alpha[B][I][R][k:N, N - k - 1] = arcs_r
            AIR = alpha[A][I][R][: N - k, 1 : k + 1]
            BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]
            new = semiring.dot(ACL, BIL)
            alpha[A][C][L][: N - k, k] = new
            alpha[B][C][L][k:N, N - k - 1] = new
            new = semiring.dot(AIR, BCR)
            alpha[A][C][R][: N - k, k] = new
            alpha[B][C][R][k:N, N - k - 1] = new
    
        root_incomplete_span = semiring.times(alpha[A][C][L][0, :], arc_scores[:, :, torch.arange(N), torch.arange(N)])
        root =  [ Chart((batch,), arc_scores, semiring, cache=cache) for _ in range(N)]
        for k in range(N):
            AIR = root_incomplete_span[:, :, :k+1]
            BCR = alpha[B][C][R][k, N - (k+1):]
            root[k] = semiring.dot(AIR, BCR)
        v = torch.stack([root[l-1][:,i] for i, l in enumerate(lengths)], dim=1)
        return v, [arc_scores_in], alpha
    

    `

    Basically, I don't treat the first token as root anymore. I handle the root token just after the for-loop, so you may need handle the length variable. (length = length-1, root no longer be treated as part of sentence) . I tested the modified code and found it bug-free

    opened by sustcsonglin 4
  • Inference for the HMM model

    Inference for the HMM model

    Hello! I was playing with the HMM distribution and I obtained some results that I don't really understand. More precisely, I've set the following parameters

    t = torch.tensor([[0.99, 0.01], [0.01, 0.99]]).log()
    e = torch.tensor([[0.50, 0.50], [0.50, 0.50]]).log()
    i = torch.tensor(np.array([0.99, 0.01])).log()
    x = torch.randint(0, 2, size=(1, 8))
    

    and I was expecting the model to stay in the hidden state 0 regardless of the observed data x – it starts in state 0 and the transition matrix makes it very likely to maintain it. But when plotting the argmax, it appears that the model jumps from one state to the other:

    def show_chain(chain):
        plt.imshow(chain.detach().sum(-1).transpose(0, 1))
    
    dist = torch_struct.HMM(t, e, i, x)
    show_chain(dist.argmax[0])
    

    image

    I must be missing something obvious; but shouldn't dist.argmax correspond to argmax_z p(z | x, Θ)? Thank you!

    opened by danoneata 4
  • DependencyCRF partition function broken

    DependencyCRF partition function broken

    Getting the following in-place operation error when using the DependencyCRF:

    B,N = 3,50
    phi = torch.randn(B,N,N)
    DependencyCRF(phi).partition
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/deptree.py in _check_potentials(self, arc_scores, lengths)
        121         arc_scores = semiring.convert(arc_scores)
        122         for b in range(batch):
    --> 123             semiring.zero_(arc_scores[:, b, lengths[b] + 1 :, :])
        124             semiring.zero_(arc_scores[:, b, :, lengths[b] + 1 :])
        125 
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/semirings/semirings.py in zero_(xs)
        124     @staticmethod
        125     def zero_(xs):
    --> 126         return xs.fill_(-1e5)
        127 
        128     @staticmethod
    
    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
    
    opened by teffland 3
  • [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    Hi.

    Thank you for the great library. I have one question that I hope you could help with.

    How can I compute a marginal probability over a (contiguous) set of nodes? Right now, I am using your LinearChain-CRF to do NER. In addition to the best sequence itself, I also need to compute the model’s confidence in its predicted labeling over a segment of input. For example, what is the probability that a span of tokens constitute a person name?

    I read your example and see how you get the marginal prob for each individual node. But I was not quite sure how to compute the marginal prob over a subset of nodes. If you could give any hint, it would be great.

    Thank you.

    opened by kimdev95 3
  • Get the score of dist.topk()

    Get the score of dist.topk()

    The topk() function returns top k predictions from the distribution, how to easily get the corresponding score of each prediction?

    By the way, when sentence lengths are short and the k value of topk is large, how to know the number of predictions that are valid? For the example in DependencyCRF, when sentence length is 2 and k is 5, only the top 3 predictions are valid I think.

    opened by wangxinyu0922 3
  • Labeled projective dependency CRF

    Labeled projective dependency CRF

    This is work in progress and isn't ready to merge yet.

    This seems to work for partition, but argmax and marginals don't return as I expect. Both return tensor of shape (B, N, N); I'd expect them to return (B, N, N, L) tensors instead. Any advice?

    opened by kmkurn 3
  • [Question] How to apply pytorch-struct for 2 dimensional data?

    [Question] How to apply pytorch-struct for 2 dimensional data?

    I could find examples of pytorch struct usage for 1d sequence data like text or video frame. But I'm trying to parse tables structure in pdf documents.

    Could you provide some hints where to start?

    opened by YuriyPryyma 4
  • end_class support for Autoregressive

    end_class support for Autoregressive

    end_class is not used for the Autoregressive module: https://github.com/harvardnlp/pytorch-struct/blob/7146de5659ff17ad7be53023c025ffd099866412/torch_struct/autoregressive.py#L49

    opened by urchade 1
  • Update examples to use newer torchtext APIs

    Update examples to use newer torchtext APIs

    opened by erip 2
  • Instable learning with SemiMarkov CRF

    Instable learning with SemiMarkov CRF

    HI,

    First, thank you for fixing #110 (@da03), the SemiCRF works better now, I was able to get good results on span extraction tasks. However, I still encounter a learning instability where the loss (neg logprob) gets negative after several steps (and the accuracy starts to drop). The same problem occurs with batch_size = 1. Below I put the learning curve (f1_score and log loss).

    (Maybe the bug comes from the masking of spans where (length, length + span_with) and length + span_with > length, but I am not sure.)

    Edit: I created a test and it seems that the masking is good. Maybe the log_prob computation or the to_parts function ?

    train_loss score

    opened by urchade 0
  • fix bug- missing assignment of spans from sentCFG in documentation

    fix bug- missing assignment of spans from sentCFG in documentation

    Noticed a small bug in the documentation and example of SentCFG. The return of dist.argmax is (terms, rules, init, spans), but example in documentation only assigns (term, rules, init) and gives dim mismatch. As such when running the example it breaks. This fix resolves this issue.

    opened by jdegange 0
Releases(v0.5)
PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer

Cross-Covariance Image Transformer (XCiT) PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer L

Facebook Research 605 Jan 02, 2023
Saptak Bhoumik 14 May 24, 2022
Revisiting Pre-trained Models for Chinese Natural Language Processing (Findings of EMNLP 2020)

This repository contains the resources in our paper "Revisiting Pre-trained Models for Chinese Natural Language Processing", which will be published i

Yiming Cui 463 Dec 30, 2022
Prompt-learning is the latest paradigm to adapt pre-trained language models (PLMs) to downstream NLP tasks

Prompt-learning is the latest paradigm to adapt pre-trained language models (PLMs) to downstream NLP tasks, which modifies the input text with a textual template and directly uses PLMs to conduct pre

THUNLP 2.3k Jan 08, 2023
GPT-Code-Clippy (GPT-CC) is an open source version of GitHub Copilot, a language model

GPT-Code-Clippy (GPT-CC) is an open source version of GitHub Copilot, a language model -- based on GPT-3, called GPT-Codex -- that is fine-tuned on publicly available code from GitHub.

Nathan Cooper 2.3k Jan 01, 2023
Ukrainian TTS (text-to-speech) using Coqui TTS

title emoji colorFrom colorTo sdk app_file pinned Ukrainian TTS 🐸 green green gradio app.py false Ukrainian TTS 📢 🤖 Ukrainian TTS (text-to-speech)

Yurii Paniv 85 Dec 26, 2022
Fake news detector filters - Smart filter project allow to classify the quality of information and web pages

fake-news-detector-1.0 Lists, lists and more lists... Spam filter list, quality keyword list, stoplist list, top-domains urls list, news agencies webs

Memo Sim 1 Jan 04, 2022
CATs: Semantic Correspondence with Transformers

CATs: Semantic Correspondence with Transformers For more information, check out the paper on [arXiv]. Training with different backbones and evaluation

74 Dec 10, 2021
This github repo is for Neurips 2021 paper, NORESQA A Framework for Speech Quality Assessment using Non-Matching References.

NORESQA: Speech Quality Assessment using Non-Matching References This is a Pytorch implementation for using NORESQA. It contains minimal code to predi

Meta Research 36 Dec 08, 2022
Based on 125GB of data leaked from Twitch, you can see their monthly revenues from 2019-2021

Twitch Revenues Bu script'i kullanarak istediğiniz yayıncıların, Twitch'den sızdırılan 125 GB'lik veriye dayanarak, 2019-2021 arası aylık gelirlerini

4 Nov 11, 2021
A deep learning-based translation library built on Huggingface transformers

DL Translate A deep learning-based translation library built on Huggingface transformers and Facebook's mBART-Large 💻 GitHub Repository 📚 Documentat

Xing Han Lu 244 Dec 30, 2022
Code for EMNLP20 paper: "ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training"

ProphetNet-X This repo provides the code for reproducing the experiments in ProphetNet. In the paper, we propose a new pre-trained language model call

Microsoft 394 Dec 17, 2022
An automated program that helps customers of Pizza Palour place their pizza orders

PIzza_Order_Assistant Introduction An automated program that helps customers of Pizza Palour place their pizza orders. The program uses voice commands

Tindi Sommers 1 Dec 26, 2021
FB ID CLONER WUTHOT CHECKPOINT, FACEBOOK ID CLONE FROM FILE

* MY SOCIAL MEDIA : Programming And Memes Want to contact Mr. Error ? CONTACT : [ema

Mr. Error 9 Jun 17, 2021
translate using your voice

speech-to-text-translator Usage translate using your voice description this project makes translating a word easy, all you have to do is speak and...

1 Oct 18, 2021
Blue Brain text mining toolbox for semantic search and structured information extraction

Blue Brain Search Source Code DOI Data & Models DOI Documentation Latest Release Python Versions License Build Status Static Typing Code Style Securit

The Blue Brain Project 29 Dec 01, 2022
A very simple framework for state-of-the-art Natural Language Processing (NLP)

A very simple framework for state-of-the-art NLP. Developed by Humboldt University of Berlin and friends. IMPORTANT: (30.08.2020) We moved our models

flair 12.3k Dec 31, 2022
Transcribing audio files using Hugging Face's implementation of Wav2Vec2 + "chain-linking" NLP tasks to combine speech-to-text with downstream tasks like translation and summarisation.

PART 2: CHAIN LINKING AUDIO-TO-TEXT NLP TASKS 2A: TRANSCRIBE-TRANSLATE-SENTIMENT-ANALYSIS In notebook3.0, I demo a simple workflow to: transcribe a lo

Chua Chin Hon 30 Jul 13, 2022
Implementation of paper Does syntax matter? A strong baseline for Aspect-based Sentiment Analysis with RoBERTa.

RoBERTaABSA This repo contains the code for NAACL 2021 paper titled Does syntax matter? A strong baseline for Aspect-based Sentiment Analysis with RoB

106 Nov 28, 2022
Script to download some free japanese lessons in portuguse from NHK

Nihongo_nhk This is a script to download some free japanese lessons in portuguese from NHK. It can be executed by installing the packages with: pip in

Matheus Alves 2 Jan 06, 2022