LouisDesdoigts/dLux

Improve `filter_jit` Usage Model.

Closed this issue ยท 44 comments

Hi all,
No that the apertures provide a more complex interface it would make sense if we set it up so that certain parameters were ignored as static on command. However, this interacts very badly with the current syntax of np.asarray(param).astype(dtype) syntax that we have in the constructors. The problem is that the Array type parameters cannot be hashed and as a result cannot be marked as static.

For example, in my Toliman forwards model I do not want to learn the radius of the circular aperture, its compression or its strain, I am only interested in the global rotation of the spider (this may sound dumb, but I think it matters because of the background stars). However, I cannot mark these attributes as static using filter_jit because they are Array types.

I do not have a good solution to this on the top of my head. In general this leads to conflicts with the zodiax::ExtendedBase::get_args method. The following MWE demonstrates my problem.

import dLux as dl
import equinox as eqx
import jax

circ: object = dl.CircularAperture(1.)

@ft.partial(eqx.filter_jit, args=(True, circ.get_args("centre"), True))
def dl_loss(data: float, model: object, wavefront: object) -> float:
    psf: float = model(wavefront)
    return jax.lax.integer_pow(data - psf, 2).sum()

dl_loss(pupil_data, circ, wavefront)

This gives me the error,

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[72], line 10
      7     psf: float = model(wavefront)
      8     return jax.lax.integer_pow(data - psf, 2).sum()
---> 10 dl_loss(pupil_data, dl.CircularAperture(1.), wavefront)

File ~/anaconda3/envs/dLux/lib/python3.10/site-packages/equinox/jit.py:82, in _JitWrapper.__call__(_JitWrapper__self, *args, **kwargs)
     81 def __call__(__self, *args, **kwargs):
---> 82     return __self._fun_wrapper(False, args, kwargs)

File ~/anaconda3/envs/dLux/lib/python3.10/site-packages/equinox/jit.py:78, in _JitWrapper._fun_wrapper(self, is_lower, args, kwargs)
     76     return self._cached.lower(dynamic, static)
     77 else:
---> 78     dynamic_out, static_out = self._cached(dynamic, static)
     79     return combine(dynamic_out, static_out.value)

ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'dl_loss' while trying to hash an object of type <class 'tuple'>, (((<function dl_loss at 0x7f3cc2554d30>,), PyTreeDef(*)), ((Array([0., 0.], dtype=float32), Array([1., 1.], dtype=float32), Array(0., dtype=float32), False, Array(1., dtype=float32), Array(1., dtype=float32)), PyTreeDef(((None, CustomNode(CircularAperture[(('centre', 'strain', 'compression', 'rotation', 'occulting', 'softening', 'radius'), ('name',), ('CircularAperture',))], [None, *, *, *, *, *, *]), CustomNode(Wavefront[(('wavelength', 'radius', 'npix', 'pixel_scale'), (), ())], [None, None, None, None])), {}))), <function is_array at 0x7f3db2985d80>). The error was:
TypeError: unhashable type: 'Array'

Regards

Jordan

This I think will require some rather complex discussion so I will do my best to flesh out the context of the issue as I understand it and we can go from there. Marking arguments as static in jit compilation can make things much faster or have very little effect. The benefit of doing so is inconsistent but I have never seen it make something slower.

In general it makes sense to be faster, because the compiler knows more about the operations it is going to perform. I am hoping by marking the parameters of the aperture that we do not want to learn as static the compiler will be clever enough to do things like avoid the conditions entirely (dead code removal is a pretty simple compiler optimisation). While that might not accelerate things too much I would expect to see some improvement.

eqx.filter_jit has a default behaviour of marking python types as static and tracing only the jax arrays. We have been using this to our advantage by constructing the classes so that the parameters we want to learn are stored as arrays and the ones we do not want to learn are stored as python types. This simplifies the interface to the compiler because you do not have to worry about manually constructing the args boolean pytree (this is NEVER fun!).

You don't need to mark any of these as static at all... equinox does this for us, and you only need to mark things as static that determine array shapes etc, which none of those parameters change. AFAIK the performance should be the same either way, static arguments only determine when the function is recompiled dues to array dimensionality/logic changes, which can't be done with array types anyway since they arent't hashable. The data is never going to change shape during optimisation so doesn't need to be static.

If you only want to learn a single parameter this is done with the arguments passed to filter_grad, not filter_jit.

Also as an aside, the spider orientation has no relation to the background stars. It's fixed wrt the pupil/detector orientation and effects all the PSFs.

My previous comment is the context, now we need to consider solutions.

  1. Stop making learnable parameters jax types. This has some merit on zero dimensional tensor like rotation = 0., but what about setting the coefficients of an AberratedAperture? It is not so obvious there but it is still possible by making a list like [float(normal(PRNGKey(0))) for _ in range(n)], but this is much more complex.
  2. Do nothing. We are unlikely to gain huge performance

Only need to mark shape determining parameters as static but you can mark any parameter (type allowing) as static.

There is more than one non-size determining parameter that we mark as static, for example occulting in the aperture.

Also equinox has a default function it uses to mark static arguments is_array, but this function can be replaced or you can just provide your own pytree. I'm just building an example now. It is just taking a really long time because I underestimated it.

Marking arguments as static that don't affect shapes or logic flow don't increase the speed of the compiled function, it only changes when the function is recompiled. The compiled XLA function wont be different if the data parameter is marked static, it'll only recompile the python function if a different data array is passed into the function.

occulting doesn't affect array shapes but it does change the logic flow, so we use a python boolean and equinox automatically treats this as static, which is what we want.

Furthermore you can't make jax array data types static anyway because they aren't hashable, so there is no way to make any learnable floating parameters static.

I'm pretty sure that jax treats static arguments as constants. Is this not correct?

We could switch occulting to cond and it would be marked as static still.

Static arguments are only used to build the hash key for the compiled XLA function dictionary, it doesn't change the XLA function under the hood

I don't think that is quite right. Compilation goes python -> jaxpr -> XLA/MHLO -> optimisations -> Assembly -> Binary. The static arguments are hashed so that compilation can be performed if the value of the static argument is changed. The optimisations are independent of that.

Yes but lets say that a static argument is treated as a constant. This would be the difference between these two functions:

def f(x, a):
    return x*a
def f(x):
    return x*10

The runtime difference of these is a trivial O(1) lookup for the a parameter. All the array operations are the same and so the optimizations should be the same.

I'm most interested in the pruning of conditionals, i.e.

def f(x, a):
    if (a == 1.).all():
        # Branch 1.
    else:
        # Branch 2.

def f(x):
    # Branch 1

and the likes.

Besides I have seen constants vs variables make a big difference. I'm just trying to make that example.

But you're probably right and I will probably feel like an idiot pretty soon.

Yeah this is part of the benefit of keeping python logicals over conds for actual static parameters (ie booleans), the jit compiles function will only compile that single branch.

Yeah id be interested in the example but either way I don't think there will be any significant benefit in the long run, especially if we need to do some syntax change explicitly marking things are static vs not with extra parameters.

OK, I got it working.

Here is the example, which is a 25% speed increase on my machine.

radius: float = 1.
npix: int = 128
nsoft: int = 3
x: float = 0.
y: float = 0.
rotation: float = 0.
pixel_scale: float = 2. * radius / npix

def circ_ap_func(
        radius: float, 
        x: float, 
        y: float,
        rotation: float, 
        nsoft: float,
        pixel_scale: float) -> float:
    # Passing arguments to safe types. 
    centre: float = np.asarray([x, y]).astype(float)
    radius: float = np.asarray(radius).astype(float)
    rotation: float = np.asarray(rotation).astype(float)
    nsoft: float = np.asarray(nsoft).astype(float)
    
    # Organising coords
    ccoords: float = coords(npix, radius)
    
    # Translation 
    ccoords: float = ccoords - centre[:, None, None]
        
    # Rotation
    sin_alpha: float = jax.lax.sin(rotation)
    cos_alpha: float = jax.lax.cos(rotation)
    x: float = jax.lax.index_in_dim(ccoords, 0)
    y: float = jax.lax.index_in_dim(ccoords, 1)
    new_x: float = x * cos_alpha - y * sin_alpha
    new_y: float = x * sin_alpha + y * cos_alpha
    ccoords: float = jax.lax.concatenate([new_x, new_y], 0)        
        
    # Transformation 
    rho: float = hypotenuse(ccoords)
        
    # Linear softening
    distances: float = radius - rho
    lower: float = jax.lax.full_like(distances, 0., dtype=float)
    upper: float = jax.lax.full_like(distances, 1., dtype=float)
    inside: float = jax.lax.max(distances, lower)
    scaled: float = inside / nsoft / pixel_scale
    aperture: float = jax.lax.min(scaled, upper)
    return aperture

jit_circ_ap_func: callable = jax.jit(circ_ap_func)
static_jit_circ_ap_func: callable = jax.jit(circ_ap_func, inline=True, static_argnums=(1, 2, 3, 4))

Interesting I wouldn't have suspected that level of improvement, I'm curious if this is coming from function overhead improvements. How does the speed change if you use different array sizes (ie 512, 2048 etc)?

Either way operations like this will still be dwarfed by the MFT calculations so won't be worth any syntactical/interface changes in the long run.

It get's better with the size increase. This is 512
image

Interesting...

For the time being, you know what I'm going to say - compute is going to be dominated by matrix multiplications/FFTs, and so we should consider this interesting and try and engineer the most painless approach to the instruments we need to model.

image
Ten times for 1024. I also realised that the last two had all arguments marked as static.

Did we ever implement that two sided MFT?

Turns out we were doing it all along!

It was what we implemented from the start. The poppy code actually didn't describe this properly so we were under the impression that we hadn't

OK also, this is current dLux for an AnnularAperture vs the propagator.
image
The resolution is 1024 and I honestly don't know what the difference is at this point, which is causing the code I showed above to run in approximately ten times better time. But, we cannot keep discarding things as not-expensive related to the MFT. The MFT runs on millisecond time at the sizes we care about.

So after talking to @Jordan-Dennis about this some more I think there might be a relatively simple way to improve the speed of jitted models using this. The goal is the create an extra function that is called on models before being passed to jitted equinox functions. This can be done because python floats and ints are hashable, and hence can be set as static in jitted functions. Currently we store all floats as 0d arrays and so can't be jitted, however if we can cast them to python floats and generate a boolean pytree of their locations, we theoretically this new model with the boolean pytree to eqx.filter_jit in order to set them as actually static. This could improve MFT performance for example since most of the parameters are floats or ints.

Furthermore, because all of our wavelength calculations are done independently, it might be possible to create Source objects that store wavelengths as lists of floats, which can then be set as static, if we replace vmaps with tree_maps. This may not be possible, but given the speed improvements shown in this thread I think it is at least worth investigating.

I will have a play around and see if this is even generally possible, but would love your thoughts on this. If it can be made as I have described, this should actually be only a single extra line for potential big speed improvements. I'll prototype something quickly and share results here.

Alright so I did some testing and it seems like this gives functionally zero speed improvements.

Implementation is simple, you essentially create a jit_model which has all 0d array replaced with float, which are then automatically marked as static by eqx.filter_jit because it uses the eqx.is_array function to determine static arguments. ie:

def float_from_0d(x):
    if isinstance(x, np.ndarray):
        return float(x) if x.ndim == 0 else x
    else:
        return x
    
def get_jit_model(model, args):
    opt, non_opt = eqx.partition(model, args)
    float_model = jax.tree_map(float_from_0d, non_opt)
    return eqx.combine(opt, float_model)

args = model.get_args(parameters)
jit_model = get_jit_model(model, args)

I've verified that this does actually mark extra parameters as static like so

from zodaix import Base
is_flipped = lambda leaf0, leaf1: True if leaf0 != leaf1 else False

def print_flipped(pytree):
    if isinstance(pytree, Base):
        pytree = pytree.__dict__    
    for key, value in pytree.items():
        if isinstance(value, (Base, dict)):
            print_flipped(value)
        else:
            if isinstance(value, bool) and value:
                print(key)

# Find jit static mapping
model_is_array = jax.tree_map(eqx.is_array, model)
jit_model_is_array = jax.tree_map(eqx.is_array, jit_model)

# Find which ones are different
diff_tree = jax.tree_map(is_flipped, model_is_array, jit_model_is_array)

# Examine
n = np.array(jax.tree_util.tree_leaves(diff_tree)).sum()
print(f"{n} extra static parameters:")
print_flipped(diff_tree)

In my testing only a handful of values extra were set as static (diameter, pixel_scale_out, flux, etc) and I tested over a series of array sizes and the results were functionally identical in all cases (although possibly with a lower variance fwiw).

I tested over simpler optics such as Toliman and more complex ones with detector effects such as JWST. Results remained the same.


With that said, this could possibly be improved by setting 1d arrays to lists of floats and replace vmaps with tree_maps, and there may be other cases where the speed is actually improved, but seemingly not in our typical use case. I'm thinking of maybe adding this in as a function anyway since it is trivial, but this would more likely belong in zodiax than dLux since it is pytree general.

Let me know what you think

I think it will depend on the type of calculation as to the amount of speed that can be gained. I'm not surprised this did not work so well for the MFT since it is (essentially) a matrix multiplication. There is not much opportunity for the compiler to optimise a matmul compared to a serious of other operations.

There is also the problem of the compiler targeting different architectures with different amounts of optimisation. @jakevdp implied that not a lot of work has gone into the CPU compiler optimisations. This means that in my case static arguments make more room for the compiler to do more whereas on GPU these optimisations might already be getting applied.

I think we will need a better idea of what code is actually running. If I get a chance today I will examine the mhlo output for an apertures example and try and report back what I find. As I said we are getting into architecture territory, however, static arguments should never be slower, so it could still be worth it.

Yeah I think we could get more speed improvements in the aperture module. Dynamic apertures are definitely slower than desired even when jitted. I think it would unfortunately require re-writing a fair bit of the core code to be able to work with both lists and arrays.

Either way I think I'll add it in to Zodiax since I'm sure there are other packages that could be easily accelerated by this.

I wouldn't rush in too fast. I stared this issue before I spoke to @jakevdp. It might even be worth raising an issue for the jax team asking what they thought about the merits. I'm going to do this now actually since I am curious.

I mean I can always put it in and then improve/change it later. Its not going to make thing slower after all.

No, you are correct there. I'm particularly interested to see what they have to say about the different supported architectures. While I am writing my discussion, are there any rules about what I can say? I figure there aren't since this is a public repository, but it doesn't hurt to ask.

Nothing that I'm aware of!

Since when has there been syntax highlighting for jaxprs?
image

So without even going to the mhlo level, we can see significant changes in the jaxpr in the case of static vs non-static.
image
In this case the change is because the instantiation of an array in this line of code can be ignored. This happens to be functionally equivalent to passing a jax.Array in instead of python floats.
image

OK so I spent way too long looking at this but didn't get to set up the old example where it made a difference. Here is my code if someone, wants to pursue this further.

import jax.numpy as np
import jax.lax as jl
import jax
import re
import difflib

def get_mhlo(
        func: callable, 
        *args: object, 
        static_argnums: tuple = None, 
        **kwargs: object
    ) -> str:
    comp_func: callable = jax.jit(func, static_argnums = static_argnums)
    mhlo: str = comp_func.lower(*args, **kwargs).compile().as_text()
    return re.sub("metadata={.*}", "", mhlo)

def print_diff(mhlo: str, comp: str) -> list:
    mhlo_lines: list = mhlo.splitlines()
    comp_lines: list = comp.splitlines()
    diff: iter = difflib.unified_diff(mhlo_lines, comp_lines)
    for line in diff:
        print(line)

def mesh(npix: int) -> float:
    centre: float = (npix - 1.0) / 2.0
    shape: tuple =  (1, npix, npix) 
    x: float = jl.broadcasted_iota(float, shape, 1)
    y: float = jl.broadcasted_iota(float, shape, 2)
    return jl.concatenate([x, y], 0) - centre

def circular_aperture(coordinates: float, radius: float, xcentre: float, ycentre: float) -> float:
    trans: float = coordinates - np.array([[[xcentre]], [[ycentre]]])
    pythags: float = jl.integer_pow(coordinates, 2)
    radii: float = jl.reduce(pythags, 0., jl.add, [0])
    aperture: float = jl.lt(radii, radius).astype(float)
    return aperture

coordinates: float = mesh(1024)

dynamic: callable = jax.jit(circular_aperture)
static: callable = jax.jit(circular_aperture, static_argnums=(1, 2, 3))

%%timeit
dynamic(coordinates, 8., 0., 0.)

%%timeit
static(coordinates, 8., 0., 0.)

dynamic_mhlo: str = get_mhlo(circular_aperture, coordinates, 8., 0., 0.)
static_mhlo: str = get_mhlo(circular_aperture, coordinates, 8., 0., 0., static_argnums = (1, 2, 3))

print_diff(dynamic_mhlo, static_mhlo)

Here is a link to my jax discussion.

So I just went back and reviewed my previous example, but with my newer comparison code. Annoyingly I got a contradiction with my earlier findings. I found that the greatest speed benefits were to be had at lower resolutions. This makes a little more sense to me as most of the optimisations will not be scale-able. However, I did not use exactly the same code. I have found the original code and plan to rerun that to see what my results are. I will post the results when I get them, but it isn't a priority for me.

Yeah this is a low-priority issue, can let it burn in the background

Done in Zodiax experimental.