PyTorch-centric library for evaluating and enhancing the robustness of AI technologies

Overview

Responsible AI Toolbox

PyPI Python version support GitHub Actions Tested with Hypothesis

A library that provides high-quality, PyTorch-centric tools for evaluating and enhancing both the robustness and the explainability of AI models.

Check out our documentation for more information.

The rAI-toolbox works great with PyTorch Lightning and Hydra 🐉 . Check out rai_toolbox.mushin to see how we use these frameworks to create efficient, configurable, and reproducible ML workflows with minimal boilerplate code.

Citation

Using rai_toolbox for your research? Please cite the following publication:

@article{soklaski2022tools,
  title={Tools and Practices for Responsible AI Engineering},
  author={Soklaski, Ryan and Goodwin, Justin and Brown, Olivia and Yee, Michael and Matterer, Jason},
  journal={arXiv preprint arXiv:2201.05647},
  year={2022}
}

Contributing

If you would like to contribute to this repo, please refer to our CONTRIBUTING.md document.

Disclaimer

DISTRIBUTION STATEMENT A. Approved for public release. Distribution is unlimited.

© 2022 MASSACHUSETTS INSTITUTE OF TECHNOLOGY

  • Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014)
  • SPDX-License-Identifier: MIT

This material is based upon work supported by the Under Secretary of Defense for Research and Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the Under Secretary of Defense for Research and Engineering.

A portion of this research was sponsored by the United States Air Force Research Laboratory and the United States Air Force Artificial Intelligence Accelerator and was accomplished under Cooperative Agreement Number FA8750-19-2-1000. The views and conclusions contained in this document are those of the authors and should not be interpreted as representing the official policies, either expressed or implied, of the United States Air Force or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes notwithstanding any copyright notation herein.

The software/firmware is provided to you on an As-Is basis.

Comments
  • Update workflows

    Update workflows

    See example use here: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/workflow-docs/examples/MNIST-Translation-Robustness.ipynb

    • [x] Create base class for workflows
    • [x] Update docs
    • [x] Create tests for workflows
    opened by jgbos 7
  • Strange computational graph issue with `gradient_ascent` and `LightningModule`

    Strange computational graph issue with `gradient_ascent` and `LightningModule`

    First here's a working simple example of running gradient_ascent that works without error:

    from functools import partial
    import torch as tr
    from torchvision import models
    from rai_toolbox.optim import L2ProjectedOptim
    from rai_toolbox.perturbations.solvers import gradient_ascent
    
    model = models.resnet18()
    data = tr.rand(10, 3, 100, 100, dtype=tr.float)
    target = tr.randint(0, 2, size=(10,))
    pert = partial(
        gradient_ascent, optimizer=L2ProjectedOptim, epsilon=1.0, steps=1, lr=1.0
    )
    
    # run gradient ascent
    pert(model=model, data=data, target=target)
    

    Now setup and run the same thing using Trainer.predict:

    import pytorch_lightning as pl
    
    class Lit(pl.LightningModule):
        def __init__(self):
            super().__init__()
            self.model = model
            self.pert = pert
    
        def predict_step(self, batch, *args, **kwargs):
            data, target = batch
            data = self.pert(model=self.model, data=data, target=target)
            logits = self.model(data)
            return logits.sum()
    
    trainer = pl.Trainer()
    trainer.predict(
        Lit(),
        datamodule=pl.LightningDataModule.from_datasets(
            predict_dataset=tr.utils.data.TensorDataset(data, target),
            batch_size=1,
            num_workers=0,
        ),
    )
    

    Here we get the following error:

    ...
    /tmp/ipykernel_74682/1909129363.py in predict_step(self, batch, *args, **kwargs)
         27     def predict_step(self, batch, *args, **kwargs):
         28         data, target = batch
    ---> 29         data = self.pert(model=self, data=data, target=target)
         30         logits = self.model(data)
         31         return logits.sum()
    
    ~/projects/raiden/rai_toolbox/src/rai_toolbox/perturbations/solvers.py in gradient_ascent(model, data, target, optimizer, steps, perturbation_model, targeted, use_best, criterion, reduction_fn, **optim_kwargs)
        277             # Update the perturbation
        278             optim.zero_grad(set_to_none=True)
    --> 279             loss.backward()
        280             optim.step()
        281 
    
    ~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
        394                 create_graph=create_graph,
        395                 inputs=inputs)
    --> 396         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
        397 
        398     def register_hook(self, hook):
    
    ~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
        171     # some Python versions print out the first line of a multi-line function
        172     # calls in the traceback and some print out the last line
    --> 173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
        174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
        175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    
    RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
    

    If I enter debug everything seems to be setup correctly except that pmodel(data) does not return a tensor with grad_fn!!

    #
    # pdb at `loss.backward()` line
    #
    > pmodel.delta.requires_grad
    True
    
    > tr.is_grad_enabled()
    True
    
    > pmodel.delta + data
    ... # tensor output without `grad_fn`
    
    # try reinitializing
    > perturbation_model(data)(data)
    ... # tensor output WITH `grad_fn`
    

    I have no idea how to debug this and find out what is wrong.

    @rsokl do you get this error in your environment?

    opened by jgbos 6
  • Docs: perturbation explanation

    Docs: perturbation explanation

    Starting an explanation on our approach to data perturbations. I still intend to add more to this today, but feel free to take a look and let me know your thoughts on how it's going so far. Especially what should/shouldn't be included in this

    opened by oliviamb 4
  • CIFAR10-Adversarial-Perturbations.ipynb -- Standard rai-toolbox[mushin] install doesn't include dill module

    CIFAR10-Adversarial-Perturbations.ipynb -- Standard rai-toolbox[mushin] install doesn't include dill module

    CIFAR10-Adversarial-Perturbations.ipynb example jupyter notebook attempts to load the pretrained models and fails with ModuleNotFoundError: No module named 'dill'. Dill module is not included in the standard rai-toolbox[mushin] install.

    Full error traceback below:

    ModuleNotFoundError                       Traceback (most recent call last)
    Input In [9], in <cell line: 3>()
          1 # Load pretrained model that was trained using a robust approach (i.e., adversarial training)
          2 ckpt_robust = "mitll_cifar_l2_1_0.pt"
    ----> 3 model_robust = load_model(ckpt_robust)
          4 model_robust.eval();
          6 # Load pretrained model that was trained with standard approach
    
    Input In [7], in load_model(ckpt)
          2 def load_model(ckpt):
    ----> 3     base_model = load_from_checkpoint(
          4         model = resnet50(),
          5         ckpt = ckpt,
          6         weights_key="state_dict",
          7     )
          9     normalizer = transforms.Normalize(
         10         mean=[0.4914, 0.4822, 0.4465],
         11         std=[0.2023, 0.1994, 0.2010],
         12     )
         14     model = nn.Sequential(normalizer, base_model)
    
    File ~/dev/rai-toolbox-james/responsible-ai-toolbox/src/rai_toolbox/mushin/_utils.py:60, in load_from_checkpoint(model, ckpt, weights_key, weights_key_strip, model_attr)
         57     ckpt = Path.home() / ".torch" / "models" / ckpt
         58 log.info(f"Loading model checkpoint from {ckpt}")
    ---> 60 ckpt_data: Dict[str, Any] = torch.load(ckpt, map_location="cpu")
         62 if weights_key is not None:
         63     assert weights_key in ckpt_data
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:713, in load(f, map_location, pickle_module, **pickle_load_args)
        711             return torch.jit.load(opened_file)
        712         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    --> 713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:930, in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
        928 unpickler = UnpicklerWrapper(f, **pickle_load_args)
        929 unpickler.persistent_load = persistent_load
    --> 930 result = unpickler.load()
        932 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
        934 offset = f.tell() if f_should_read_directly else None
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:746, in _legacy_load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
        744     except KeyError:
        745         pass
    --> 746 return super().find_class(mod_name, name)
    
    ModuleNotFoundError: No module named 'dill'
    

    Installing dill via pip install dill in the python environment corrects this error.

    opened by miscpeeps 3
  • test for ensuring hydra ddp raises is raising for the wrong reason

    test for ensuring hydra ddp raises is raising for the wrong reason

    @jgbos

    In the following test:

    https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/3882320391c87b6bf6330a09471a9808fca01160/tests/test_mushin/test_lightning_hydra_ddp.py#L27-L40

    launch(Config, pl_main_task) raises TypeError because pl_main_task doesn't accept a single config (pyright warned me about this). I doubt this is what you meant to exercise in this test.

    I am confused about what this test is doing. Config = make_config(trainer=trainer, wrong_config_name=module, devices=2) makes it seem like we are making sure that launch fails for a config with a bad field name, but the test seems like should be exercising ddp

    bug test-suite 
    opened by rsokl 3
  • CIFAR10-Adversarial-Perturbations.ipynb -- Load pretrained CIFAR-10 models not included and incorrectly named

    CIFAR10-Adversarial-Perturbations.ipynb -- Load pretrained CIFAR-10 models not included and incorrectly named

    CIFAR10-Adversarial-Perturbations.ipynb example jupyter notebook and tutorial reference mitll_cifar_l2_1_0.pt and mitll_cifar_nat.pt as pretrained CIFAR-10 models. These models are not included in the standard rai-toolbox[mushin] install (perhaps due to licensing or desire to have most up-to-date models?).

    Models download from urls at robustness Github are named cifar_l2_1_0.pt and cifar_nat.pt and will cause the following error on In[10] of CIFAR10-Adversarial-Perturbations.ipynb:

    FileNotFoundError                         Traceback (most recent call last)
    Input In [8], in <cell line: 3>()
          1 # Load pretrained model that was trained using a robust approach (i.e., adversarial training)
          2 ckpt_robust = "mitll_cifar_l2_1_0.pt"
    ----> 3 model_robust = load_model(ckpt_robust)
          4 model_robust.eval();
          6 # Load pretrained model that was trained with standard approach
    
    Input In [7], in load_model(ckpt)
          2 def load_model(ckpt):
    ----> 3     base_model = load_from_checkpoint(
          4         model = resnet50(),
          5         ckpt = ckpt,
          6         weights_key="state_dict",
          7     )
          9     normalizer = transforms.Normalize(
         10         mean=[0.4914, 0.4822, 0.4465],
         11         std=[0.2023, 0.1994, 0.2010],
         12     )
         14     model = nn.Sequential(normalizer, base_model)
    
    File ~/dev/rai-toolbox-james/responsible-ai-toolbox/src/rai_toolbox/mushin/_utils.py:60, in load_from_checkpoint(model, ckpt, weights_key, weights_key_strip, model_attr)
         57     ckpt = Path.home() / ".torch" / "models" / ckpt
         58 log.info(f"Loading model checkpoint from {ckpt}")
    ---> 60 ckpt_data: Dict[str, Any] = torch.load(ckpt, map_location="cpu")
         62 if weights_key is not None:
         63     assert weights_key in ckpt_data
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:699, in load(f, map_location, pickle_module, **pickle_load_args)
        696 if 'encoding' not in pickle_load_args.keys():
        697     pickle_load_args['encoding'] = 'utf-8'
    --> 699 with _open_file_like(f, 'rb') as opened_file:
        700     if _is_zipfile(opened_file):
        701         # The zipfile reader is going to advance the current file position.
        702         # If we want to actually tail call to torch.jit.load, we need to
        703         # reset back to the original position.
        704         orig_position = opened_file.tell()
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:231, in _open_file_like(name_or_buffer, mode)
        229 def _open_file_like(name_or_buffer, mode):
        230     if _is_path(name_or_buffer):
    --> 231         return _open_file(name_or_buffer, mode)
        232     else:
        233         if 'w' in mode:
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:212, in _open_file.__init__(self, name, mode)
        211 def __init__(self, name, mode):
    --> 212     super(_open_file, self).__init__(open(name, mode))
    
    FileNotFoundError: [Errno 2] No such file or directory: '/home/scott/.torch/models/mitll_cifar_l2_1_0.pt'
    

    Models must be renamed and manually copied to /home/{$USER}/.torch/models to proceed with tutorial.

    opened by miscpeeps 2
  • `zen` should not attempt to populate `*args` and `**kwargs`

    `zen` should not attempt to populate `*args` and `**kwargs`

    Previously zen would attempt to find a kwargs field in the config:

    Before:

    def f(x, **kwargs): return x
    
    cfg = make_config(x=1)
    
    zen(f)(cfg)  # AttributeError: 'Config' object has no attribute 'kwargs'
    

    Now zen skips *args, **kwargs.

    def f(x, **kwargs): return x
    
    cfg = make_config(x=1)
    
    zen(f)(cfg)  # returns 1
    

    In the future we might permit some configured behavior for populating these.

    bug 
    opened by rsokl 2
  • Update gradient-descent solver

    Update gradient-descent solver

    • Renames: gradient_descent -> gradient_ascent
    • (bug fix) Ensures that returned loss always has the correct sign. Previously, when targeted=False the returned loss values would be negated relative to the actual loss landscape
    • Adds examples section to docs
    • Ensures that data and target can be any array-like input, not necessarily a tensor
    bug code quality 
    opened by rsokl 2
  • Bump pydata-sphinx-theme from 0.8.1 to 0.11.0 in /docs

    Bump pydata-sphinx-theme from 0.8.1 to 0.11.0 in /docs

    Bumps pydata-sphinx-theme from 0.8.1 to 0.11.0.

    Release notes

    Sourced from pydata-sphinx-theme's releases.

    v0.11.0

    What's Changed

    New Contributors

    Full Changelog: https://github.com/pydata/pydata-sphinx-theme/compare/v0.10.1...v0.11.0

    v0.11.0rc3

    What's Changed

    ... (truncated)

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Bump sphinx-tabs from 3.3.1 to 3.4.1 in /docs

    Bump sphinx-tabs from 3.3.1 to 3.4.1 in /docs

    Bumps sphinx-tabs from 3.3.1 to 3.4.1.

    Release notes

    Sourced from sphinx-tabs's releases.

    Version 3.4.1

    What's Changed

    Full Changelog: https://github.com/executablebooks/sphinx-tabs/compare/v3.4.0...v3.4.1

    Version 3.4.0

    What's Changed

    New Contributors

    Full Changelog: https://github.com/executablebooks/sphinx-tabs/compare/v3.3.1...v3.4.0

    Changelog

    Sourced from sphinx-tabs's changelog.

    3.4.1 - 2022-97-02

    Added

    • Weekly scheduled testing, to catch breaking changes in unpinned dependencies

    Changed

    • docutils version pin to allow use of verison 0.18.x

    Removed

    • sphinx version pinning - only the latest version of sphinx will now be fully supported, but previous versions will work if sphinx dependencies (i.e. jinja2) are managed correctly. This is inline with the approach at sphinx
    • tests that were specific to older versions of sphinx and pygments
    • jinja2 version pinning, as this is now pinned in latest version of sphinx

    3.4.0 - 2022-06-26

    Added

    • Testing for sphinx 5
    • Tesing for python 3.10

    Fixed

    • Fixed parsing of MyST content, where first line was being stripped
    • Typos in documentation
    • Failing regression tests

    Changed

    • Testing to use an up-to-date pytest version

    Removed

    • Testing for python 3.6 and sphinx versions 2 and 4 (see #164). Note that the package will likely continue to work fine with these, but this won't be assured by tests
    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Bump sphinx-codeautolink from 0.10.0 to 0.12.0 in /docs

    Bump sphinx-codeautolink from 0.10.0 to 0.12.0 in /docs

    Bumps sphinx-codeautolink from 0.10.0 to 0.12.0.

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Bump sphinx from 5.3.0 to 6.0.0 in /docs

    Bump sphinx from 5.3.0 to 6.0.0 in /docs

    Bumps sphinx from 5.3.0 to 6.0.0.

    Release notes

    Sourced from sphinx's releases.

    v6.0.0

    Changelog: https://www.sphinx-doc.org/en/master/changes.html

    v6.0.0b2

    Changelog: https://www.sphinx-doc.org/en/master/changes.html

    v6.0.0b1

    Changelog: https://www.sphinx-doc.org/en/master/changes.html

    Changelog

    Sourced from sphinx's changelog.

    Release 6.0.0 (released Dec 29, 2022)

    Dependencies

    • #10468: Drop Python 3.6 support
    • #10470: Drop Python 3.7, Docutils 0.14, Docutils 0.15, Docutils 0.16, and Docutils 0.17 support. Patch by Adam Turner

    Incompatible changes

    • #7405: Removed the jQuery and underscore.js JavaScript frameworks.

      These frameworks are no longer be automatically injected into themes from Sphinx 6.0. If you develop a theme or extension that uses the jQuery, $, or $u global objects, you need to update your JavaScript to modern standards, or use the mitigation below.

      The first option is to use the sphinxcontrib.jquery_ extension, which has been developed by the Sphinx team and contributors. To use this, add sphinxcontrib.jquery to the extensions list in conf.py, or call app.setup_extension("sphinxcontrib.jquery") if you develop a Sphinx theme or extension.

      The second option is to manually ensure that the frameworks are present. To re-add jQuery and underscore.js, you will need to copy jquery.js and underscore.js from the Sphinx repository_ to your static directory, and add the following to your layout.html:

      .. code-block:: html+jinja

      {%- block scripts %} {{ super() }} {%- endblock %}

      .. _sphinxcontrib.jquery: https://github.com/sphinx-contrib/jquery/

      Patch by Adam Turner.

    • #10471, #10565: Removed deprecated APIs scheduled for removal in Sphinx 6.0. See :ref:dev-deprecated-apis for details. Patch by Adam Turner.

    • #10901: C Domain: Remove support for parsing pre-v3 style type directives and roles. Also remove associated configuration variables c_allow_pre_v3 and c_warn_on_allowed_pre_v3. Patch by Adam Turner.

    Features added

    ... (truncated)

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Update madry example

    Update madry example

    Currently we use Workflow.run within hydra.main, which no longer works: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/5348f7d6a96837d0f9d9ce1b5e71ebfdee6f4b88/experiments/madry/run.py#L29

    We should update this to leverage zen, but to make sure that the plotting still works (i.e. we give the workflow the necessary context to gather the xarray)

    opened by rsokl 0
  • Add Pickled Hydra Runs (e.g., Rerun) to Support PL `ddp`

    Add Pickled Hydra Runs (e.g., Rerun) to Support PL `ddp`

    In this PR we will attempt to address two issues:

    1. Reproducible Hydra experiments purely from the run directory by pickling both the runtime configuration and the task function
      • An extension of Hydra's experimental rerun
    2. Solving Hydra+DDP for PyTorch Lightning ddp strategy by saving the task function
      • The current solution in HydraDDP has strong constraints on the expected task function. This limits what the user can do in their experiments.

    Hydra Rerun Capability

    Here we take advantage of Hydra Callbacks to save the runtime configuration and the desired task function. Currently our callback takes a task function on initialization but future Hydra version's may allow the Hydra to pass the task function to the callback methds.

    Callback implementation: MushinPickleJobCallback. This takes in a Hydra task function on initialization and saves the task function and runtime configuration in the hydra.runtime.output_dir folder. The pickled files are stored in:

    <hydra.runtime.output_dir>/config.pickle
    <hydra.runtime.output_dir>/task_fn.pickle
    

    This implementation uses cloudpickle to support pickling of the task function. The only downside of this approach is that the task function must be hashable for pickling and "instantiable" for Hydra from the command line, e.g., defining the task function in the notebook won't work.

    Note: Submitit is capable of pickling functions that were created in __main__, so this should be possible

    Execution: With the configuration and task function saved in the job directory, we can rerun any experiment using:

    $ python -m rai_toolbox.mushin._hydra_rerun +config=<path to config.pickle> +task_fn=<path to task_fn.pickle>
    

    Lightning DDP

    Challenges this PR solves for Hydra+DDP:

    • Runs from notebook
    • Supports generic task functions (i.e., solves HydraDDP issue)
    • Task functions can run multiple Trainer methods (e.g., Trainer.fit followed by Trainer.test). HydraDDP does not support these types of task functions

    First we must configure our custom Hydra Callback, MushinPickleJobCallback:

    task_fn_cfg = builds(...)
    
    callback_cfg = dict(
        save_job_info=builds(MushinPickleJobCallback, task_fn=task_fn_cfg)
    )
    
    cs = ConfigStore.instance()
    cs.store(name="pickle_job", group="hydra/callbacks", node=callback_cfg)
    

    The Trainer strategy can then be configured with our costum Lightning ddp strategy, HydraRerunDDP:

    TrainerConfig = builds(Trainer,   strategy=builds(HydraRerunDDP))
    

    We must set hydra/callbacks in the overrides to launch a job:

    task_fn = instantiate(task_fn_cfg)
    launch(Config, task_fn, overrides=["hydra/callbacks=pickle_job", ...])
    

    Notes

    • MushinPickleJobCallback will clean up the PL environment automatically at the end of a job.
    • See tests for examples.

    I plan to update this comment to better describe everything

    TODOS

    • [ ] Should we deprecate HydraDDP in favor of this
    • [ ] Can we pickle and use task functions built in a "main" setting like the notebook?
    • [ ] Structure of Hydra specific and Lightning specific code
    • [ ] More tests: - Validate results, not just pickle file available - Test Hydra rerun without Lightning
    opened by jgbos 1
  • Implements elastic-net attack

    Implements elastic-net attack

    Derived from: https://arxiv.org/pdf/1709.04114.pdf

    Here is a trivial scenario where we are merely perturbing the "logits" themselves so that the specified targets will be optimized for. Let's see that the longer we run the optimizer, the more the learned perturbation shrinks (while still amounting to a successful attack).

    >>> from rai_toolbox.perturbations.solvers import elastic_net_attack
    >>> logits = [[0.497, 0.503]]
    >>> target = [0]
    
    >>> for num_steps in [1, 10, 100]:
    ...     _, x_adv, _ = elastic_net_attack(
    ...         model=lambda x: x,
    ...         data=logits,
    ...         target=target,
    ...         beta=1e-3,
    ...         c=2,
    ...         steps=num_steps,
    ...         confidence=.01,
    ...         lr=0.5,
    ...     )
    ...     print(f"num-steps: {num_steps}\n{x_adv}")
    num-steps: 1
    tensor([[ 1.4960, -0.4960]])
    num-steps: 10
    tensor([[0.5062, 0.4938]])
    num-steps: 100
    tensor([[0.5018, 0.4982]])
    
    opened by rsokl 0
  • Use fused multiply-add to apply `grad_scale` and `grad_bias`

    Use fused multiply-add to apply `grad_scale` and `grad_bias`

    https://pytorch.org/docs/stable/generated/torch.add.html

    >>> a = torch.randn(4)
    >>> a
    tensor([ 0.0202,  1.0985,  1.3506, -0.6056])
    
    >>> b = torch.randn(4)
    >>> b
    tensor([-0.9732, -0.3497,  0.6245,  0.4022])
    >>> c = torch.randn(4, 1)
    >>> c
    tensor([[ 0.3743],
            [-1.7724],
            [-0.5811],
            [-0.8017]])
    >>> torch.add(b, c, alpha=10)
    tensor([[  2.7695,   3.3930,   4.3672,   4.1450],
            [-18.6971, -18.0736, -17.0994, -17.3216],
            [ -6.7845,  -6.1610,  -5.1868,  -5.4090],
            [ -8.9902,  -8.3667,  -7.3925,  -7.6147]])
    
    opened by rsokl 0
Releases(v0.2.1)
  • v0.2.1(Jun 16, 2022)

    See changelog

    What's Changed

    • Fix TopQGradient device mismatch by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/64
    • Fixes to_xarray when target_job_dirs points to job that performed multirun over sequence values by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/68

    Full Changelog: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Jun 1, 2022)

    See changelog for details

    What's Changed

    • zen should not attempt to populate *args and **kwargs by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/48
    • Workflow Improvement by @jgbos in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/47
    • Adds zen callbacks by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/49
    • Remove numpy dependency (defer to pytorch) by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/50
    • Make methods static where possible; simplify examples; cleanup formating by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/51
    • Add working_subdir data variable to xarray by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/52
    • Improve parity between pre-step and post-step method names by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/54
    • fix typo in univ_adv_pert.rst by @Jasha10 in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/56
    • Update workflows by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/55
    • Adds Support for Lightning's Trainer.predict by @jgbos in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/53
    • Fix hypothesis by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/58
    • Add pre-task method to workflow by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/57
    • Deprecate ParamTransformingOptimizer.project by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/59
    • Add pre-release and nightly CI jobs by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/60
    • Ensure workflow overrides roundtrip by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/61
    • Deprecate evaluation_task in favor of task by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/62
    • Enable user-specified functions for loading metrics files by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/63

    New Contributors

    • @Jasha10 made their first contribution in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/56

    Full Changelog: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/compare/v0.1.1...v0.2.0

    Source code(tar.gz)
    Source code(zip)
Code for Overinterpretation paper Overinterpretation reveals image classification model pathologies

Overinterpretation This repository contains the code for the paper: Overinterpretation reveals image classification model pathologies Authors: Brandon

Gifford Lab, MIT CSAIL 17 Dec 10, 2022
Predictive Maintenance LSTM

Predictive-Maintenance-LSTM - Predictive maintenance study for Complex case study, we've obtained failure causes by operational error and more deeply by design mistakes.

Amir M. Sadafi 1 Dec 31, 2021
EqGAN - Improving GAN Equilibrium by Raising Spatial Awareness

EqGAN - Improving GAN Equilibrium by Raising Spatial Awareness Improving GAN Equilibrium by Raising Spatial Awareness Jianyuan Wang, Ceyuan Yang, Ying

GenForce: May Generative Force Be with You 149 Dec 19, 2022
Code for paper "Do Language Models Have Beliefs? Methods for Detecting, Updating, and Visualizing Model Beliefs"

This is the codebase for the paper: Do Language Models Have Beliefs? Methods for Detecting, Updating, and Visualizing Model Beliefs Directory Structur

Peter Hase 19 Aug 21, 2022
Pytorch implementation for RelTransformer

RelTransformer Our Architecture This is a Pytorch implementation for RelTransformer The implementation for Evaluating on VG200 can be found here Requi

Vision CAIR Research Group, KAUST 21 Nov 22, 2022
Open & Efficient for Framework for Aspect-based Sentiment Analysis

PyABSA - Open & Efficient for Framework for Aspect-based Sentiment Analysis Fast & Low Memory requirement & Enhanced implementation of Local Context F

YangHeng 567 Jan 07, 2023
Tensorflow-Project-Template - A best practice for tensorflow project template architecture.

Tensorflow Project Template A simple and well designed structure is essential for any Deep Learning project, so after a lot of practice and contributi

Mahmoud G. Salem 3.6k Dec 22, 2022
Deep Learning pipeline for motor-imagery classification.

BCI-ToolBox 1. Introduction BCI-ToolBox is deep learning pipeline for motor-imagery classification. This repo contains five models: ShallowConvNet, De

DongHee 18 Oct 31, 2022
Colossal-AI: A Unified Deep Learning System for Large-Scale Parallel Training

ColossalAI An integrated large-scale model training system with efficient parallelization techniques. arXiv: Colossal-AI: A Unified Deep Learning Syst

HPC-AI Tech 7.9k Jan 08, 2023
"MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction" (CVPRW 2022) & (Winner of NTIRE 2022 Challenge on Spectral Reconstruction from RGB)

MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction (CVPRW 2022) Yuanhao Cai, Jing Lin, Zudi Lin, Haoqian Wang, Yulun Z

Yuanhao Cai 274 Jan 05, 2023
High performance distributed framework for training deep learning recommendation models based on PyTorch.

High performance distributed framework for training deep learning recommendation models based on PyTorch.

340 Dec 30, 2022
codebase for "A Theory of the Inductive Bias and Generalization of Kernel Regression and Wide Neural Networks"

Eigenlearning This repo contains code for replicating the experiments of the paper A Theory of the Inductive Bias and Generalization of Kernel Regress

Jamie Simon 45 Dec 02, 2022
Aydin is a user-friendly, feature-rich, and fast image denoising tool

Aydin is a user-friendly, feature-rich, and fast image denoising tool that provides a number of self-supervised, auto-tuned, and unsupervised image denoising algorithms.

Royer Lab 99 Dec 14, 2022
PyTorch implementation(s) of various ResNet models from Twitch streams.

pytorch-resnet-twitch PyTorch implementation(s) of various ResNet models from Twitch streams. Status: ResNet50 currently not working. Will update in n

Daniel Bourke 3 Jan 11, 2022
DynaTune: Dynamic Tensor Program Optimization in Deep Neural Network Compilation

DynaTune: Dynamic Tensor Program Optimization in Deep Neural Network Compilation This repository is the implementation of DynaTune paper. This folder

4 Nov 02, 2022
Weakly Supervised Segmentation by Tensorflow.

Weakly Supervised Segmentation by Tensorflow. Implements semantic segmentation in Simple Does It: Weakly Supervised Instance and Semantic Segmentation, by Khoreva et al. (CVPR 2017).

CHENG-YOU LU 52 Dec 27, 2022
CrossMLP - The repository offers the official implementation of our BMVC 2021 paper (oral) in PyTorch.

CrossMLP Cascaded Cross MLP-Mixer GANs for Cross-View Image Translation Bin Ren1, Hao Tang2, Nicu Sebe1. 1University of Trento, Italy, 2ETH, Switzerla

Bingoren 16 Jul 27, 2022
PIXIE: Collaborative Regression of Expressive Bodies

PIXIE: Collaborative Regression of Expressive Bodies [Project Page] This is the official Pytorch implementation of PIXIE. PIXIE reconstructs an expres

Yao Feng 331 Jan 04, 2023
Python package facilitating the use of Bayesian Deep Learning methods with Variational Inference for PyTorch

PyVarInf PyVarInf provides facilities to easily train your PyTorch neural network models using variational inference. Bayesian Deep Learning with Vari

342 Dec 02, 2022
zeus is a Python implementation of the Ensemble Slice Sampling method.

zeus is a Python implementation of the Ensemble Slice Sampling method. Fast & Robust Bayesian Inference, Efficient Markov Chain Monte Carlo (MCMC), Bl

Minas Karamanis 197 Dec 04, 2022