Collection of NLP model explanations and accompanying analysis tools

Overview

logo

Thermostat is a large collection of NLP model explanations and accompanying analysis tools.

  • Combines explainability methods from the captum library with Hugging Face's datasets and transformers.
  • Mitigates repetitive execution of common experiments in Explainable NLP and thus reduces the environmental impact and financial roadblocks.
  • Increases comparability and replicability of research.
  • Reduces the implementational burden.

This work is described in our paper accepted to EMNLP 2021 System Demonstrations :
Nils Feldhus, Robert Schwarzenberg, and Sebastian Mรถller.
Thermostat: A Large Collection of NLP Model Explanations and Analysis Tools. 2021.

arXiv pre-print available here: https://arxiv.org/abs/2108.13961

Installation

With pip

PyPI

pip install thermostat-datasets

Usage

Downloading a dataset requires just two lines of code:

import thermostat
data = thermostat.load("imdb-bert-lig")

Thermostat datasets can be addressed and loaded with an identifier string that contains three basic coordinates: Dataset, Model, and Explainer. In this example, the dataset is IMDb (sentiment analysis of movie reviews), the model is a BERT model fine-tuned on the IMDb data, the explanations are generated using a (Layer) Integrated Gradients explainer.

data then contains the following columns/features:

  • attributions (the attributions for each token for each data point; type: List of floats)
  • idx (the index of the instance in the dataset)
  • input_ids (the token IDs of the original dataset; type: List of ints)
  • label (the label of the original dataset; type: int)
  • predictions (the class logits of the classifier/downstream model; type: List of floats)

This is the raw content stored in each of the instances of data:

instance-contents

If we print data, we get more info such as the actual names of the dataset, the explainer and the model:

print(data)
> IMDb dataset, BERT model, Layer Integrated Gradients explanations
> Explainer: LayerIntegratedGradients
> Model: textattack/bert-base-uncased-imdb
> Dataset: imdb

Indexing an instance

We can simply index the loaded dataset like a list:

import thermostat
instance = thermostat.load("imdb-bert-lig")[429]

Visualizing attributions as a heatmap

We can apply .render() to every instance to display a heatmap visualization generated by the displaCy library.

instance.render()  # instance refers to the variable assigned in the last codebox

heatmap-html

Get simple tuple-based heatmap

The explanation attribute stores a tuple-based heatmap with the token, the attribution, and the token index as elements.

print(instance.explanation)  # instance refers to the variable assigned in the second to last codebox

> [('[CLS]', 0.0, 0),
 ('amazing', 2.3141794204711914, 1),
 ('movie', 0.06655970215797424, 2),
 ('.', -0.47832658886909485, 3),
 ('some', 0.15708176791667938, 4),
 ('of', -0.02931656688451767, 5),
 ('the', -0.08834744244813919, 6),
 ('script', -0.2660972774028778, 7),
 ('writing', -0.4021594822406769, 8),
 ('could', -0.19280624389648438, 9),
 ('have', -0.015477157197892666, 10),
 ('been', -0.21898044645786285, 11),
 ('better', -0.4095713794231415, 12),
 ...]  # abbreviated

The heatmap attribute displays it as a pandas table:

print(instance.heatmap)

> token_index    0         1          2         3          4         5    \
token        [CLS]         i       went       and        saw      this   
attribution      0 -0.117371  0.0849944  0.165192  0.0362542 -0.029687   
text_field    text      text       text      text       text      text   

token_index       6         7         8          9          10         11   \
token           movie      last     night      after      being     coaxed   
attribution  0.533126  0.240222  0.171116 -0.0450005 -0.0103401  0.0166524   
text_field       text      text      text       text       text       text   

token_index        13         14          15         16         17   \
token               to         by           a        few    friends   
attribution  0.0269605 -0.0213463  0.00761083  0.0216749  0.0579834   
text_field        text       text        text       text       text   

# abbreviated

Modifying the load function

thermostat.load() is a wrapper around datasets.load_dataset() and you can use any keyword arguments from load_dataset() in load(), too (except path, name and split which are reserved), e.g. if you want to use another cache directory, you can use the cache_dir argument in thermostat.load().


Explainers

Name captum implementation Parameters
Layer Gradient x Activation (lgxa) .attr.LayerGradientXActivation
Layer Integrated Gradients (lig) .attr.LayerIntegratedGradients # samples = 25
LIME (lime) .attr.LimeBase # samples = 25,
mask prob = 0.3
Occlusion (occ) .attr.Occlusion sliding window = 3
Shapley Value Sampling (svs) .attr.ShapleyValueSampling # samples = 25

Datasets + Models

Overview

โœ… = Dataset is downloadable
โ๏ธ = Dataset is finished, but not uploaded yet
๐Ÿ”„ = Currently running on cluster (x n = number of jobs/screens)
โš ๏ธ = Issue

IMDb

imdb is a sentiment analysis dataset with 2 classes (pos and neg). The available split is the test subset containing 25k examples.
Example configuration: imdb-xlnet-lig

Name ๐Ÿค— lgxa lig lime occ svs
ALBERT (albert) textattack/albert-base-v2-imdb โœ… โœ… โœ… โœ… โœ…
BERT (bert) textattack/bert-base-uncased-imdb โœ… โœ… โœ… โœ… โœ…
ELECTRA (electra) monologg/electra-small-finetuned-imdb โœ… โœ… โœ… โœ… โœ…
RoBERTa (roberta) textattack/roberta-base-imdb โœ… โœ… โœ… โœ… โœ…
XLNet (xlnet) textattack/xlnet-base-cased-imdb โœ… โœ… โœ… โœ… โœ…

MultiNLI

multi_nli is a textual entailment dataset. The available split is the validation_matched subset containing 9815 examples.
Example configuration: multi_nli-roberta-lime

Name ๐Ÿค— lgxa lig lime occ svs
ALBERT (albert) prajjwal1/albert-base-v2-mnli โœ… โœ… โœ… โœ… โœ…
BERT (bert) textattack/bert-base-uncased-MNLI โœ… โœ… โœ… โœ… โœ…
ELECTRA (electra) howey/electra-base-mnli โœ… โœ… โœ… โœ… โœ…
RoBERTa (roberta) textattack/roberta-base-MNLI โœ… โœ… โœ… โœ… โœ…
XLNet (xlnet) textattack/xlnet-base-cased-MNLI โœ… โœ… โœ… โœ… โœ…

XNLI

xnli is a textual entailment dataset. It provides the test set of MultiNLI through the "en" configuration. The fine-tuned models used here are the same as the MultiNLI ones. The available split is the test subset containing 5010 examples.
Example configuration: xnli-roberta-lime

Name ๐Ÿค— lgxa lig lime occ svs
ALBERT (albert) prajjwal1/albert-base-v2-mnli โœ… โœ… โœ… โœ… โœ…
BERT (bert) textattack/bert-base-uncased-MNLI โœ… โœ… โœ… โœ… โœ…
ELECTRA (electra) howey/electra-base-mnli โœ… โœ… โœ… โœ… โœ…
RoBERTa (roberta) textattack/roberta-base-MNLI โœ… โœ… โœ… โœ… โœ…
XLNet (xlnet) textattack/xlnet-base-cased-MNLI โœ… โœ… โœ… โœ… โœ…

AG News

ag_news is a news topic classification dataset. The available split is the test subset containing 7600 examples.
Example configuration: ag_news-albert-svs

Name ๐Ÿค— lgxa lig lime occ svs
ALBERT (albert) textattack/albert-base-v2-ag-news โœ… โœ… โœ… โœ… โœ…
BERT (bert) textattack/bert-base-uncased-ag-news โœ… โœ… โœ… โœ… โœ…
RoBERTa (roberta) textattack/roberta-base-ag-news โœ… โœ… โœ… โœ… โœ…

Contribute a dataset

New explanation datasets must follow the JSONL format and include the five fields attributions, idx, input_ids, label and predictions as described above in "Usage".

Please follow the instructions for writing a dataset loading script in the official docs of datasets.

Provide the additional Thermostat metadata via the list of builder configs (click here to see the Thermostat implementation of builder configs).

Necessary fields include...

  • name : The unique identifier string, e.g. including the three coordinates <DATASET>-<MODEL>-<EXPLAINER>
  • dataset : The full name of the dataset, usually follows the naming convention in datasets, e.g. "imdb"
  • explainer : The full name of the explainer, usually follows the naming convention in captum, e.g. "LayerIntegratedGradients"
  • model : The full name of the model, usually follows the naming convention in transformers, e.g. "textattack/bert-base-uncased-imdb"
  • label_column : The name of the column in the JSONL file that contains the label, usually "label"
  • label_classes : The list of label names or classes, e.g. ["entailment", "neutral", "contradiction"] for NLI datasets
  • text_column : Either a string (if there is only one text column) or a list of strings that identify the column in the JSONL file that contains the text(s), e.g. "text" (IMDb) or ["premise", "hypothesis"] (NLI)
  • description : Should at least state the full names of the three coordinates, can optionally include more info such as hyperparameter choices
  • data_url : The URL to the data storage, e.g. a Google Drive link

plus features which you can copy from the codebox below:

features={"attributions": "attributions",
          "predictions": "predictions",
          "input_ids": "input_ids"}

While debugging, you can wrap your data with the Thermopack class and see if it correctly parses your data:

import thermostat
from datasets import load_dataset
data = load_dataset('your_dataset')
thermostat.Thermopack(data)

If you're successful, follow the official instructions for sharing a community provided dataset at the HuggingFace hub.

At first, all Thermostat contributions will have to be loaded via the code example above. Please notify us of existing explanation datasets by creating an Issue with the tag Contribution and a maintainer of this repository will add your dataset to the Thermostat configs s.t. it can be accessed by everyone via thermostat.load().


Cite Thermostat

@inproceedings{feldhus2021thermostat,
    title={Thermostat: A Large Collection of NLP Model Explanations and Analysis Tools},
    author={Nils Feldhus and Robert Schwarzenberg and Sebastian Mรถller},
    year={2021},
    editor = {Heike Adel and Shuming Shi},
    booktitle = {Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing: System Demonstrations},
}

Disclaimer

We give no warranties for the correctness of the heatmaps or any other part of the data. This is evolving work and will be hot-patched continuously.

The Thermostat project follows the ACL and ACM Code of Ethics.

Acknowledgements

The majority of the codebase, especially regarding the combination of transformers and captum, stems from our other recent project Empirical Explainers.

Comments
  • AttributeError: 'XLNetModel' object has no attribute 'embeddings'

    AttributeError: 'XLNetModel' object has no attribute 'embeddings'

    The following error was raised for "mnli-xlnet-lgxa":

    Traceback (most recent call last):
      File "run_explainer.py", line 57, in <module>
        explainer = getattr(thermex, f'Explainer{config["explainer"]["name"]}').from_config(config=config)
      File "/home/feldhus/thermostat/src/thermostat/explainers/grad.py", line 71, in from_config
        layer=res.model.base_model.embeddings)
      File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1094, in __getattr__
        raise AttributeError("'{}' object has no attribute '{}'".format(
    AttributeError: 'XLNetModel' object has no attribute 'embeddings'
    
    opened by nfelnlp 6
  • LIME :

    LIME : "normal_kernel_cuda" not implemented for 'Long'

    Some CUDA-related type error occurred at attribution time.

    2021-04-09 17:15:04,043 -explain - INFO - (Progress) Loaded explainer
    2021-04-09 17:15:04,043 -explain - INFO - (Progress) Initialized data loader
      0%|                                                 | 0/12500 [00:00<?, ?it/s]2021-04-09 17:15:04,052 -explain - INFO - (Progress) Processing batch 0 / instance 0
    50
      0%|                                                 | 0/12500 [00:12<?, ?it/s]
    Traceback (most recent call last):
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/attr/_core/lime.py", line 415, in attribute
        curr_sample = self.perturb_func(inputs, **kwargs)
      File "/home/nfel/PycharmProjects/thermostat/src/thermostat/explainers/lime.py", line 43, in perturb_func
        return original_input + torch.randn_like(original_input)
    RuntimeError: "normal_kernel_cuda" not implemented for 'Long'
    
    bug 
    opened by nfelnlp 4
  • Prediction Values in Dataset

    Prediction Values in Dataset

    Hello, I am experimenting with the attribution values of Layer Integrated Gradients on the AlbertV2 model on the imdb dataset. Along the way, I noticed a mismatch between the predictions saved in each instance in a dataset and the output of the model if loaded and called separately. I hope the following code sample makes clear what I mean:

    import thermostat
    import torch
    from transformers import AutoModelForSequenceClassification
    
    data = thermostat.load("imdb-albert-lig")
    
    albert = AutoModelForSequenceClassification.from_pretrained(data.model_name, return_dict=False)
    albert.eval()
    
    sliced_data = data[:10]
    thermostat_preds = [instance.predictions for instance in sliced_data]
    
    batch_input = torch.tensor([instance.input_ids for instance in sliced_data])
    preds = albert(batch_input)[0]
    
    print("thermostat  ---  new inference".center(89, " "))
    for m, n in zip(thermostat_preds, preds.tolist()):
        print(m, " --- ", n)
    

    outputs:

    Loading Thermostat configuration: imdb-albert-lig
    Reusing dataset thermostat (/home/tim/.cache/huggingface/datasets/thermostat/imdb-albert-lig/1.0.1/0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b)
                                  thermostat  ---  new inference                             
    [-2.9755005836486816, 3.422632932662964]  ---  [0.5287008285522461, 0.11485755443572998]
    [-2.1383304595947266, 2.6063592433929443]  ---  [0.5254676938056946, 0.11936521530151367]
    [-2.891936779022217, 3.3106441497802734]  ---  [0.5285712480545044, 0.11505177617073059]
    [-3.0642969608306885, 3.240943670272827]  ---  [0.5282390117645264, 0.11541029810905457]
    [-3.1076266765594482, 3.050632953643799]  ---  [0.5288912057876587, 0.11466512084007263]
    [-2.8576371669769287, 3.023214101791382]  ---  [0.5283359885215759, 0.12018033862113953]
    [-1.8885599374771118, 2.4350857734680176]  ---  [0.5438214540481567, 0.12777245044708252]
    [-1.5720579624176025, 2.051628589630127]  ---  [0.5284878611564636, 0.11513140797615051]
    [-3.2173707485198975, 3.510160207748413]  ---  [0.5282713174819946, 0.11538228392601013]
    [-2.5653769969940186, 3.0336244106292725]  ---  [-2.5653772354125977, 3.033625841140747]
    

    I only checked for that particular dataset and model, but I am wondering what causes the mismatch here? My suspicion is that either the model changed or the input_ids from the dataset do not correspond to the exact input used to create the dataset. Maybe you can shed some light on this? Thank you already for this useful library and if there are further questions please contact me, Tim

    opened by tpatzelt 3
  • LIME token similarity kernel for input.shape[0] > 1

    LIME token similarity kernel for input.shape[0] > 1

    The assertion in the token_similarity_kernel function of ExplainerLimeBase

    assert original_input.shape[0] == perturbed_input.shape[0]  == 1
    

    https://github.com/nfelnlp/thermostat/blob/24177342945e834552a6df956ae59fdf1e69335b/src/thermostat/explainers/lime.py#L47

    only works for IMDB so far. An error is thrown for MNLI, so I debugged it and found out that the two input shapes can still be equal, although they're not exactly 1. I assume this is because it has two text fields ("premise", "hypothesis") instead of one. The calculation below can still be performed with .shape[0]==2.

    Do you think removing the == 1 at the end of the assertion would be fine?

    help wanted 
    opened by nfelnlp 3
  • [InputXGradient] RuntimeError: One of the differentiated Tensors does not require grad

    [InputXGradient] RuntimeError: One of the differentiated Tensors does not require grad

    Stuck at this error while implementing InputXGradient. Tested on DistilBERT and RoBERTa.

    Traceback (most recent call last):
      File "/home/nfel/.pycharm_helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
        exec(exp, global_vars, local_vars)
      File "<input>", line 1, in <module>
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/log/__init__.py", line 35, in wrapper
        return func(*args, **kwargs)
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/attr/_core/input_x_gradient.py", line 117, in attribute
        gradients = self.gradient_func(
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/_utils/gradient.py", line 125, in compute_gradients
        grads = torch.autograd.grad(torch.unbind(outputs), inputs)
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/torch/autograd/__init__.py", line 223, in grad
        return Variable._execution_engine.run_backward(
    RuntimeError: One of the differentiated Tensors does not require grad
      0%|                                                  | 0/1821 [07:25<?, ?it/s]
    Traceback (most recent call last):
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/_utils/gradient.py", line 125, in compute_gradients
        grads = torch.autograd.grad(torch.unbind(outputs), inputs)
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/torch/autograd/__init__.py", line 223, in grad
        return Variable._execution_engine.run_backward(
    RuntimeError: One of the differentiated Tensors does not require grad
    
    
    bug 
    opened by nfelnlp 2
  • MNLI and XNLI downstream model performance very low

    MNLI and XNLI downstream model performance very low

    IMDb and AG News accuracies when comparing true labels to predicted labels are reasonably high. However, almost all MNLI and XNLI models investigated so far have an extremely low accuracy.

    Here are some random configs that I investigated: multi_nli-albert-occ : 6.48% multi_nli-roberta-occ : 28.72% multi_nli-xlnet-occ : 5.53% multi_nli-bert-lime : 6.46 % xnli-bert-lig : 7.05%

    Curiously, the ELECTRA model (submitted by a different person than all the other models) are not affected: multi_nli-electra-lgxa : 88.7% xnli-electra-lig : 88.3%

    It's not clear to me yet if this is simply an issue stemming from an outdated label order as documented here: https://github.com/huggingface/transformers/pull/10203/

    However, the label orders that are mentioned there, both old and new, are also different from the label assignment given by datasets: https://huggingface.co/datasets/viewer/?dataset=xnli (multi_nli has the same labels)

    Note that the datasets version we use in thermostat is 1.5.0, so this might be a problem with the downstream models predicting different labels.

    Then I investigated in what way the predictions and labels of such a faulty model (xnli-bert-lgxa) align:

    import thermostat
    from collections import Counter
    
    bert_lgxa = thermostat.load("xnli-bert-lgxa")
    true_pred_comp_bert_lgxa = [(b_i.true_label['index'], b_i.predicted_label['index']) for b_i in bert_lgxa]
    Counter(true_pred_comp_bert_lgxa)
    
    >>> Counter({(2, 0): 1488,
             (0, 1): 1345,
             (1, 0): 153,
             (1, 2): 1411,
             (0, 2): 231,
             (0, 0): 93,
             (1, 1): 107,
             (2, 2): 153,
             (2, 1): 29})
    

    This leads me to assume that (2, 0), (0, 1) and (1, 2) are actually the correct predictions and (0, 0), (1, 1) and (2, 2) are wrong ones. If we sum up (2, 0), (0, 1) and (1, 2) and divide it by the sum of all (or the length of the XNLI dataset), we end up at 84.71% which is a much more reasonable number in my opinion.

    At the very least, this means that all thermostat subsets concerning MNLI and XNLI need to be redone (editing JSONL files and reuploading). Hopefully, this only means going through each JSONL and changing the values either by:

    1. Changing the true labels to the old standard. However, this means that we do not use the vanilla data from datasets anymore.
    2. Changing the predicted labels as well as the logits.

    I'm pretty positive that we don't need to run the explanations again.


    On a sidenote, I also considered the encode_pair function (which is only used by MNLI and XNLI in thermostat) not working correctly, but couldn't find any reference implementation stating that the way the two text fields are ingested might be wrong.

    invalid 
    opened by nfelnlp 1
  • ExplainerCaptum.get_inputs_and_additional_args and .get_forward_function need to be extensible

    ExplainerCaptum.get_inputs_and_additional_args and .get_forward_function need to be extensible

    if name_model in ['bert-base-cased', 'xlnet-base-cased']:
        assert 'input_ids' in batch, f'Input ids expected for {name_model} but not found.'
        assert 'attention_mask' in batch, f'Attention mask expected for {name_model} but not found.'
        assert 'token_type_ids' in batch, f'Token type ids expected for model {name_model} but not found.'
        input_ids = batch['input_ids']
        additional_forward_args = (batch['attention_mask'], batch['token_type_ids'])
        return input_ids, additional_forward_args
    elif name_model == 'textattack/roberta-base-imdb':  # TODO: Separate classes?
        assert 'input_ids' in batch, f'Input ids expected for {name_model} but not found.'
        assert 'special_tokens_mask' in batch, f'Special tokens mask expected for {name_model} but not found.'
        assert 'attention_mask' in batch, f'Attention mask expected for {name_model} but not found.'
        input_ids = batch['input_ids']
        additional_forward_args = (batch['special_tokens_mask'], batch['attention_mask'])
        return input_ids, additional_forward_args
    else:
        raise NotImplementedError
    

    This way of putting every case where the batch encoding is different together into one class is not manageable in the future.

    see commit

    enhancement 
    opened by nfelnlp 1
  • Selection of experiment_in unsafe

    Selection of experiment_in unsafe

    experiment_in = [f for f in os.listdir(experiment_path) if "preprocess" in f and f.endswith('.jsonl')][0]
    

    can accidentally select another empty explainer file (from previous, sometimes unsuccessful debug runs). This needs to be fixed! Probably by introducing config files (like in gxai) where the experiment inputs (processed data) is hard-coded, i.e. includes the specific files with timestamps.

    see commit

    invalid 
    opened by nfelnlp 1
  • Improve heatmap visualization options (more color schemes, better readability)

    Improve heatmap visualization options (more color schemes, better readability)

    For once, there could be an option to choose different colors instead of red and blue as inspired by this site.

    https://github.com/nfelnlp/thermostat/blob/b7f8e829ca2927121c7fa24eb977ecb91a92d017/src/thermostat/visualize.py#L46

    On top of that, the current color scheme is not perfect for readability, because some of the most salient words (very red or very blue) still have black text and cannot be read properly. Maybe it would be better to have a colored border around the word instead?

    https://github.com/nfelnlp/thermostat/blob/b7f8e829ca2927121c7fa24eb977ecb91a92d017/src/thermostat/visualize.py#L147-L152

    opened by nfelnlp 0
  • Missing assertion that num_labels in dataset corresponds to classification head shape of model

    Missing assertion that num_labels in dataset corresponds to classification head shape of model

    The alignment of the classification head with the number of labels in datasets has not been a problem so far, but I left a TODO here to insert an assert that should check if res.num_labels corresponds with the shape of the loaded model's classification head.

    https://github.com/nfelnlp/thermostat/blob/b7f8e829ca2927121c7fa24eb977ecb91a92d017/src/thermostat/explain.py#L137-L138

    opened by nfelnlp 0
  • Find out if the cast to long for batch components in forward functions is necessary

    Find out if the cast to long for batch components in forward functions is necessary

    https://github.com/nfelnlp/thermostat/blob/a8180a2d83e1c3ec5f873dbf0ce0ab14026cf6bf/src/thermostat/explain.py#L54

            def bert_forward(input_ids, attention_mask, token_type_ids):
                input_model = {
                    'input_ids': input_ids.long(),
                    'attention_mask': attention_mask.long(),
                    'token_type_ids': token_type_ids.long(),
                }
                output_model = model(**input_model)[0]
                return output_model
    
            def roberta_forward(input_ids, attention_mask):
                input_model = {
                    'input_ids': input_ids.long(),
                    'attention_mask': attention_mask.long(),
                }
                output_model = model(**input_model)[0]
                return output_model
    
    question 
    opened by nfelnlp 1
  • [KernelShap] ZeroDivisionError: Weights sum to zero, can't be normalized

    [KernelShap] ZeroDivisionError: Weights sum to zero, can't be normalized

    KernelShap is not working yet and gave the following error:

    Traceback (most recent call last):
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/attr/_core/lime.py", line 484, in attribute
        self.interpretable_model.fit(DataLoader(dataset, batch_size=n_samples))
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/_utils/models/linear_model/model.py", line 303, in fit
        return super().fit(train_data=train_data, **kwargs)
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/_utils/models/linear_model/model.py", line 260, in fit
        return super().fit(
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/_utils/models/linear_model/model.py", line 117, in fit
        return self.train_fn(
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/captum/_utils/models/linear_model/train.py", line 329, in sklearn_train_linear_model
        sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/sklearn/linear_model/_base.py", line 525, in fit
        X, y, X_offset, y_offset, X_scale = self._preprocess_data(
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/sklearn/linear_model/_base.py", line 162, in _preprocess_data
        X_offset = np.average(X, axis=0, weights=sample_weight)
      File "<__array_function__ internals>", line 5, in average
      File "/home/nfel/PycharmProjects/thermostat/venv/lib/python3.8/site-packages/numpy/lib/function_base.py", line 409, in average
        raise ZeroDivisionError(
    ZeroDivisionError: Weights sum to zero, can't be normalized
    
    bug 
    opened by nfelnlp 1
Releases(1.0.2)
Owner
Speech and Language Technology (SLT) Group of the Berlin lab of the German Research Center for Artificial Intelligence (DFKI)
SPTAG: A library for fast approximate nearest neighbor search

SPTAG: A library for fast approximate nearest neighbor search SPTAG SPTAG (Space Partition Tree And Graph) is a library for large scale vector approxi

Microsoft 4.3k Jan 01, 2023
The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper.

Intermdiate layer matters - SSL The official repository for "Intermediate Layers Matter in Momentum Contrastive Self Supervised Learning" paper. Downl

Aakash Kaku 35 Sep 19, 2022
A library for finding knowledge neurons in pretrained transformer models.

knowledge-neurons An open source repository replicating the 2021 paper Knowledge Neurons in Pretrained Transformers by Dai et al., and extending the t

EleutherAI 96 Dec 21, 2022
Full-featured Decision Trees and Random Forests learner.

CID3 This is a full-featured Decision Trees and Random Forests learner. It can save trees or forests to disk for later use. It is possible to query tr

Alejandro Penate-Diaz 3 Aug 15, 2022
Airborne magnetic data of the Osborne Mine and Lightning Creek sill complex, Australia

Osborne Mine, Australia - Airborne total-field magnetic anomaly This is a section of a survey acquired in 1990 by the Queensland Government, Australia

Fatiando a Terra Datasets 1 Jan 21, 2022
Weakly supervised medical named entity classification

Trove Trove is a research framework for building weakly supervised (bio)medical named entity recognition (NER) and other entity attribute classifiers

60 Nov 18, 2022
A simple implementation of Kalman filter in Multi Object Tracking

kalman Filter in Multi-object Tracking A simple implementation of Kalman filter in Multi Object Tracking ๆœฌๅฎž็Žฐๆ˜ฏๅœจhttps://github.com/liuchangji/kalman-fil

124 Dec 29, 2022
[AAAI 2021] MVFNet: Multi-View Fusion Network for Efficient Video Recognition

MVFNet: Multi-View Fusion Network for Efficient Video Recognition (AAAI 2021) Overview We release the code of the MVFNet (Multi-View Fusion Network).

Wenhao Wu 114 Nov 27, 2022
CausaLM: Causal Model Explanation Through Counterfactual Language Models

CausaLM: Causal Model Explanation Through Counterfactual Language Models Authors: Amir Feder, Nadav Oved, Uri Shalit, Roi Reichart Abstract: Understan

Amir Feder 39 Jul 10, 2022
I3-master-layout - Simple master and stack layout script

Simple master and stack layout script | ------ | ----- | | | | | Ma

Tobias S 18 Dec 05, 2022
Empowering journalists and whistleblowers

Onymochat Empowering journalists and whistleblowers Onymochat is an end-to-end encrypted, decentralized, anonymous chat application. You can also host

Samrat Dutta 19 Sep 02, 2022
Analyzing basic network responses to novel classes

novelty-detection Analyzing how AlexNet responds to novel classes with varying degrees of similarity to pretrained classes from ImageNet. If you find

Noam Eshed 34 Oct 02, 2022
Time Series Forecasting with Temporal Fusion Transformer in Pytorch

Forecasting with the Temporal Fusion Transformer Multi-horizon forecasting often contains a complex mix of inputs โ€“ including static (i.e. time-invari

Nicolรกs Fornasari 6 Jan 24, 2022
TRACER: Extreme Attention Guided Salient Object Tracing Network implementation in PyTorch

TRACER: Extreme Attention Guided Salient Object Tracing Network This paper was accepted at AAAI 2022 SA poster session. Datasets All datasets are avai

Karel 118 Dec 29, 2022
PyTorch implementation for the paper Pseudo Numerical Methods for Diffusion Models on Manifolds

Pseudo Numerical Methods for Diffusion Models on Manifolds (PNDM) This repo is the official PyTorch implementation for the paper Pseudo Numerical Meth

Luping Liu (ๅˆ˜่ทฏๅนณ) 196 Jan 05, 2023
Trax โ€” Deep Learning with Clear Code and Speed

Trax โ€” Deep Learning with Clear Code and Speed Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively us

Google 7.3k Dec 26, 2022
Code and datasets for the paper "Combining Events and Frames using Recurrent Asynchronous Multimodal Networks for Monocular Depth Prediction" (RA-L, 2021)

Combining Events and Frames using Recurrent Asynchronous Multimodal Networks for Monocular Depth Prediction This is the code for the paper Combining E

Robotics and Perception Group 69 Dec 26, 2022
GUPNet - Geometry Uncertainty Projection Network for Monocular 3D Object Detection

GUPNet This is the official implementation of "Geometry Uncertainty Projection Network for Monocular 3D Object Detection". citation If you find our wo

Yan Lu 103 Dec 28, 2022
Official page of Struct-MDC (RA-L'22 with IROS'22 option); Depth completion from Visual-SLAM using point & line features

Struct-MDC (click the above buttons for redirection!) Official page of "Struct-MDC: Mesh-Refined Unsupervised Depth Completion Leveraging Structural R

Urban Robotics Lab. @ KAIST 37 Dec 22, 2022
Relative Positional Encoding for Transformers with Linear Complexity

Stochastic Positional Encoding (SPE) This is the source code repository for the ICML 2021 paper Relative Positional Encoding for Transformers with Lin

Antoine Liutkus 48 Nov 16, 2022