tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
SelfAugment extends MoCo to include automatic unsupervised augmentation selection.

SelfAugment extends MoCo to include automatic unsupervised augmentation selection. In addition, we've included the ability to pretrain on several new datasets and included a wandb integration.

Colorado Reed 24 Oct 26, 2022
Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning.

Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning. Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive

<a href=[email protected](SZ)"> 7 Dec 16, 2021
TyXe: Pyro-based BNNs for Pytorch users

TyXe: Pyro-based BNNs for Pytorch users TyXe aims to simplify the process of turning Pytorch neural networks into Bayesian neural networks by leveragi

87 Jan 03, 2023
Convert Table data to approximate values with GUI

Table_Editor Convert Table data to approximate values with GUIs... usage - Import methods for extension Tables. Imported method supposed to have only

CLJ 1 Jan 10, 2022
Multi-Modal Fingerprint Presentation Attack Detection: Evaluation On A New Dataset

PADISI USC Dataset This repository analyzes the PADISI-Finger dataset introduced in Multi-Modal Fingerprint Presentation Attack Detection: Evaluation

USC ISI VISTA Computer Vision 6 Feb 06, 2022
Machine Learning Model deployment for Container (TensorFlow Serving)

try_tf_serving ├───dataset │ ├───testing │ │ ├───paper │ │ ├───rock │ │ └───scissors │ └───training │ ├───paper │ ├───rock

Azhar Rizki Zulma 5 Jan 07, 2022
Python Jupyter kernel using Poetry for reproducible notebooks

Poetry Kernel Use per-directory Poetry environments to run Jupyter kernels. No need to install a Jupyter kernel per Python virtual environment! The id

Pathbird 204 Jan 04, 2023
FaRL for Facial Representation Learning

FaRL for Facial Representation Learning This repo hosts official implementation of our paper General Facial Representation Learning in a Visual-Lingui

Microsoft 19 Jan 05, 2022
A library of extension and helper modules for Python's data analysis and machine learning libraries.

Mlxtend (machine learning extensions) is a Python library of useful tools for the day-to-day data science tasks. Sebastian Raschka 2014-2020 Links Doc

Sebastian Raschka 4.2k Jan 02, 2023
A nutritional label for food for thought.

Lexiscore As a first effort in tackling the theme of information overload in content consumption, I've been working on the lexiscore: a nutritional la

Paul Bricman 34 Nov 08, 2022
Learning Tracking Representations via Dual-Branch Fully Transformer Networks

Learning Tracking Representations via Dual-Branch Fully Transformer Networks DualTFR ⭐ We achieves the runner-ups for both VOT2021ST (short-term) and

phiphi 19 May 04, 2022
The implementation of the lifelong infinite mixture model

Lifelong infinite mixture model 📋 This is the implementation of the Lifelong infinite mixture model 📋 Accepted by ICCV 2021 Title : Lifelong Infinit

Fei Ye 5 Oct 20, 2022
Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease

Heart_Disease_Classification Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease Dataset

Ashish 1 Jan 30, 2022
Equipped customers with insights about their EVs Hourly energy consumption and helped predict future charging behavior using LSTM model

Equipped customers with insights about their EVs Hourly energy consumption and helped predict future charging behavior using LSTM model. Designed sample dashboard with insights and recommendation for

Yash 2 Apr 07, 2022
[CVPR'21] Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration

Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration This repository contains the implementation of our paper Locally Aware Pi

sfwang 70 Dec 19, 2022
SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning

SPCL SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning Update on 2021/11/25: ArXiv Ver

Binhui Xie (谢斌辉) 11 Oct 29, 2022
Memory-Augmented Model Predictive Control

Memory-Augmented Model Predictive Control This repository hosts the source code for the journal article "Composing MPC with LQR and Neural Networks fo

Fangyu Wu 1 Jun 19, 2022
Production First and Production Ready End-to-End Speech Recognition Toolkit

WeNet 中文版 Discussions | Docs | Papers | Runtime (x86) | Runtime (android) | Pretrained Models We share neural Net together. The main motivation of WeN

2.7k Jan 04, 2023
Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides

Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides Project | This repo is the officia

CVSM Group - email: <a href=[email protected]"> 33 Dec 28, 2022
Code release for NeRF (Neural Radiance Fields)

NeRF: Neural Radiance Fields Project Page | Video | Paper | Data Tensorflow implementation of optimizing a neural representation for a single scene an

6.5k Jan 01, 2023