tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
Pytorch implementation of 'Fingerprint Presentation Attack Detector Using Global-Local Model'

RTK-PAD This is an official pytorch implementation of 'Fingerprint Presentation Attack Detector Using Global-Local Model', which is accepted by IEEE T

6 Aug 01, 2022
Accuracy Aligned. Concise Implementation of Swin Transformer

Accuracy Aligned. Concise Implementation of Swin Transformer This repository contains the implementation of Swin Transformer, and the training codes o

FengWang 77 Dec 16, 2022
Get started with Machine Learning with Python - An introduction with Python programming examples

Machine Learning With Python Get started with Machine Learning with Python An engaging introduction to Machine Learning with Python TL;DR Download all

Learn Python with Rune 130 Jan 02, 2023
My solution for the 7th place / 245 in the Umoja Hack 2022 challenge

Umoja Hack 2022 : Insurance Claim Challenge My solution for the 7th place / 245 in the Umoja Hack 2022 challenge Umoja Hack Africa is a yearly hackath

Souames Annis 17 Jun 03, 2022
GPU Accelerated Non-rigid ICP for surface registration

GPU Accelerated Non-rigid ICP for surface registration Introduction Preivous Non-rigid ICP algorithm is usually implemented on CPU, and needs to solve

Haozhe Wu 144 Jan 04, 2023
Pytorch implementation of the paper "Optimization as a Model for Few-Shot Learning"

Optimization as a Model for Few-Shot Learning This repo provides a Pytorch implementation for the Optimization as a Model for Few-Shot Learning paper.

Albert Berenguel Centeno 238 Jan 04, 2023
Incremental Cross-Domain Adaptation for Robust Retinopathy Screening via Bayesian Deep Learning

Incremental Cross-Domain Adaptation for Robust Retinopathy Screening via Bayesian Deep Learning Update (September 18th, 2021) A supporting document de

Taimur Hassan 1 Mar 16, 2022
Bayesian regularization for functional graphical models.

BayesFGM Paper: Jiajing Niu, Andrew Brown. Bayesian regularization for functional graphical models. Requirements R version 3.6.3 and up Python 3.6 and

0 Oct 07, 2021
Implementation of gaze tracking and demo

Predicting Customer Demand by Using Gaze Detecting and Object Tracking This project is the integration of gaze detecting and object tracking. Predict

2 Oct 20, 2022
Code for "LoRA: Low-Rank Adaptation of Large Language Models"

LoRA: Low-Rank Adaptation of Large Language Models This repo contains the implementation of LoRA in GPT-2 and steps to replicate the results in our re

Microsoft 394 Jan 08, 2023
Linescanning - Package for (pre)processing of anatomical and (linescanning) fMRI data

line scanning repository This repository contains all of the tools used during the acquisition and postprocessing of line scanning data at the Spinoza

Jurjen Heij 4 Sep 14, 2022
Dynamical Wasserstein Barycenters for Time Series Modeling

Dynamical Wasserstein Barycenters for Time Series Modeling This is the code related for the Dynamical Wasserstein Barycenter model published in Neurip

8 Sep 09, 2022
HEAM: High-Efficiency Approximate Multiplier Optimization for Deep Neural Networks

Approximate Multiplier by HEAM What's HEAM? HEAM is a general optimization method to generate high-efficiency approximate multipliers for specific app

4 Sep 11, 2022
Quasi-Dense Similarity Learning for Multiple Object Tracking, CVPR 2021 (Oral)

Quasi-Dense Tracking This is the offical implementation of paper Quasi-Dense Similarity Learning for Multiple Object Tracking. We present a trailer th

ETH VIS Research Group 327 Dec 27, 2022
Official implementation of NeuralFusion: Online Depth Map Fusion in Latent Space

NeuralFusion This is the official implementation of NeuralFusion: Online Depth Map Fusion in Latent Space. We provide code to train the proposed pipel

53 Jan 01, 2023
The Codebase for Causal Distillation for Language Models.

Causal Distillation for Language Models Zhengxuan Wu*,Atticus Geiger*, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D.

Zen 20 Dec 31, 2022
Overview of architecture and implementation of TEDS-Net, as described in MICCAI 2021: "TEDS-Net: Enforcing Diffeomorphisms in Spatial Transformers to Guarantee TopologyPreservation in Segmentations"

TEDS-Net Overview of architecture and implementation of TEDS-Net, as described in MICCAI 2021: "TEDS-Net: Enforcing Diffeomorphisms in Spatial Transfo

Madeleine K Wyburd 14 Jan 04, 2023
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022
EMNLP 2021 paper Models and Datasets for Cross-Lingual Summarisation.

This repository contains data and code for our EMNLP 2021 paper Models and Datasets for Cross-Lingual Summarisation. Please contact me at

9 Oct 28, 2022