A small library for creating and manipulating custom JAX Pytree classes

Overview

Treeo

A small library for creating and manipulating custom JAX Pytree classes

  • Light-weight: has no dependencies other than jax.
  • Compatible: Treeo Tree objects are compatible with any jax function that accepts Pytrees.
  • Standards-based: treeo.field is built on top of python's dataclasses.field.
  • Flexible: Treeo is compatible with both dataclass and non-dataclass classes.

Treeo lets you easily create class-based Pytrees so your custom objects can easily interact seamlessly with JAX. Uses of Treeo can range from just creating simple simple JAX-aware utility classes to using it as the core abstraction for full-blown frameworks. Treeo was originally extracted from the core of Treex and shares a lot in common with flax.struct.

Documentation | User Guide

Installation

Install using pip:

pip install treeo

Basics

With Treeo you can easily define your own custom Pytree classes by inheriting from Treeo's Tree class and using the field function to declare which fields are nodes (children) and which are static (metadata):

import treeo as to

@dataclass
class Person(to.Tree):
    height: jnp.array = to.field(node=True) # I am a node field!
    name: str = to.field(node=False) # I am a static field!

field is just a wrapper around dataclasses.field so you can define your Pytrees as dataclasses, but Treeo fully supports non-dataclass classes as well. Since all Tree instances are Pytree they work with the various functions from thejax library as expected:

p = Person(height=jnp.array(1.8), name="John")

# Trees can be jitted!
jax.jit(lambda person: person)(p) # Person(height=array(1.8), name='John')

# Trees can be mapped!
jax.tree_map(lambda x: 2 * x, p) # Person(height=array(3.6), name='John')

Kinds

Treeo also include a kind system that lets you give semantic meaning to fields (what a field represents within your application). A kind is just a type you pass to field via its kind argument:

class Parameter: pass
class BatchStat: pass

class BatchNorm(to.Tree):
    scale: jnp.ndarray = to.field(node=True, kind=Parameter)
    mean: jnp.ndarray = to.field(node=True, kind=BatchStat)

Kinds are very useful as a filtering mechanism via treeo.filter:

model = BatchNorm(...)

# select only Parameters, mean is filtered out
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)

Nothing behaves like None in Python, but it is a special value that is used to represent the absence of a value within Treeo.

Treeo also offers the merge function which lets you rejoin filtered Trees with a logic similar to Python dict.update but done recursively:

def loss_fn(params, model, ...):
    # add traced params to model
    model = to.merge(model, params)
    ...

# gradient only w.r.t. params
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)
grads = jax.grad(loss_fn)(params, model, ...)

For a more in-depth tour check out the User Guide.

Examples

A simple Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Character(to.Tree):
    position: jnp.ndarray = to.field(node=True)    # node field
    name: str = to.field(node=False, opaque=True)  # static field

character = Character(position=jnp.array([0, 0]), name='Adam')

# character can freely pass through jit
@jax.jit
def update(character: Character, velocity, dt) -> Character:
    character.position += velocity * dt
    return character

character = update(character velocity=jnp.array([1.0, 0.2]), dt=0.1)

A Stateful Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Counter(to.Tree):
    n: jnp.array = to.field(default=jnp.array(0), node=True) # node
    step: int = to.field(default=1, node=False) # static

    def inc(self):
        self.n += self.step

counter = Counter(step=2) # Counter(n=jnp.array(0), step=2)

@jax.jit
def update(counter: Counter):
    counter.inc()
    return counter

counter = update(counter) # Counter(n=jnp.array(2), step=2)

# map over the tree

Full Example - Linear Regression

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import treeo as to


class Linear(to.Tree):
    w: jnp.ndarray = to.node()
    b: jnp.ndarray = to.node()

    def __init__(self, din, dout, key):
        self.w = jax.random.uniform(key, shape=(din, dout))
        self.b = jnp.zeros(shape=(dout,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b


@jax.value_and_grad
def loss_fn(model, x, y):
    y_pred = model(x)
    loss = jnp.mean((y_pred - y) ** 2)

    return loss


def sgd(param, grad):
    return param - 0.1 * grad


@jax.jit
def train_step(model, x, y):
    loss, grads = loss_fn(model, x, y)
    model = jax.tree_map(sgd, model, grads)

    return loss, model


x = np.random.uniform(size=(500, 1))
y = 1.4 * x - 0.3 + np.random.normal(scale=0.1, size=(500, 1))

key = jax.random.PRNGKey(0)
model = Linear(1, 1, key=key)

for step in range(1000):
    loss, model = train_step(model, x, y)
    if step % 100 == 0:
        print(f"loss: {loss:.4f}")

X_test = np.linspace(x.min(), x.max(), 100)[:, None]
y_pred = model(X_test)

plt.scatter(x, y, c="k", label="data")
plt.plot(X_test, y_pred, c="b", linewidth=2, label="prediction")
plt.legend()
plt.show()
Comments
  • Use field kinds within tree_map

    Use field kinds within tree_map

    Firstly, thanks for creating Treeo - it's a fantastic package.

    Is there a way to use methods defined within a field's kind object within a tree_map call? For example, consider the following MWE

    import jax.numpy as jnp
    
    class Parameter:
        def transform(self):
            return jnp.exp(self)
    
    
    @dataclass
    class Model(to.Tree):
        lengthscale: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=Parameter
        )
    

    is there a way that I could do something similar to the following pseudocode snippet:

    m = Model()
    jax.tree_map(lamdba x: x.transform(), to.filter(m, Parameter))
    
    opened by thomaspinder 10
  • Stacking of Treeo.Tree

    Stacking of Treeo.Tree

    I'm running into some issues when trying to stack a list of Treeo.Tree objects into a single object. I've made a short example:

    from dataclasses import dataclass
    
    import jax
    import jax.numpy as jnp
    import treeo as to
    
    @dataclass
    class Person(to.Tree):
        height: jnp.array = to.field(node=True) # I am a node field!
        age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
        name: str = to.field(node=False) # I am a static field!
    
    persons = [
        Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
        Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
        Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
    ]
    
    # Stack (struct of arrays instead of list of structs)
    jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    

    However, this fails with the following exception:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    Cell In[1], line 18
         11     name: str = to.field(node=False) # I am a static field!
         13 persons = [
         14     Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
         15     Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
         16     Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
         17 ]
    ---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').
    

    Versions used:

    • JAX: 0.3.20
    • Treeo: 0.0.10

    From a certain perspective this is expected because jax.tree_map does not apply to static (node=False) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields. Has anyone has tried something similar and come up with a nice solution?

    opened by peterroelants 3
  • Jitting twice for a class method

    Jitting twice for a class method

    import jax
    import jax.numpy as jnp
    import treeo as to
    
    class A(to.Tree):
        X: jnp.array = to.field(node=True)
        
        def __init__(self):
            self.X = jnp.ones((50, 50))
    
        @jax.jit
        def f(self, Y):
            return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)
    
    Y = jnp.ones(2)
    for i in range(5):
        print(A.f._cache_size())
        a = A()
        a.f(Y)
    

    The output of the above is 0 1 2 2 2 with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is 0 1 1 1 1. Thanks.

    opened by pipme 2
  • Change Mutable API

    Change Mutable API

    Changes

    • Previously self.mutable(*args, method=method, **kwargs)
    • Is now...... self.mutable(method=method)(*args, **kwargs)
    • Opaque API is removed
    • inplace argument is now only available for apply.
    • Immutable.{mutable, toplevel_mutable} methods are removed.
    fix 
    opened by cgarciae 1
  • Improve mutability support

    Improve mutability support

    Changes

    • Fixes issues with immutability in compact context
    • The make_mutable context manager and the mutable function now expose a toplevel_only: bool argument.
    • Adds a _get_unbound_method private function in utils.
    feature 
    opened by cgarciae 1
  • Bug Fixes from 0.0.11

    Bug Fixes from 0.0.11

    Changes

    • Fixes an issues that disabled mutability inside __init__ for Immutable classes when TreeMeta's `constructor method is overloaded.
    • Fixes the Apply.apply mixin method.

    Closes cgarciae/treex#68

    fix 
    opened by cgarciae 1
  • Adds support for immutable Trees

    Adds support for immutable Trees

    Changes

    • Adds an Immutable mixin that can make Trees effectively immutable (as far as python permits).
    • Immutable contains the .replace and .mutable methods that let you manipulate state in a functionally pure fashion.
    • Adds the mutable function transformation / decorator which lets you turn function that perform mutable operation into pure functions.
    opened by cgarciae 1
  • Add the option of using add_field_info inside map

    Add the option of using add_field_info inside map

    This PR addresses the comments made in #2 . An additional argument is created within map to allow for a field_info boolean flag to passed. When true, jax.tree_map is carried out under the with add_field_info(): context manager.

    Tests have been added to test for correct function application on classes contain Trees with mixed kind types.

    A brief section has been added to the documentation to reflect the above changes.

    opened by thomaspinder 1
  • Get all unique kinds

    Get all unique kinds

    Hi,

    Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:

    class KindOne: pass
    class KindTwo: pass
    
    @dataclass
    class SubModel(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindOne
        )
    
    
    @dataclass 
    class Model(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindTwo
        )
    
    m = Model()
    
    m.unique_kinds() # [KindOne, KindTwo]
    
    opened by thomaspinder 1
  • Compact

    Compact

    Changes

    • Removes opaque_is_equal, same functionality available through opaque.
    • Adds compact decorator that enable the definition of Tree subnodes at runtime.
    • Adds the Compact mixin that adds the first_run property and the get_field method.
    opened by cgarciae 0
  • Relax jax/jaxlib version constraints

    Relax jax/jaxlib version constraints

    Now that jax 0.3.0 and jaxlib 0.3.0 have been released the version constraints in pyproject.toml are outdated.

    https://github.com/cgarciae/treeo/blob/a402f3f69557840cfbee4d7804964b8e2c47e3f7/pyproject.toml#L16-L17

    This corresponds to the version constraint jax<0.3.0,>=0.2.18 (https://python-poetry.org/docs/dependency-specification/#caret-requirements). Now that jax v0.3.0 has been released (https://github.com/google/jax/releases/tag/jax-v0.3.0) this doesn't work with the latest version. I think the same applies to jaxlib as well, since it also got upgraded to v0.3.0 (https://github.com/google/jax/releases/tag/jaxlib-v0.3.0).

    opened by samuela 4
  • TracedArrays treated as nodes by default

    TracedArrays treated as nodes by default

    Current for convenience all non-Tree fields which are not declared are set to static fields as most fields actually are, however, for more complex applications a Traced Array might actually be passed when a static field is usually expected.

    A simple solution is change the current node policy to treat any field containing a TracedArray as a node, this would be the same as the current policy for Tree fields.

    opened by cgarciae 0
Releases(0.2.1)
Owner
Cristian Garcia
ML Engineer at Quansight, working on Treex and Elegy.
Cristian Garcia
A package, and script, to perform imaging transcriptomics on a neuroimaging scan.

Imaging Transcriptomics Imaging transcriptomics is a methodology that allows to identify patterns of correlation between gene expression and some prop

Alessio Giacomel 10 Dec 27, 2022
ONNX-PackNet-SfM: Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Ibai Gorordo 14 Dec 09, 2022
A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''.

P-tuning A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''. How to use our code We have released the code

THUDM 562 Dec 27, 2022
A PyTorch Implementation of the Luna: Linear Unified Nested Attention

Unofficial PyTorch implementation of Luna: Linear Unified Nested Attention The quadratic computational and memory complexities of the Transformer’s at

Soohwan Kim 32 Nov 07, 2022
A PaddlePaddle implementation of Time Interval Aware Self-Attentive Sequential Recommendation.

TiSASRec.paddle A PaddlePaddle implementation of Time Interval Aware Self-Attentive Sequential Recommendation. Introduction 论文:Time Interval Aware Sel

Paddorch 2 Nov 28, 2021
MG-GCN: Scalable Multi-GPU GCN Training Framework

MG-GCN MG-GCN: multi-GPU GCN training framework. For more information, please read our paper. After cloning our repository, run git submodule update -

Translational Data Analytics (TDA) Lab @GaTech 6 Oct 24, 2022
Pytorch implementation of "ARM: Any-Time Super-Resolution Method"

ARM-Net Dependencies Python 3.6 Pytorch 1.7 Results Train Data preprocessing cd data_scripts python extract_subimages_test.py python data_augmentation

Bohong Chen 55 Nov 24, 2022
Inferring Lexicographically-Ordered Rewards from Preferences

Inferring Lexicographically-Ordered Rewards from Preferences Code author: Alihan Hüyük ([e

Alihan Hüyük 1 Feb 13, 2022
Doods2 - API for detecting objects in images and video streams using Tensorflow

DOODS2 - Return of DOODS Dedicated Open Object Detection Service - Yes, it's a b

Zach 101 Jan 04, 2023
Automatic number plate recognition using tech: Yolo, OCR, Scene text detection, scene text recognation, flask, torch

Automatic Number Plate Recognition Automatic Number Plate Recognition (ANPR) is the process of reading the characters on the plate with various optica

Meftun AKARSU 52 Dec 22, 2022
Alpha-Zero - Telegram Group Manager Bot Written In Python Using Pyrogram

✨ Alpha Zero Bot ✨ Telegram Group Manager Bot + Userbot Written In Python Using

1 Feb 17, 2022
Python scripts for performing stereo depth estimation using the HITNET Tensorflow model.

HITNET-Stereo-Depth-estimation Python scripts for performing stereo depth estimation using the HITNET Tensorflow model from Google Research. Stereo de

Ibai Gorordo 76 Jan 02, 2023
Deep Watershed Transform for Instance Segmentation

Deep Watershed Transform Performs instance level segmentation detailed in the following paper: Min Bai and Raquel Urtasun, Deep Watershed Transformati

193 Nov 20, 2022
"Exploring Vision Transformers for Fine-grained Classification" at CVPRW FGVC8

FGVC8 Exploring Vision Transformers for Fine-grained Classification paper presented at the CVPR 2021, The Eight Workshop on Fine-Grained Visual Catego

Marcos V. Conde 19 Dec 06, 2022
This repo provides code for QB-Norm (Cross Modal Retrieval with Querybank Normalisation)

This repo provides code for QB-Norm (Cross Modal Retrieval with Querybank Normalisation) Usage example python dynamic_inverted_softmax.py --sims_train

36 Dec 29, 2022
Source code for CIKM 2021 paper for Relation-aware Heterogeneous Graph for User Profiling

RHGN Source code for CIKM 2021 paper for Relation-aware Heterogeneous Graph for User Profiling Dependencies torch==1.6.0 torchvision==0.7.0 dgl==0.7.1

Big Data and Multi-modal Computing Group, CRIPAC 6 Nov 29, 2022
InsCLR: Improving Instance Retrieval with Self-Supervision

InsCLR: Improving Instance Retrieval with Self-Supervision This is an official PyTorch implementation of the InsCLR paper. Download Dataset Dataset Im

Zelu Deng 25 Aug 30, 2022
A minimal solution to hand motion capture from a single color camera at over 100fps. Easy to use, plug to run.

Minimal Hand A minimal solution to hand motion capture from a single color camera at over 100fps. Easy to use, plug to run. This project provides the

Yuxiao Zhou 824 Jan 07, 2023
Object tracking implemented with YOLOv4, DeepSort, and TensorFlow.

Object tracking implemented with YOLOv4, DeepSort, and TensorFlow. YOLOv4 is a state of the art algorithm that uses deep convolutional neural networks to perform object detections. We can take the ou

The AI Guy 1.1k Dec 29, 2022
My published benchmark for a Kaggle Simulations Competition

Lux AI Working Title Bot Please refer to the Kaggle notebook for the comment section. The comment section contains my explanation on my code structure

Tong Hui Kang 29 Aug 22, 2022