xarray-contrib/xbatcher

Friendlier API for multiple inputs/outputs

cmdupuis3 opened this issue · 2 comments

Currently, there doesn't seem to be a nice way of interfacing a batch with a model with multiple inputs/outputs. This list comprehension works, but I think there could be a more elegant solution.

    model.fit([batch[x] for x in (sc.stencil_2D + sc.stencil_3D + sc.vanilla)],
              [batch[x] for x in sc.target])

From here, it would be possible to index with a list and get a list back. So, ideally I'd like to have something like

    model.fit([batch[sc.stencil_2D + sc.stencil_3D + sc.vanilla]],
              [batch[sc.target]])

without the list comprehension boilerplate.

@cmdupuis3 - I know this has sat for a while but do you think you could expand a bit on your use case? Even better if you could provide a simple demo that articulates how you are creating your batches and what data shapes you expect to pass to your model. For example, it is difficult to tell what sc is supposed to be in your example.

@jhamman Ah, yeah that is a little unclear. sc is just a struct with lists of variable names. Maybe instead, you can think of it like this:

    model.fit([batch[x] for x in (var_list1 + var_list2 + var_list3)],
              [batch[x] for x in var_list4])

(changed to)

    model.fit([batch[var_list1 + var_list2 + var_list3]],
              [batch[var_list4]])

So in my case, I have something like this (note that I need to use the squeeze_batch_dim option I added in #39 ):

    bgen = xb.BatchGenerator(
            ds,
            {'nlon':nlons,     'nlat':nlats},
            {'nlon':halo_size, 'nlat':halo_size},
            squeeze_batch_dim = False
        )

    for batch in bgen:
            sub = {'nlon':range(halo_size,nlons-halo_size),
                   'nlat':range(halo_size,nlats-halo_size)}
            ...
            batch_stencil_2D = [batch[x.name] for x in sc.variable]
            batch_target     = [batch[x.name][sub] for x in sc.target]
            ...
            model.compile(loss='mae', optimizer='Adam', metrics=['mae', 'mse', 'accuracy'])
            model.fit(batch_stencil_2D, batch_target, batch_size=32, epochs=1, verbose=0)