JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library

Overview

JAX bindings to FINUFFT

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

Included features

This library is currently CPU-only, but GPU support is in the works using the cuFINUFFT library.

Type 1 and 2 transforms are supported in 1-, 2-, and 3-dimensions. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using vmap.

Installation

For now, only a source build is supported.

For building, you should only need a recent version of Python (>3.6) and FFTW. At runtime, you'll need numpy, scipy, and jax. To set up such an environment, you can use conda (but you're welcome to use whatever workflow works for you!):

conda create -n jax-finufft -c conda-forge python=3.9 numpy scipy fftw
python -m pip install "jax[cpu]"

Then you can install from source using (don't forget the --recursive flag because FINUFFT is included as a submodule):

git clone --recursive https://github.com/dfm/jax-finufft
cd jax-finufft
python -m pip install .

Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): nufft1 and nufft2 (for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, please note that the function signatures here are different!

For example, here's how you can do a 1-dimensional type 1 transform:

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

Noting that the eps and iflag are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the finufft Python package.

The syntax for a 2-, or 3-dimensional transform is:

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

The syntax for a type 2 transform is (also allowing optional iflag and eps parameters):

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

Similar libraries

  • finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
  • mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.

Comments
  • batching issue

    batching issue

    Hi,

    I get the following error when I try to batch nufft2 in Jax.

    process_primitive(self, primitive, tracers, params) 161 frame = self.get_frame(vals_in, dims_in) 162 batched_primitive = self.get_primitive_batcher(primitive, frame) --> 163 val_out, dim_out = batched_primitive(vals_in, dims_in, **params) 164 if primitive.multiple_results: 165 return map(partial(BatchTracer, self), val_out, dim_out)

    TypeError: batch() got an unexpected keyword argument 'output_shape'

    it seems like this is caused by

    nufft2(source, iflag, eps, *points) 57 58 return jnp.reshape( ---> 59 nufft2_p.bind(source, *points, output_shape=None, iflag=iflag, eps=eps), 60 expected_output_shape, 61 )

    Is there something I am doing wrong?

    Thanks for your help!

    opened by samaktbo 13
  • problem batching functions that have multiple arguments

    problem batching functions that have multiple arguments

    Hi,

    I am posting this here to give a clearer description of the problem that I am having.

    When I run the following snippet, I would like to have A be a 4 by 100 array whose I-th row is the output of linear_func(q, X[I,:]).

    import numpy as np
    import jax.numpy as jnp
    from jax import vmap
    from jax_finufft import nufft2
    
    rng = np.random.default_rng(seed=314)
    
    d=10 
    L_tilde = 10
    L = 100
    
    qr = rng.standard_normal((d, L_tilde+1))
    qi = rng.standard_normal((d, L_tilde+1))
    q = jnp.array(qr + 1j * qi)
    X = jnp.array(rng.uniform(low=0.0, high=1.0, size=(4, L)))
    
    def linear_func(q, x):
      v = jnp.ones(shape=(1, L_tilde))
    
      return jnp.matmul(v, nufft2(q, x, eps=1e-6, iflag=-1))
    
    batched = vmap(linear_func, in_axes=(None, 0), out_axes=0)
    
    A = batched(q, X)
    

    However, when I run the snippet I get the error posted below. You had said last time that there could be a work around for unbatched arguments but I could not figure it out.

    Here is the error:

    UnfilteredStackTrace                      Traceback (most recent call last)
    <ipython-input-25-c23bc4fa7eb4> in <module>()
    ----> 1 test = batched(q, X)
    
    30 frames
    UnfilteredStackTrace: TypeError: '<' not supported between instances of 'NoneType' and 'int'
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/jax_finufft/ops.py in <genexpr>(.0)
        281     else:
        282         mx = args[0].ndim - ndim - 1
    --> 283     assert all(a < mx for a in axes)
        284     assert all(a == axes[0] for a in axes[1:])
        285     return prim.bind(*args, **kwargs), axes[0]
    
    TypeError: '<' not supported between instances of 'NoneType' and 'int'
    

    I hope this gives a clearer picture than what I had last time. Thanks so much for your help!

    opened by samaktbo 5
  • Installation issue

    Installation issue

    I am having trouble installing jax-finuff even with the instructions on the home page. The installation fails with

    ERROR: Failed building wheel for jax-finufft Failed to build jax-finufft ERROR: Could not build wheels for jax-finufft, which is required to install pyproject.toml-based projects

    opened by samaktbo 5
  • Error when differentiating nufft1 with respect to points only

    Error when differentiating nufft1 with respect to points only

    Hi Dan,

    First, thank you for releasing this package, I was very glad to find it!

    I noted the error below when attempting to differentiate nufft1 with respect to points. I am very new to JAX so I could be mistaken, but I don't believe this is the intended behavior:

    from jax_finufft import nufft1, nufft2
    import numpy as np
    import jax.numpy as jnp
    from jax import grad
    M = 100000
    N = 200000
    
    x = 2 * np.pi * np.random.uniform(size=M)
    c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
    
    def norm_nufft1(c,x):
        f = nufft1(N, c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    def norm_nufft2(c,x):
        f = nufft2( c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    grad(norm_nufft2,argnums =(1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(1))(c,x) # Throws error
    

    The error is below:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in <module>
         22 grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
         23 grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    ---> 24 grad(norm_nufft1,argnums =(1))(c,x) # Throws error
         25 
         26 
    
        [... skipping hidden 10 frame]
    
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in norm_nufft1(c, x)
         11 
         12 def norm_nufft1(c,x):
    ---> 13     f = nufft1(N, c, x, eps=1e-6, iflag=1)
         14     return jnp.linalg.norm(f)
         15 
    
        [... skipping hidden 21 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in nufft1(output_shape, source, iflag, eps, *points)
         39 
         40     return jnp.reshape(
    ---> 41         nufft1_p.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps),
         42         expected_output_shape,
         43     )
    
        [... skipping hidden 3 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in jvp(type_, prim, args, tangents, output_shape, iflag, eps)
        248 
        249         axis = -2 if type_ == 2 else -ndim - 1
    --> 250         output_tangent *= jnp.concatenate(jnp.broadcast_arrays(*scales), axis=axis)
        251 
        252         expand_shape = (
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in concatenate(arrays, axis)
       3405   if hasattr(arrays[0], "concatenate"):
       3406     return arrays[0].concatenate(arrays[1:], axis)
    -> 3407   axis = _canonicalize_axis(axis, ndim(arrays[0]))
       3408   arrays = _promote_dtypes(*arrays)
       3409   # lax.concatenate can be slow to compile for wide concatenations, so form a
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/util.py in canonicalize_axis(axis, num_dims)
        277   axis = operator.index(axis)
        278   if not -num_dims <= axis < num_dims:
    --> 279     raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
        280   if axis < 0:
        281     axis = axis + num_dims
    
    ValueError: axis -2 is out of bounds for array of dimension 1
    
    bug 
    opened by ma-gilles 4
  • Speed up `vmap`s where none of the points are batched

    Speed up `vmap`s where none of the points are batched

    If none of the points are batched in a vmap, we should be able to get a faster computation by stacking the transforms into a single transform and then reshaping. It might be worth making the stacked axes into an explicit parameter rather than just trying to infer it, but that would take some work on the interface.

    opened by dfm 0
  • WIP: Adding support for non-batched dimensions in `vmap`

    WIP: Adding support for non-batched dimensions in `vmap`

    In most cases, we'll just need to broadcast all the inputs out to the right shapes (this shouldn't be too hard), but when none of the points get mapped, we can get a bit of a speed up by stacking the transforms. This starts to implement that logic, but it's not quite ready yet.

    opened by dfm 0
  • Find a publication venue

    Find a publication venue

    @lgarrison and I have been chatting about the possibility of writing a paper describing what we're doing here. Something like "Differentiable programming with NUFFTs" or "NUFFTs for machine learning applications" or something. This issue is here to remind me to look into possible venues for this.

    opened by dfm 5
  • Starting to add GPU support using cuFINUFFT

    Starting to add GPU support using cuFINUFFT

    So far I just have the CMake definitions to compile cuFINUFFT when nvcc is found, but I haven't started writing the boilerplate needed to loop it into XLA. Coming soon!

    Keeping @lgarrison in the loop.

    opened by dfm 10
  • Support Type 3?

    Support Type 3?

    There aren't currently any plans to support the Type 3 transform for a few reasons:

    • I'm not totally sure of the use cases,
    • The logic will probably be somewhat more complicated than the existing implementations, and
    • cuFINUFFT doesn't seem to support Type 3.

    Do we want to work around these issues?

    opened by dfm 0
  • Add support for handling errors

    Add support for handling errors

    Currently, if any of the finufft methods fail with a non-zero error we just ignore it and keep on trucking. How should we handle this? It looks like the JAX convention is currently to set everything to NaN when an op fails, since propagating errors up from XLA can be a bit of a pain.

    opened by dfm 0
  • Add GPU support

    Add GPU support

    Via cufinufft. We'll need to figure out how to get XLA's handling of CUDA streams to play nice with cufinufft (this is way above my pay grade). Some references:

    1. A simple example of how to write an XLA compatible CUDA kernel: https://github.com/dfm/extending-jax/blob/main/lib/kernels.cc.cu
    2. The source code for how JAX wraps cuBLAS: https://github.com/google/jax/blob/main/jaxlib/cublas_kernels.cc
    3. There is a tensorflow implementation that might have some useful context: https://github.com/mrphys/tensorflow-nufft/blob/master/tensorflow_nufft/cc/kernels/nufft_kernels.cu.cc

    Perhaps @lgarrison is interested :D

    opened by dfm 0
Releases(v0.0.3)
  • v0.0.3(Dec 10, 2021)

    What's Changed

    • Fix segfault when batching multiple transforms by @dfm in https://github.com/dfm/jax-finufft/pull/11
    • Generalize the behavior of vmap by @dfm in https://github.com/dfm/jax-finufft/pull/12

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.2...v0.0.3

    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Nov 12, 2021)

    • Faster differentiation using stacked transforms
    • Better error checking for vmap

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.1...v0.0.2

    Source code(tar.gz)
    Source code(zip)
  • v0.0.1(Nov 8, 2021)

Owner
Dan Foreman-Mackey
Dan Foreman-Mackey
Official code implementation for "Personalized Federated Learning using Hypernetworks"

Personalized Federated Learning using Hypernetworks This is an official implementation of Personalized Federated Learning using Hypernetworks paper. [

Aviv Shamsian 121 Dec 25, 2022
REBEL: Relation Extraction By End-to-end Language generation

REBEL: Relation Extraction By End-to-end Language generation This is the repository for the Findings of EMNLP 2021 paper REBEL: Relation Extraction By

Babelscape 222 Jan 06, 2023
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Malik Boudiaf 138 Dec 12, 2022
OpenMMLab Image and Video Editing Toolbox

Introduction MMEditing is an open source image and video editing toolbox based on PyTorch. It is a part of the OpenMMLab project. The master branch wo

OpenMMLab 3.9k Jan 04, 2023
Implementation of the algorithm shown in the article "Modelo de Predicción de Éxito de Canciones Basado en Descriptores de Audio"

Success Predictor Implementation of the algorithm shown in the article "Modelo de Predicción de Éxito de Canciones Basado en Descriptores de Audio". B

Rodrigo Nazar Meier 4 Mar 17, 2022
《Rethinking Sptil Dimensions of Vision Trnsformers》(2021)

Rethinking Spatial Dimensions of Vision Transformers Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Junsuk Choe, Seong Joon Oh | Paper NAVER

NAVER AI 224 Dec 27, 2022
This repo contains the official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
Deep Learning for Natural Language Processing SS 2021 (TU Darmstadt)

Deep Learning for Natural Language Processing SS 2021 (TU Darmstadt) Task Training huge unsupervised deep neural networks yields to strong progress in

Oliver Hahn 1 Jan 26, 2022
Julia package for contraction of tensor networks, based on the sweep line algorithm outlined in the paper General tensor network decoding of 2D Pauli codes

Julia package for contraction of tensor networks, based on the sweep line algorithm outlined in the paper General tensor network decoding of 2D Pauli codes

Christopher T. Chubb 35 Dec 21, 2022
PyTorch Implementation of PortaSpeech: Portable and High-Quality Generative Text-to-Speech

PortaSpeech - PyTorch Implementation PyTorch Implementation of PortaSpeech: Portable and High-Quality Generative Text-to-Speech. Model Size Module Nor

Keon Lee 279 Jan 04, 2023
NLMpy - A Python package to create neutral landscape models

NLMpy is a Python package for the creation of neutral landscape models that are widely used by landscape ecologists to model ecological patterns

Manaaki Whenua – Landcare Research 1 Oct 08, 2022
Knowledgeable Prompt-tuning: Incorporating Knowledge into Prompt Verbalizer for Text Classification

Knowledgeable Prompt-tuning: Incorporating Knowledge into Prompt Verbalizer for Text Classification

DingDing 143 Jan 01, 2023
Detecting Blurred Ground-based Sky/Cloud Images

Detecting Blurred Ground-based Sky/Cloud Images With the spirit of reproducible research, this repository contains all the codes required to produce t

1 Oct 20, 2021
Code and data accompanying our SVRHM'21 paper.

Code and data accompanying our SVRHM'21 paper. Requires tensorflow 1.13, python 3.7, scikit-learn, and pytorch 1.6.0 to be installed. Python scripts i

5 Nov 17, 2021
Multimodal Descriptions of Social Concepts: Automatic Modeling and Detection of (Highly Abstract) Social Concepts evoked by Art Images

MUSCO - Multimodal Descriptions of Social Concepts Automatic Modeling of (Highly Abstract) Social Concepts evoked by Art Images This project aims to i

0 Aug 22, 2021
This is the code for the paper "Contrastive Clustering" (AAAI 2021)

Contrastive Clustering (CC) This is the code for the paper "Contrastive Clustering" (AAAI 2021) Dependency python=3.7 pytorch=1.6.0 torchvision=0.8

Yunfan Li 210 Dec 30, 2022
Pythonic particle-based (super-droplet) warm-rain/aqueous-chemistry cloud microphysics package with box, parcel & 1D/2D prescribed-flow examples in Python, Julia and Matlab

PySDM PySDM is a package for simulating the dynamics of population of particles. It is intended to serve as a building block for simulation systems mo

Atmospheric Cloud Simulation Group @ Jagiellonian University 32 Oct 18, 2022
Exact Pareto Optimal solutions for preference based Multi-Objective Optimization

Exact Pareto Optimal solutions for preference based Multi-Objective Optimization

Debabrata Mahapatra 40 Dec 24, 2022
Contextual Attention Localization for Offline Handwritten Text Recognition

CALText This repository contains the source code for CALText model introduced in "CALText: Contextual Attention Localization for Offline Handwritten T

0 Feb 17, 2022
Code and data for "Broaden the Vision: Geo-Diverse Visual Commonsense Reasoning" (EMNLP 2021).

GD-VCR Code for Broaden the Vision: Geo-Diverse Visual Commonsense Reasoning (EMNLP 2021). Research Questions and Aims: How well can a model perform o

Da Yin 24 Oct 13, 2022