PyTorch-centric library for evaluating and enhancing the robustness of AI technologies

Overview

Responsible AI Toolbox

PyPI Python version support GitHub Actions Tested with Hypothesis

A library that provides high-quality, PyTorch-centric tools for evaluating and enhancing both the robustness and the explainability of AI models.

Check out our documentation for more information.

The rAI-toolbox works great with PyTorch Lightning and Hydra 🐉 . Check out rai_toolbox.mushin to see how we use these frameworks to create efficient, configurable, and reproducible ML workflows with minimal boilerplate code.

Citation

Using rai_toolbox for your research? Please cite the following publication:

@article{soklaski2022tools,
  title={Tools and Practices for Responsible AI Engineering},
  author={Soklaski, Ryan and Goodwin, Justin and Brown, Olivia and Yee, Michael and Matterer, Jason},
  journal={arXiv preprint arXiv:2201.05647},
  year={2022}
}

Contributing

If you would like to contribute to this repo, please refer to our CONTRIBUTING.md document.

Disclaimer

DISTRIBUTION STATEMENT A. Approved for public release. Distribution is unlimited.

© 2022 MASSACHUSETTS INSTITUTE OF TECHNOLOGY

  • Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014)
  • SPDX-License-Identifier: MIT

This material is based upon work supported by the Under Secretary of Defense for Research and Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the Under Secretary of Defense for Research and Engineering.

A portion of this research was sponsored by the United States Air Force Research Laboratory and the United States Air Force Artificial Intelligence Accelerator and was accomplished under Cooperative Agreement Number FA8750-19-2-1000. The views and conclusions contained in this document are those of the authors and should not be interpreted as representing the official policies, either expressed or implied, of the United States Air Force or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes notwithstanding any copyright notation herein.

The software/firmware is provided to you on an As-Is basis.

Comments
  • Update workflows

    Update workflows

    See example use here: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/workflow-docs/examples/MNIST-Translation-Robustness.ipynb

    • [x] Create base class for workflows
    • [x] Update docs
    • [x] Create tests for workflows
    opened by jgbos 7
  • Strange computational graph issue with `gradient_ascent` and `LightningModule`

    Strange computational graph issue with `gradient_ascent` and `LightningModule`

    First here's a working simple example of running gradient_ascent that works without error:

    from functools import partial
    import torch as tr
    from torchvision import models
    from rai_toolbox.optim import L2ProjectedOptim
    from rai_toolbox.perturbations.solvers import gradient_ascent
    
    model = models.resnet18()
    data = tr.rand(10, 3, 100, 100, dtype=tr.float)
    target = tr.randint(0, 2, size=(10,))
    pert = partial(
        gradient_ascent, optimizer=L2ProjectedOptim, epsilon=1.0, steps=1, lr=1.0
    )
    
    # run gradient ascent
    pert(model=model, data=data, target=target)
    

    Now setup and run the same thing using Trainer.predict:

    import pytorch_lightning as pl
    
    class Lit(pl.LightningModule):
        def __init__(self):
            super().__init__()
            self.model = model
            self.pert = pert
    
        def predict_step(self, batch, *args, **kwargs):
            data, target = batch
            data = self.pert(model=self.model, data=data, target=target)
            logits = self.model(data)
            return logits.sum()
    
    trainer = pl.Trainer()
    trainer.predict(
        Lit(),
        datamodule=pl.LightningDataModule.from_datasets(
            predict_dataset=tr.utils.data.TensorDataset(data, target),
            batch_size=1,
            num_workers=0,
        ),
    )
    

    Here we get the following error:

    ...
    /tmp/ipykernel_74682/1909129363.py in predict_step(self, batch, *args, **kwargs)
         27     def predict_step(self, batch, *args, **kwargs):
         28         data, target = batch
    ---> 29         data = self.pert(model=self, data=data, target=target)
         30         logits = self.model(data)
         31         return logits.sum()
    
    ~/projects/raiden/rai_toolbox/src/rai_toolbox/perturbations/solvers.py in gradient_ascent(model, data, target, optimizer, steps, perturbation_model, targeted, use_best, criterion, reduction_fn, **optim_kwargs)
        277             # Update the perturbation
        278             optim.zero_grad(set_to_none=True)
    --> 279             loss.backward()
        280             optim.step()
        281 
    
    ~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
        394                 create_graph=create_graph,
        395                 inputs=inputs)
    --> 396         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
        397 
        398     def register_hook(self, hook):
    
    ~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
        171     # some Python versions print out the first line of a multi-line function
        172     # calls in the traceback and some print out the last line
    --> 173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
        174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
        175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    
    RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
    

    If I enter debug everything seems to be setup correctly except that pmodel(data) does not return a tensor with grad_fn!!

    #
    # pdb at `loss.backward()` line
    #
    > pmodel.delta.requires_grad
    True
    
    > tr.is_grad_enabled()
    True
    
    > pmodel.delta + data
    ... # tensor output without `grad_fn`
    
    # try reinitializing
    > perturbation_model(data)(data)
    ... # tensor output WITH `grad_fn`
    

    I have no idea how to debug this and find out what is wrong.

    @rsokl do you get this error in your environment?

    opened by jgbos 6
  • Docs: perturbation explanation

    Docs: perturbation explanation

    Starting an explanation on our approach to data perturbations. I still intend to add more to this today, but feel free to take a look and let me know your thoughts on how it's going so far. Especially what should/shouldn't be included in this

    opened by oliviamb 4
  • CIFAR10-Adversarial-Perturbations.ipynb -- Standard rai-toolbox[mushin] install doesn't include dill module

    CIFAR10-Adversarial-Perturbations.ipynb -- Standard rai-toolbox[mushin] install doesn't include dill module

    CIFAR10-Adversarial-Perturbations.ipynb example jupyter notebook attempts to load the pretrained models and fails with ModuleNotFoundError: No module named 'dill'. Dill module is not included in the standard rai-toolbox[mushin] install.

    Full error traceback below:

    ModuleNotFoundError                       Traceback (most recent call last)
    Input In [9], in <cell line: 3>()
          1 # Load pretrained model that was trained using a robust approach (i.e., adversarial training)
          2 ckpt_robust = "mitll_cifar_l2_1_0.pt"
    ----> 3 model_robust = load_model(ckpt_robust)
          4 model_robust.eval();
          6 # Load pretrained model that was trained with standard approach
    
    Input In [7], in load_model(ckpt)
          2 def load_model(ckpt):
    ----> 3     base_model = load_from_checkpoint(
          4         model = resnet50(),
          5         ckpt = ckpt,
          6         weights_key="state_dict",
          7     )
          9     normalizer = transforms.Normalize(
         10         mean=[0.4914, 0.4822, 0.4465],
         11         std=[0.2023, 0.1994, 0.2010],
         12     )
         14     model = nn.Sequential(normalizer, base_model)
    
    File ~/dev/rai-toolbox-james/responsible-ai-toolbox/src/rai_toolbox/mushin/_utils.py:60, in load_from_checkpoint(model, ckpt, weights_key, weights_key_strip, model_attr)
         57     ckpt = Path.home() / ".torch" / "models" / ckpt
         58 log.info(f"Loading model checkpoint from {ckpt}")
    ---> 60 ckpt_data: Dict[str, Any] = torch.load(ckpt, map_location="cpu")
         62 if weights_key is not None:
         63     assert weights_key in ckpt_data
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:713, in load(f, map_location, pickle_module, **pickle_load_args)
        711             return torch.jit.load(opened_file)
        712         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    --> 713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:930, in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
        928 unpickler = UnpicklerWrapper(f, **pickle_load_args)
        929 unpickler.persistent_load = persistent_load
    --> 930 result = unpickler.load()
        932 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
        934 offset = f.tell() if f_should_read_directly else None
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:746, in _legacy_load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
        744     except KeyError:
        745         pass
    --> 746 return super().find_class(mod_name, name)
    
    ModuleNotFoundError: No module named 'dill'
    

    Installing dill via pip install dill in the python environment corrects this error.

    opened by miscpeeps 3
  • test for ensuring hydra ddp raises is raising for the wrong reason

    test for ensuring hydra ddp raises is raising for the wrong reason

    @jgbos

    In the following test:

    https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/3882320391c87b6bf6330a09471a9808fca01160/tests/test_mushin/test_lightning_hydra_ddp.py#L27-L40

    launch(Config, pl_main_task) raises TypeError because pl_main_task doesn't accept a single config (pyright warned me about this). I doubt this is what you meant to exercise in this test.

    I am confused about what this test is doing. Config = make_config(trainer=trainer, wrong_config_name=module, devices=2) makes it seem like we are making sure that launch fails for a config with a bad field name, but the test seems like should be exercising ddp

    bug test-suite 
    opened by rsokl 3
  • CIFAR10-Adversarial-Perturbations.ipynb -- Load pretrained CIFAR-10 models not included and incorrectly named

    CIFAR10-Adversarial-Perturbations.ipynb -- Load pretrained CIFAR-10 models not included and incorrectly named

    CIFAR10-Adversarial-Perturbations.ipynb example jupyter notebook and tutorial reference mitll_cifar_l2_1_0.pt and mitll_cifar_nat.pt as pretrained CIFAR-10 models. These models are not included in the standard rai-toolbox[mushin] install (perhaps due to licensing or desire to have most up-to-date models?).

    Models download from urls at robustness Github are named cifar_l2_1_0.pt and cifar_nat.pt and will cause the following error on In[10] of CIFAR10-Adversarial-Perturbations.ipynb:

    FileNotFoundError                         Traceback (most recent call last)
    Input In [8], in <cell line: 3>()
          1 # Load pretrained model that was trained using a robust approach (i.e., adversarial training)
          2 ckpt_robust = "mitll_cifar_l2_1_0.pt"
    ----> 3 model_robust = load_model(ckpt_robust)
          4 model_robust.eval();
          6 # Load pretrained model that was trained with standard approach
    
    Input In [7], in load_model(ckpt)
          2 def load_model(ckpt):
    ----> 3     base_model = load_from_checkpoint(
          4         model = resnet50(),
          5         ckpt = ckpt,
          6         weights_key="state_dict",
          7     )
          9     normalizer = transforms.Normalize(
         10         mean=[0.4914, 0.4822, 0.4465],
         11         std=[0.2023, 0.1994, 0.2010],
         12     )
         14     model = nn.Sequential(normalizer, base_model)
    
    File ~/dev/rai-toolbox-james/responsible-ai-toolbox/src/rai_toolbox/mushin/_utils.py:60, in load_from_checkpoint(model, ckpt, weights_key, weights_key_strip, model_attr)
         57     ckpt = Path.home() / ".torch" / "models" / ckpt
         58 log.info(f"Loading model checkpoint from {ckpt}")
    ---> 60 ckpt_data: Dict[str, Any] = torch.load(ckpt, map_location="cpu")
         62 if weights_key is not None:
         63     assert weights_key in ckpt_data
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:699, in load(f, map_location, pickle_module, **pickle_load_args)
        696 if 'encoding' not in pickle_load_args.keys():
        697     pickle_load_args['encoding'] = 'utf-8'
    --> 699 with _open_file_like(f, 'rb') as opened_file:
        700     if _is_zipfile(opened_file):
        701         # The zipfile reader is going to advance the current file position.
        702         # If we want to actually tail call to torch.jit.load, we need to
        703         # reset back to the original position.
        704         orig_position = opened_file.tell()
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:231, in _open_file_like(name_or_buffer, mode)
        229 def _open_file_like(name_or_buffer, mode):
        230     if _is_path(name_or_buffer):
    --> 231         return _open_file(name_or_buffer, mode)
        232     else:
        233         if 'w' in mode:
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:212, in _open_file.__init__(self, name, mode)
        211 def __init__(self, name, mode):
    --> 212     super(_open_file, self).__init__(open(name, mode))
    
    FileNotFoundError: [Errno 2] No such file or directory: '/home/scott/.torch/models/mitll_cifar_l2_1_0.pt'
    

    Models must be renamed and manually copied to /home/{$USER}/.torch/models to proceed with tutorial.

    opened by miscpeeps 2
  • `zen` should not attempt to populate `*args` and `**kwargs`

    `zen` should not attempt to populate `*args` and `**kwargs`

    Previously zen would attempt to find a kwargs field in the config:

    Before:

    def f(x, **kwargs): return x
    
    cfg = make_config(x=1)
    
    zen(f)(cfg)  # AttributeError: 'Config' object has no attribute 'kwargs'
    

    Now zen skips *args, **kwargs.

    def f(x, **kwargs): return x
    
    cfg = make_config(x=1)
    
    zen(f)(cfg)  # returns 1
    

    In the future we might permit some configured behavior for populating these.

    bug 
    opened by rsokl 2
  • Update gradient-descent solver

    Update gradient-descent solver

    • Renames: gradient_descent -> gradient_ascent
    • (bug fix) Ensures that returned loss always has the correct sign. Previously, when targeted=False the returned loss values would be negated relative to the actual loss landscape
    • Adds examples section to docs
    • Ensures that data and target can be any array-like input, not necessarily a tensor
    bug code quality 
    opened by rsokl 2
  • Bump pydata-sphinx-theme from 0.8.1 to 0.11.0 in /docs

    Bump pydata-sphinx-theme from 0.8.1 to 0.11.0 in /docs

    Bumps pydata-sphinx-theme from 0.8.1 to 0.11.0.

    Release notes

    Sourced from pydata-sphinx-theme's releases.

    v0.11.0

    What's Changed

    New Contributors

    Full Changelog: https://github.com/pydata/pydata-sphinx-theme/compare/v0.10.1...v0.11.0

    v0.11.0rc3

    What's Changed

    ... (truncated)

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Bump sphinx-tabs from 3.3.1 to 3.4.1 in /docs

    Bump sphinx-tabs from 3.3.1 to 3.4.1 in /docs

    Bumps sphinx-tabs from 3.3.1 to 3.4.1.

    Release notes

    Sourced from sphinx-tabs's releases.

    Version 3.4.1

    What's Changed

    Full Changelog: https://github.com/executablebooks/sphinx-tabs/compare/v3.4.0...v3.4.1

    Version 3.4.0

    What's Changed

    New Contributors

    Full Changelog: https://github.com/executablebooks/sphinx-tabs/compare/v3.3.1...v3.4.0

    Changelog

    Sourced from sphinx-tabs's changelog.

    3.4.1 - 2022-97-02

    Added

    • Weekly scheduled testing, to catch breaking changes in unpinned dependencies

    Changed

    • docutils version pin to allow use of verison 0.18.x

    Removed

    • sphinx version pinning - only the latest version of sphinx will now be fully supported, but previous versions will work if sphinx dependencies (i.e. jinja2) are managed correctly. This is inline with the approach at sphinx
    • tests that were specific to older versions of sphinx and pygments
    • jinja2 version pinning, as this is now pinned in latest version of sphinx

    3.4.0 - 2022-06-26

    Added

    • Testing for sphinx 5
    • Tesing for python 3.10

    Fixed

    • Fixed parsing of MyST content, where first line was being stripped
    • Typos in documentation
    • Failing regression tests

    Changed

    • Testing to use an up-to-date pytest version

    Removed

    • Testing for python 3.6 and sphinx versions 2 and 4 (see #164). Note that the package will likely continue to work fine with these, but this won't be assured by tests
    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Bump sphinx-codeautolink from 0.10.0 to 0.12.0 in /docs

    Bump sphinx-codeautolink from 0.10.0 to 0.12.0 in /docs

    Bumps sphinx-codeautolink from 0.10.0 to 0.12.0.

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Bump sphinx from 5.3.0 to 6.0.0 in /docs

    Bump sphinx from 5.3.0 to 6.0.0 in /docs

    Bumps sphinx from 5.3.0 to 6.0.0.

    Release notes

    Sourced from sphinx's releases.

    v6.0.0

    Changelog: https://www.sphinx-doc.org/en/master/changes.html

    v6.0.0b2

    Changelog: https://www.sphinx-doc.org/en/master/changes.html

    v6.0.0b1

    Changelog: https://www.sphinx-doc.org/en/master/changes.html

    Changelog

    Sourced from sphinx's changelog.

    Release 6.0.0 (released Dec 29, 2022)

    Dependencies

    • #10468: Drop Python 3.6 support
    • #10470: Drop Python 3.7, Docutils 0.14, Docutils 0.15, Docutils 0.16, and Docutils 0.17 support. Patch by Adam Turner

    Incompatible changes

    • #7405: Removed the jQuery and underscore.js JavaScript frameworks.

      These frameworks are no longer be automatically injected into themes from Sphinx 6.0. If you develop a theme or extension that uses the jQuery, $, or $u global objects, you need to update your JavaScript to modern standards, or use the mitigation below.

      The first option is to use the sphinxcontrib.jquery_ extension, which has been developed by the Sphinx team and contributors. To use this, add sphinxcontrib.jquery to the extensions list in conf.py, or call app.setup_extension("sphinxcontrib.jquery") if you develop a Sphinx theme or extension.

      The second option is to manually ensure that the frameworks are present. To re-add jQuery and underscore.js, you will need to copy jquery.js and underscore.js from the Sphinx repository_ to your static directory, and add the following to your layout.html:

      .. code-block:: html+jinja

      {%- block scripts %} {{ super() }} {%- endblock %}

      .. _sphinxcontrib.jquery: https://github.com/sphinx-contrib/jquery/

      Patch by Adam Turner.

    • #10471, #10565: Removed deprecated APIs scheduled for removal in Sphinx 6.0. See :ref:dev-deprecated-apis for details. Patch by Adam Turner.

    • #10901: C Domain: Remove support for parsing pre-v3 style type directives and roles. Also remove associated configuration variables c_allow_pre_v3 and c_warn_on_allowed_pre_v3. Patch by Adam Turner.

    Features added

    ... (truncated)

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Update madry example

    Update madry example

    Currently we use Workflow.run within hydra.main, which no longer works: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/5348f7d6a96837d0f9d9ce1b5e71ebfdee6f4b88/experiments/madry/run.py#L29

    We should update this to leverage zen, but to make sure that the plotting still works (i.e. we give the workflow the necessary context to gather the xarray)

    opened by rsokl 0
  • Add Pickled Hydra Runs (e.g., Rerun) to Support PL `ddp`

    Add Pickled Hydra Runs (e.g., Rerun) to Support PL `ddp`

    In this PR we will attempt to address two issues:

    1. Reproducible Hydra experiments purely from the run directory by pickling both the runtime configuration and the task function
      • An extension of Hydra's experimental rerun
    2. Solving Hydra+DDP for PyTorch Lightning ddp strategy by saving the task function
      • The current solution in HydraDDP has strong constraints on the expected task function. This limits what the user can do in their experiments.

    Hydra Rerun Capability

    Here we take advantage of Hydra Callbacks to save the runtime configuration and the desired task function. Currently our callback takes a task function on initialization but future Hydra version's may allow the Hydra to pass the task function to the callback methds.

    Callback implementation: MushinPickleJobCallback. This takes in a Hydra task function on initialization and saves the task function and runtime configuration in the hydra.runtime.output_dir folder. The pickled files are stored in:

    <hydra.runtime.output_dir>/config.pickle
    <hydra.runtime.output_dir>/task_fn.pickle
    

    This implementation uses cloudpickle to support pickling of the task function. The only downside of this approach is that the task function must be hashable for pickling and "instantiable" for Hydra from the command line, e.g., defining the task function in the notebook won't work.

    Note: Submitit is capable of pickling functions that were created in __main__, so this should be possible

    Execution: With the configuration and task function saved in the job directory, we can rerun any experiment using:

    $ python -m rai_toolbox.mushin._hydra_rerun +config=<path to config.pickle> +task_fn=<path to task_fn.pickle>
    

    Lightning DDP

    Challenges this PR solves for Hydra+DDP:

    • Runs from notebook
    • Supports generic task functions (i.e., solves HydraDDP issue)
    • Task functions can run multiple Trainer methods (e.g., Trainer.fit followed by Trainer.test). HydraDDP does not support these types of task functions

    First we must configure our custom Hydra Callback, MushinPickleJobCallback:

    task_fn_cfg = builds(...)
    
    callback_cfg = dict(
        save_job_info=builds(MushinPickleJobCallback, task_fn=task_fn_cfg)
    )
    
    cs = ConfigStore.instance()
    cs.store(name="pickle_job", group="hydra/callbacks", node=callback_cfg)
    

    The Trainer strategy can then be configured with our costum Lightning ddp strategy, HydraRerunDDP:

    TrainerConfig = builds(Trainer,   strategy=builds(HydraRerunDDP))
    

    We must set hydra/callbacks in the overrides to launch a job:

    task_fn = instantiate(task_fn_cfg)
    launch(Config, task_fn, overrides=["hydra/callbacks=pickle_job", ...])
    

    Notes

    • MushinPickleJobCallback will clean up the PL environment automatically at the end of a job.
    • See tests for examples.

    I plan to update this comment to better describe everything

    TODOS

    • [ ] Should we deprecate HydraDDP in favor of this
    • [ ] Can we pickle and use task functions built in a "main" setting like the notebook?
    • [ ] Structure of Hydra specific and Lightning specific code
    • [ ] More tests: - Validate results, not just pickle file available - Test Hydra rerun without Lightning
    opened by jgbos 1
  • Implements elastic-net attack

    Implements elastic-net attack

    Derived from: https://arxiv.org/pdf/1709.04114.pdf

    Here is a trivial scenario where we are merely perturbing the "logits" themselves so that the specified targets will be optimized for. Let's see that the longer we run the optimizer, the more the learned perturbation shrinks (while still amounting to a successful attack).

    >>> from rai_toolbox.perturbations.solvers import elastic_net_attack
    >>> logits = [[0.497, 0.503]]
    >>> target = [0]
    
    >>> for num_steps in [1, 10, 100]:
    ...     _, x_adv, _ = elastic_net_attack(
    ...         model=lambda x: x,
    ...         data=logits,
    ...         target=target,
    ...         beta=1e-3,
    ...         c=2,
    ...         steps=num_steps,
    ...         confidence=.01,
    ...         lr=0.5,
    ...     )
    ...     print(f"num-steps: {num_steps}\n{x_adv}")
    num-steps: 1
    tensor([[ 1.4960, -0.4960]])
    num-steps: 10
    tensor([[0.5062, 0.4938]])
    num-steps: 100
    tensor([[0.5018, 0.4982]])
    
    opened by rsokl 0
  • Use fused multiply-add to apply `grad_scale` and `grad_bias`

    Use fused multiply-add to apply `grad_scale` and `grad_bias`

    https://pytorch.org/docs/stable/generated/torch.add.html

    >>> a = torch.randn(4)
    >>> a
    tensor([ 0.0202,  1.0985,  1.3506, -0.6056])
    
    >>> b = torch.randn(4)
    >>> b
    tensor([-0.9732, -0.3497,  0.6245,  0.4022])
    >>> c = torch.randn(4, 1)
    >>> c
    tensor([[ 0.3743],
            [-1.7724],
            [-0.5811],
            [-0.8017]])
    >>> torch.add(b, c, alpha=10)
    tensor([[  2.7695,   3.3930,   4.3672,   4.1450],
            [-18.6971, -18.0736, -17.0994, -17.3216],
            [ -6.7845,  -6.1610,  -5.1868,  -5.4090],
            [ -8.9902,  -8.3667,  -7.3925,  -7.6147]])
    
    opened by rsokl 0
Releases(v0.2.1)
  • v0.2.1(Jun 16, 2022)

    See changelog

    What's Changed

    • Fix TopQGradient device mismatch by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/64
    • Fixes to_xarray when target_job_dirs points to job that performed multirun over sequence values by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/68

    Full Changelog: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Jun 1, 2022)

    See changelog for details

    What's Changed

    • zen should not attempt to populate *args and **kwargs by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/48
    • Workflow Improvement by @jgbos in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/47
    • Adds zen callbacks by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/49
    • Remove numpy dependency (defer to pytorch) by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/50
    • Make methods static where possible; simplify examples; cleanup formating by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/51
    • Add working_subdir data variable to xarray by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/52
    • Improve parity between pre-step and post-step method names by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/54
    • fix typo in univ_adv_pert.rst by @Jasha10 in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/56
    • Update workflows by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/55
    • Adds Support for Lightning's Trainer.predict by @jgbos in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/53
    • Fix hypothesis by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/58
    • Add pre-task method to workflow by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/57
    • Deprecate ParamTransformingOptimizer.project by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/59
    • Add pre-release and nightly CI jobs by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/60
    • Ensure workflow overrides roundtrip by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/61
    • Deprecate evaluation_task in favor of task by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/62
    • Enable user-specified functions for loading metrics files by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/63

    New Contributors

    • @Jasha10 made their first contribution in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/56

    Full Changelog: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/compare/v0.1.1...v0.2.0

    Source code(tar.gz)
    Source code(zip)
Deep Learning and Logical Reasoning from Data and Knowledge

Logic Tensor Networks (LTN) Logic Tensor Network (LTN) is a neurosymbolic framework that supports querying, learning and reasoning with both rich data

171 Dec 29, 2022
Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes

Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes [Paper] Method overview 4DMatch Benchmark 4DMatch is a benchmark for matc

103 Jan 06, 2023
Release of SPLASH: Dataset for semantic parse correction with natural language feedback in the context of text-to-SQL parsing

SPLASH: Semantic Parsing with Language Assistance from Humans SPLASH is dataset for the task of semantic parse correction with natural language feedba

Microsoft Research - Language and Information Technologies (MSR LIT) 35 Oct 31, 2022
codes for Image Inpainting with External-internal Learning and Monochromic Bottleneck

Image Inpainting with External-internal Learning and Monochromic Bottleneck This repository is for the CVPR 2021 paper: 'Image Inpainting with Externa

97 Nov 29, 2022
deep learning for image processing including classification and object-detection etc.

深度学习在图像处理中的应用教程 前言 本教程是对本人研究生期间的研究内容进行整理总结,总结的同时也希望能够帮助更多的小伙伴。后期如果有学习到新的知识也会与大家一起分享。 本教程会以视频的方式进行分享,教学流程如下: 1)介绍网络的结构与创新点 2)使用Pytorch进行网络的搭建与训练 3)使用Te

WuZhe 13.6k Jan 04, 2023
A Loss Function for Generative Neural Networks Based on Watson’s Perceptual Model

This repository contains the similarity metrics designed and evaluated in the paper, and instructions and code to re-run the experiments. Implementation in the deep-learning framework PyTorch

Steffen 86 Dec 27, 2022
Official code for the CVPR 2021 paper "How Well Do Self-Supervised Models Transfer?"

How Well Do Self-Supervised Models Transfer? This repository hosts the code for the experiments in the CVPR 2021 paper How Well Do Self-Supervised Mod

Linus Ericsson 157 Dec 16, 2022
Simple reimplemetation experiments about FcaNet

FcaNet-CIFAR An implementation of the paper FcaNet: Frequency Channel Attention Networks on CIFAR10/CIFAR100 dataset. how to run Code: python Cifar.py

76 Feb 04, 2021
PyTorch implementation of UNet++ (Nested U-Net).

PyTorch implementation of UNet++ (Nested U-Net) This repository contains code for a image segmentation model based on UNet++: A Nested U-Net Architect

4ui_iurz1 642 Jan 04, 2023
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

selfcontact This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] It includes the main function

Lea Müller 68 Dec 06, 2022
Rafael Project- Classifying rockets to different types using data science algorithms.

Rocket-Classify Rafael Project- Classifying rockets to different types using data science algorithms. In this project we received data base with data

Hadassah Engel 5 Sep 18, 2021
(CVPR 2021) Back-tracing Representative Points for Voting-based 3D Object Detection in Point Clouds

BRNet Introduction This is a release of the code of our paper Back-tracing Representative Points for Voting-based 3D Object Detection in Point Clouds,

86 Oct 05, 2022
End-to-end machine learning project for rices detection

Basmatinet Welcome to this project folks ! Whether you like it or not this project is all about riiiiice or riz in french. It is also about Deep Learn

Béranger 47 Jun 18, 2022
🌾 PASTIS 🌾 Panoptic Agricultural Satellite TIme Series

🌾 PASTIS 🌾 Panoptic Agricultural Satellite TIme Series (optical and radar) The PASTIS Dataset Dataset presentation PASTIS is a benchmark dataset for

86 Jan 04, 2023
PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis

Daft-Exprt - PyTorch Implementation PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis The

Keon Lee 47 Dec 18, 2022
Simple and understandable swin-transformer OCR project

swin-transformer-ocr ocr with swin-transformer Overview Simple and understandable swin-transformer OCR project. The model in this repository heavily r

Ha YongWook 67 Dec 31, 2022
ROS Basics and TurtleSim

Waypoint Follower Anna Garverick This package draws given waypoints, then waits for a service call with a start position to send the turtle to each wa

Anna Garverick 1 Dec 13, 2021
Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021

Memory-Efficient Multi-Level In-Situ Generation (MLG) By Jiaqi Gu, Hanqing Zhu, Chenghao Feng, Mingjie Liu, Zixuan Jiang, Ray T. Chen and David Z. Pan

Jiaqi Gu 2 Jan 04, 2022
This is the official implement of paper "ActionCLIP: A New Paradigm for Action Recognition"

This is an official pytorch implementation of ActionCLIP: A New Paradigm for Video Action Recognition [arXiv] Overview Content Prerequisites Data Prep

268 Jan 09, 2023
This repository contains the source code of our work on designing efficient CNNs for computer vision

Efficient networks for Computer Vision This repo contains source code of our work on designing efficient networks for different computer vision tasks:

Sachin Mehta 386 Nov 26, 2022