code

I am enthusiastic about designing tools for astronomy, especially if they enable fast, differentiable forward modeling of observations. You can find all my code on GitHub, a few highlights are described below.


Gaussian processes are powerful priors for continuous fields, but scaling them to millions or billions of parameters remains a challenge. This package uses Vecchia’s approximation with a new generation order and efficient implementation, acheiving linear scaling with very low memory overhead. It effortlessly handles arbitrary point distributions with large dynamic range, and has an exact inverse and determinant available. It is written in JAX, with a faster custom CUDA extension that supports derivatives. Try it out!

graph = gp.build_graph(points, n0=100, k=10)
covariance = (cov_bins, cov_values)
xi = jr.normal(rng, (points.shape[0],))
values = gp.generate(graph, covariance, xi)

Neighbor searches with k-d trees are important for many scientific applications, but JAX lacks a native implementation in part because the task is not very suited to array programming. This package uses two GPU-friendly tree algorithms [1, 2] to perform efficient neighbor searches using only JAX primitives. It will not be quite as performant as a lower-level implementation, but it is easier to integrate with JAX and run wherever JAX runs. An experimental CUDA extension is also available for faster queries.

tree = jk.build_tree(points)
counts = jk.count_neighbors(tree, queries, r=0.1)
neighbors, distances = jk.query_neighbors(tree, queries, k=10)

This small tool makes it easier to read and write probabilistic models in astrophysics using JAX:

@purify
def model():
    x = param(3)
    return jnp.sin(x)

params = model.normal(rng)
model(params)

Parameters are only mentioned once, right next to where they are used. This avoids initialization methods and object-oriented programming while retaining automatic parameter management. There are also several other features which are convenient for the type of forward models needed in science. Under the hood, we use a Jaxpr interpreter to interpret the model, similar to Oryx.