lanl/scico

Linear operator multiplication with singleton array not supported

Closed this issue · 1 comments

This issue is intended as a distillation of the bug component of #442, which also addresses some broader questions.

In its essence, the problem is that:

  • some scico functions, such as linop.operator_norm, return singleton arrays rather than scalar floats (this behaviour is ubiquitous due to the behaviour of jax.numpy ufuncs, e.g. jax.numpy.sqrt(4.0) returns an array, not a float)
  • in certain contexts, such as when such a computed value is passed as the lam parameter of loss.SquaredL2Loss.prox, it is multiplied by a linear operator within the internal calculations.

A quick fix would be to cast lam to a float within loss.SquaredL2Loss.prox, but this does not address the broader problem that such issues are equally likely to arise in other contexts.

A more general solution would be to add support for multiplication by singleton arrays to linear operators, but this additional flexibility is not currently reflected in typing annotations, and properly accounting for it is likely to require significant changes.

An alternative approach would be to take a careful look at functions nominally expected to return scalars, explicitly casting to float where necessary. A component of this would be to adopt a coherent policy on when to use numpy functions instead of jax.numpy functions since the latter will invariably return a singleton array. Again, this would require comprehensive changes.

After discussion: the correct approach is for singleton arrays to be supported as input wherever floats are.