JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library

Overview

JAX bindings to FINUFFT

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

Included features

This library is currently CPU-only, but GPU support is in the works using the cuFINUFFT library.

Type 1 and 2 transforms are supported in 1-, 2-, and 3-dimensions. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using vmap.

Installation

For now, only a source build is supported.

For building, you should only need a recent version of Python (>3.6) and FFTW. At runtime, you'll need numpy, scipy, and jax. To set up such an environment, you can use conda (but you're welcome to use whatever workflow works for you!):

conda create -n jax-finufft -c conda-forge python=3.9 numpy scipy fftw
python -m pip install "jax[cpu]"

Then you can install from source using (don't forget the --recursive flag because FINUFFT is included as a submodule):

git clone --recursive https://github.com/dfm/jax-finufft
cd jax-finufft
python -m pip install .

Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): nufft1 and nufft2 (for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, please note that the function signatures here are different!

For example, here's how you can do a 1-dimensional type 1 transform:

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

Noting that the eps and iflag are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the finufft Python package.

The syntax for a 2-, or 3-dimensional transform is:

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

The syntax for a type 2 transform is (also allowing optional iflag and eps parameters):

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

Similar libraries

  • finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
  • mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.

Comments
  • batching issue

    batching issue

    Hi,

    I get the following error when I try to batch nufft2 in Jax.

    process_primitive(self, primitive, tracers, params) 161 frame = self.get_frame(vals_in, dims_in) 162 batched_primitive = self.get_primitive_batcher(primitive, frame) --> 163 val_out, dim_out = batched_primitive(vals_in, dims_in, **params) 164 if primitive.multiple_results: 165 return map(partial(BatchTracer, self), val_out, dim_out)

    TypeError: batch() got an unexpected keyword argument 'output_shape'

    it seems like this is caused by

    nufft2(source, iflag, eps, *points) 57 58 return jnp.reshape( ---> 59 nufft2_p.bind(source, *points, output_shape=None, iflag=iflag, eps=eps), 60 expected_output_shape, 61 )

    Is there something I am doing wrong?

    Thanks for your help!

    opened by samaktbo 13
  • problem batching functions that have multiple arguments

    problem batching functions that have multiple arguments

    Hi,

    I am posting this here to give a clearer description of the problem that I am having.

    When I run the following snippet, I would like to have A be a 4 by 100 array whose I-th row is the output of linear_func(q, X[I,:]).

    import numpy as np
    import jax.numpy as jnp
    from jax import vmap
    from jax_finufft import nufft2
    
    rng = np.random.default_rng(seed=314)
    
    d=10 
    L_tilde = 10
    L = 100
    
    qr = rng.standard_normal((d, L_tilde+1))
    qi = rng.standard_normal((d, L_tilde+1))
    q = jnp.array(qr + 1j * qi)
    X = jnp.array(rng.uniform(low=0.0, high=1.0, size=(4, L)))
    
    def linear_func(q, x):
      v = jnp.ones(shape=(1, L_tilde))
    
      return jnp.matmul(v, nufft2(q, x, eps=1e-6, iflag=-1))
    
    batched = vmap(linear_func, in_axes=(None, 0), out_axes=0)
    
    A = batched(q, X)
    

    However, when I run the snippet I get the error posted below. You had said last time that there could be a work around for unbatched arguments but I could not figure it out.

    Here is the error:

    UnfilteredStackTrace                      Traceback (most recent call last)
    <ipython-input-25-c23bc4fa7eb4> in <module>()
    ----> 1 test = batched(q, X)
    
    30 frames
    UnfilteredStackTrace: TypeError: '<' not supported between instances of 'NoneType' and 'int'
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/jax_finufft/ops.py in <genexpr>(.0)
        281     else:
        282         mx = args[0].ndim - ndim - 1
    --> 283     assert all(a < mx for a in axes)
        284     assert all(a == axes[0] for a in axes[1:])
        285     return prim.bind(*args, **kwargs), axes[0]
    
    TypeError: '<' not supported between instances of 'NoneType' and 'int'
    

    I hope this gives a clearer picture than what I had last time. Thanks so much for your help!

    opened by samaktbo 5
  • Installation issue

    Installation issue

    I am having trouble installing jax-finuff even with the instructions on the home page. The installation fails with

    ERROR: Failed building wheel for jax-finufft Failed to build jax-finufft ERROR: Could not build wheels for jax-finufft, which is required to install pyproject.toml-based projects

    opened by samaktbo 5
  • Error when differentiating nufft1 with respect to points only

    Error when differentiating nufft1 with respect to points only

    Hi Dan,

    First, thank you for releasing this package, I was very glad to find it!

    I noted the error below when attempting to differentiate nufft1 with respect to points. I am very new to JAX so I could be mistaken, but I don't believe this is the intended behavior:

    from jax_finufft import nufft1, nufft2
    import numpy as np
    import jax.numpy as jnp
    from jax import grad
    M = 100000
    N = 200000
    
    x = 2 * np.pi * np.random.uniform(size=M)
    c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
    
    def norm_nufft1(c,x):
        f = nufft1(N, c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    def norm_nufft2(c,x):
        f = nufft2( c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    grad(norm_nufft2,argnums =(1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(1))(c,x) # Throws error
    

    The error is below:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in <module>
         22 grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
         23 grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    ---> 24 grad(norm_nufft1,argnums =(1))(c,x) # Throws error
         25 
         26 
    
        [... skipping hidden 10 frame]
    
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in norm_nufft1(c, x)
         11 
         12 def norm_nufft1(c,x):
    ---> 13     f = nufft1(N, c, x, eps=1e-6, iflag=1)
         14     return jnp.linalg.norm(f)
         15 
    
        [... skipping hidden 21 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in nufft1(output_shape, source, iflag, eps, *points)
         39 
         40     return jnp.reshape(
    ---> 41         nufft1_p.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps),
         42         expected_output_shape,
         43     )
    
        [... skipping hidden 3 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in jvp(type_, prim, args, tangents, output_shape, iflag, eps)
        248 
        249         axis = -2 if type_ == 2 else -ndim - 1
    --> 250         output_tangent *= jnp.concatenate(jnp.broadcast_arrays(*scales), axis=axis)
        251 
        252         expand_shape = (
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in concatenate(arrays, axis)
       3405   if hasattr(arrays[0], "concatenate"):
       3406     return arrays[0].concatenate(arrays[1:], axis)
    -> 3407   axis = _canonicalize_axis(axis, ndim(arrays[0]))
       3408   arrays = _promote_dtypes(*arrays)
       3409   # lax.concatenate can be slow to compile for wide concatenations, so form a
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/util.py in canonicalize_axis(axis, num_dims)
        277   axis = operator.index(axis)
        278   if not -num_dims <= axis < num_dims:
    --> 279     raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
        280   if axis < 0:
        281     axis = axis + num_dims
    
    ValueError: axis -2 is out of bounds for array of dimension 1
    
    bug 
    opened by ma-gilles 4
  • Speed up `vmap`s where none of the points are batched

    Speed up `vmap`s where none of the points are batched

    If none of the points are batched in a vmap, we should be able to get a faster computation by stacking the transforms into a single transform and then reshaping. It might be worth making the stacked axes into an explicit parameter rather than just trying to infer it, but that would take some work on the interface.

    opened by dfm 0
  • WIP: Adding support for non-batched dimensions in `vmap`

    WIP: Adding support for non-batched dimensions in `vmap`

    In most cases, we'll just need to broadcast all the inputs out to the right shapes (this shouldn't be too hard), but when none of the points get mapped, we can get a bit of a speed up by stacking the transforms. This starts to implement that logic, but it's not quite ready yet.

    opened by dfm 0
  • Find a publication venue

    Find a publication venue

    @lgarrison and I have been chatting about the possibility of writing a paper describing what we're doing here. Something like "Differentiable programming with NUFFTs" or "NUFFTs for machine learning applications" or something. This issue is here to remind me to look into possible venues for this.

    opened by dfm 5
  • Starting to add GPU support using cuFINUFFT

    Starting to add GPU support using cuFINUFFT

    So far I just have the CMake definitions to compile cuFINUFFT when nvcc is found, but I haven't started writing the boilerplate needed to loop it into XLA. Coming soon!

    Keeping @lgarrison in the loop.

    opened by dfm 10
  • Support Type 3?

    Support Type 3?

    There aren't currently any plans to support the Type 3 transform for a few reasons:

    • I'm not totally sure of the use cases,
    • The logic will probably be somewhat more complicated than the existing implementations, and
    • cuFINUFFT doesn't seem to support Type 3.

    Do we want to work around these issues?

    opened by dfm 0
  • Add support for handling errors

    Add support for handling errors

    Currently, if any of the finufft methods fail with a non-zero error we just ignore it and keep on trucking. How should we handle this? It looks like the JAX convention is currently to set everything to NaN when an op fails, since propagating errors up from XLA can be a bit of a pain.

    opened by dfm 0
  • Add GPU support

    Add GPU support

    Via cufinufft. We'll need to figure out how to get XLA's handling of CUDA streams to play nice with cufinufft (this is way above my pay grade). Some references:

    1. A simple example of how to write an XLA compatible CUDA kernel: https://github.com/dfm/extending-jax/blob/main/lib/kernels.cc.cu
    2. The source code for how JAX wraps cuBLAS: https://github.com/google/jax/blob/main/jaxlib/cublas_kernels.cc
    3. There is a tensorflow implementation that might have some useful context: https://github.com/mrphys/tensorflow-nufft/blob/master/tensorflow_nufft/cc/kernels/nufft_kernels.cu.cc

    Perhaps @lgarrison is interested :D

    opened by dfm 0
Releases(v0.0.3)
  • v0.0.3(Dec 10, 2021)

    What's Changed

    • Fix segfault when batching multiple transforms by @dfm in https://github.com/dfm/jax-finufft/pull/11
    • Generalize the behavior of vmap by @dfm in https://github.com/dfm/jax-finufft/pull/12

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.2...v0.0.3

    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Nov 12, 2021)

    • Faster differentiation using stacked transforms
    • Better error checking for vmap

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.1...v0.0.2

    Source code(tar.gz)
    Source code(zip)
  • v0.0.1(Nov 8, 2021)

Owner
Dan Foreman-Mackey
Dan Foreman-Mackey
Implementation of the paper: "SinGAN: Learning a Generative Model from a Single Natural Image"

SinGAN This is an unofficial implementation of SinGAN from someone who's been sitting right next to SinGAN's creator for almost five years. Please ref

35 Nov 10, 2022
Self-supervised learning optimally robust representations for domain generalization.

OptDom: Learning Optimal Representations for Domain Generalization This repository contains the official implementation for Optimal Representations fo

Yangjun Ruan 18 Aug 25, 2022
[CVPR 2020] Transform and Tell: Entity-Aware News Image Captioning

Transform and Tell: Entity-Aware News Image Captioning This repository contains the code to reproduce the results in our CVPR 2020 paper Transform and

Alasdair Tran 85 Dec 13, 2022
Automated Evidence Collection for Fake News Detection

Automated Evidence Collection for Fake News Detection This is the code repo for the Automated Evidence Collection for Fake News Detection paper accept

Mrinal Rawat 2 Apr 12, 2022
Advancing Self-supervised Monocular Depth Learning with Sparse LiDAR

Official implementation for paper "Advancing Self-supervised Monocular Depth Learning with Sparse LiDAR"

Ziyue Feng 72 Dec 09, 2022
Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized Codes

Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized C

Sam Bond-Taylor 139 Jan 04, 2023
Codes to calculate solar-sensor zenith and azimuth angles directly from hyperspectral images collected by UAV. Works only for UAVs that have high resolution GNSS/IMU unit.

UAV Solar-Sensor Angle Calculation Table of Contents About The Project Built With Getting Started Prerequisites Installation Datasets Contributing Lic

Sourav Bhadra 1 Jan 15, 2022
Optimus: the first large-scale pre-trained VAE language model

Optimus: the first pre-trained Big VAE language model This repository contains source code necessary to reproduce the results presented in the EMNLP 2

314 Dec 19, 2022
Credo AI Lens is a comprehensive assessment framework for AI systems. Lens standardizes model and data assessment, and acts as a central gateway to assessments created in the open source community.

Lens by Credo AI - Responsible AI Assessment Framework Lens is a comprehensive assessment framework for AI systems. Lens standardizes model and data a

Credo AI 27 Dec 14, 2022
This Jupyter notebook shows one way to implement a simple first-order low-pass filter on sampled data in discrete time.

How to Implement a First-Order Low-Pass Filter in Discrete Time We often teach or learn about filters in continuous time, but then need to implement t

Joshua Marshall 4 Aug 24, 2022
Planning from Pixels in Environments with Combinatorially Hard Search Spaces -- NeurIPS 2021

PPGS: Planning from Pixels in Environments with Combinatorially Hard Search Spaces Environment Setup We recommend pipenv for creating and managing vir

Autonomous Learning Group 11 Jun 26, 2022
KoCLIP: Korean port of OpenAI CLIP, in Flax

KoCLIP This repository contains code for KoCLIP, a Korean port of OpenAI's CLIP. This project was conducted as part of Hugging Face's Flax/JAX communi

Jake Tae 100 Jan 02, 2023
An LSTM for time-series classification

Update 10-April-2017 And now it works with Python3 and Tensorflow 1.1.0 Update 02-Jan-2017 I updated this repo. Now it works with Tensorflow 0.12. In

Rob Romijnders 391 Dec 27, 2022
Anchor-free Oriented Proposal Generator for Object Detection

Anchor-free Oriented Proposal Generator for Object Detection Gong Cheng, Jiabao Wang, Ke Li, Xingxing Xie, Chunbo Lang, Yanqing Yao, Junwei Han, Intro

jbwang1997 56 Nov 15, 2022
This repo contains implementation of different architectures for emotion recognition in conversations.

Emotion Recognition in Conversations Updates 🔥 🔥 🔥 Date Announcements 03/08/2021 🎆 🎆 We have released a new dataset M2H2: A Multimodal Multiparty

Deep Cognition and Language Research (DeCLaRe) Lab 1k Dec 30, 2022
Simultaneous NMT/MMT framework in PyTorch

This repository includes the codes, the experiment configurations and the scripts to prepare/download data for the Simultaneous Machine Translation wi

<a href=[email protected]"> 37 Sep 29, 2022
An open source app to help calm you down when needed.

By: Seanpm2001, Et; Al. Top README.md Read this article in a different language Sorted by: A-Z Sorting options unavailable ( af Afrikaans Afrikaans |

Sean P. Myrick V19.1.7.2 2 Oct 24, 2022
WTTE-RNN a framework for churn and time to event prediction

WTTE-RNN Weibull Time To Event Recurrent Neural Network A less hacky machine-learning framework for churn- and time to event prediction. Forecasting p

Egil Martinsson 727 Dec 28, 2022
A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers.

ViTGAN: Training GANs with Vision Transformers A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers. Refer

Hong-Jia Chen 127 Dec 23, 2022