flaport/sax

`IndexError`

Closed this issue · 6 comments

Hi Floris, we have a notebook demonstrating sax. It seems to error at cell [18] for some users and not for others. There seem to be only minor differences in the dependencies and we can't figure out what is causing this discrepancy.

Do you have any suggestions for things to look into here? We're stumped after testing several different dependencies.

This is the stack trace

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[17], line 2
      1 # how to pass specific parmaeters to each of the sub-functions for the instances
----> 2 s = circuit_fn(splitter={"params": params0}, combiner={"params": 0 * params0}, beta=3, phase_sifter=dict(phi=2.0))

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/saxtypes.py:267, in sdict.<locals>.wrapper(**kwargs)
    265 @functools.wraps(model)
    266 def wrapper(**kwargs):
--> 267     return sdict(model(**kwargs))

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/circuit.py:226, in _flat_circuit.<locals>._circuit(**settings)
    223 for inst_name, model in inst2model.items():
    224     instances[inst_name] = model(**full_settings.get(inst_name, {}))
--> 226 S = evaluate_fn(analyzed, instances)
    227 return S

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/backends/klu.py:122, in evaluate_circuit_klu(analyzed, instances)
    117     idx += len(ports_map)
    119 Sx = jnp.concatenate(
    120     [jnp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1
    121 )
--> 122 CSx = Sx[..., mask]
    123 Ix = jnp.ones((*batch_shape, n_col))
    124 I_CSx = jnp.concatenate([-CSx, Ix], -1)

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/array.py:319, in ArrayImpl.__getitem__(self, idx)
    317   return lax_numpy._rewriting_take(self, idx)
    318 else:
--> 319   return lax_numpy._rewriting_take(self, idx)

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4152, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4146     if (isinstance(aval, core.DShapedArray) and aval.shape == () and
   4147         dtypes.issubdtype(aval.dtype, np.integer) and
   4148         not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
   4149         isinstance(arr.shape[0], int)):
   4150       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4152 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   4153 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4154                unique_indices, mode, fill_value)

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4230, in _split_index_for_jit(idx, shape)
   4226   raise TypeError(f"JAX does not support string indexing; got {idx=}")
   4228 # Expand any (concrete) boolean indices. We can then use advanced integer
   4229 # indexing logic to handle them.
-> 4230 idx = _expand_bool_indices(idx, shape)
   4232 leaves, treedef = tree_flatten(idx)
   4233 dynamic = [None] * len(leaves)

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4552, in _expand_bool_indices(idx, shape)
   4550     expected_shape = shape[start: start + _ndim(i)]
   4551     if i_shape != expected_shape:
-> 4552       raise IndexError("boolean index did not match shape of indexed array in index "
   4553                        f"{dim_number}: got {i_shape}, expected {expected_shape}")
   4554     out.extend(np.where(i))
   4555 else:

IndexError: boolean index did not match shape of indexed array in index 1: got (18,), expected (14,)

And when we pip freeze for the erroring case (python 3.11 on ubuntu)

Package                       Version         Editable project location
----------------------------- --------------- ---------------------------------------------------------------
absl-py                       2.1.0
accessible-pygments           0.0.4
alabaster                     0.7.16
annotated-types               0.6.0
anyio                         4.2.0
argon2-cffi                   23.1.0
argon2-cffi-bindings          21.2.0
arrow                         1.3.0
astroid                       3.0.3
asttokens                     2.4.1
async-lru                     2.0.4
attrs                         23.2.0
autograd                      1.6.2
Babel                         2.14.0
beautifulsoup4                4.12.3
black                         23.12.1
bleach                        6.1.0
boto3                         1.23.1
botocore                      1.26.10
cached-property               1.5.2
cachetools                    5.3.2
certifi                       2024.2.2
cffi                          1.16.0
cfgv                          3.4.0
chardet                       5.2.0
charset-normalizer            3.3.2
chex                          0.1.82
click                         8.0.3
cloudpickle                   3.0.0
colorama                      0.4.6
comm                          0.2.1
commonmark                    0.9.1
contourpy                     1.2.0
coverage                      7.4.1
cycler                        0.12.1
dask                          2023.10.1
dataclasses-json              0.6.4
debugpy                       1.8.0
decorator                     5.1.1
defusedxml                    0.7.1
dill                          0.3.8
distlib                       0.3.8
docutils                      0.20.1
entrypoints                   0.4
etils                         1.6.0
exceptiongroup                1.2.0
executing                     2.0.1
fastcore                      1.5.29
fastjsonschema                2.19.1
filelock                      3.13.1
flax                          0.7.4
flow360scripts                0.0.1
fonttools                     4.47.2
fqdn                          1.5.1
fsspec                        2024.2.0
future                        0.18.3
gdspy                         1.6.13
gdstk                         0.9.50
gitdb                         4.0.11
GitPython                     3.1.41
gltflib                       1.0.13
grcwa                         0.1.2
h11                           0.14.0
h2                            4.1.0
h5netcdf                      1.0.2
h5py                          3.10.0
hpack                         4.0.0
httpcore                      1.0.3
httpx                         0.26.0
hyperframe                    6.0.1
identify                      2.5.33
idna                          3.6
imagesize                     1.4.1
importlib-metadata            6.11.0
importlib-resources           6.1.1
iniconfig                     2.0.0
ipykernel                     6.28.0
ipython                       8.21.0
ipywidgets                    8.1.1
isoduration                   20.11.0
isort                         5.13.2
jax                           0.4.14
jaxlib                        0.4.14
jaxtyping                     0.2.25
jedi                          0.19.1
Jinja2                        3.1.3
jmespath                      1.0.1
json5                         0.9.14
jsonpointer                   2.4
jsonschema                    4.21.1
jsonschema-specifications     2023.12.1
jupyter                       1.0.0
jupyter_client                8.6.0
jupyter-console               6.6.3
jupyter_core                  5.7.1
jupyter-events                0.9.0
jupyter-lsp                   2.2.2
jupyter_server                2.12.5
jupyter-server-mathjax        0.2.6
jupyter_server_terminals      0.5.2
jupyterlab                    4.0.12
jupyterlab_pygments           0.3.0
jupyterlab_server             2.25.2
jupyterlab-widgets            3.0.9
kiwisolver                    1.4.5
klujax                        0.2.4
locket                        1.0.0
markdown-it-py                3.0.0
MarkupSafe                    2.1.3
marshmallow                   3.20.2
matplotlib                    3.8.2
matplotlib-inline             0.1.6
mccabe                        0.7.0
mdit-py-plugins               0.4.0
mdurl                         0.1.2
memory-profiler               0.61.0
mistune                       3.0.2
ml-dtypes                     0.3.2
mpmath                        1.3.0
msgpack                       1.0.7
multiprocess                  0.70.16
mypy-extensions               1.0.0
myst-parser                   2.0.0
natsort                       8.4.0
nbclient                      0.8.0
nbconvert                     7.16.1
nbdime                        4.0.1
nbformat                      5.9.2
nbsphinx                      0.9.3
nest_asyncio                  1.6.0
networkx                      2.8.8
nodeenv                       1.8.0
notebook                      7.0.8
notebook_shim                 0.2.3
numpy                         1.26.3
opt-einsum                    3.3.0
optax                         0.1.9
orbax-checkpoint              0.5.3
orjson                        3.9.13
overrides                     7.7.0
packaging                     23.2
pandas                        2.2.0
pandocfilters                 1.5.0
parso                         0.8.3
partd                         1.4.1
pathspec                      0.12.1
pexpect                       4.9.0
pillow                        10.2.0
pip                           23.3.1
pkgutil_resolve_name          1.3.10
platformdirs                  4.2.0
pluggy                        1.4.0
ply                           3.11
pre-commit                    3.6.0
prometheus_client             0.20.0
prompt-toolkit                3.0.43
protobuf                      4.25.2
psutil                        5.9.0
ptyprocess                    0.7.0
pure-eval                     0.2.2
pybind11                      2.11.1
pycparser                     2.21
pydantic                      2.6.3
pydantic_core                 2.16.3
pydata-sphinx-theme           0.15.2
Pygments                      2.17.2
PyJWT                         2.8.0
pylint                        3.0.3
PyMieScatt                    1.8.1.1
pyparsing                     3.1.1
pyproject-api                 1.6.1
pyroots                       0.5.0
pyrsistent                    0.20.0
pyswarms                      1.3.0
pytest                        8.0.0
pytest-timeout                2.2.0
python-dateutil               2.8.2
python-json-logger            2.0.7
pytz                          2024.1
PyYAML                        6.0.1
pyzmq                         25.1.2
qtconsole                     5.5.1
QtPy                          2.4.1
referencing                   0.33.0
requests                      2.28.2
responses                     0.24.1
rfc3339-validator             0.1.4
rfc3986-validator             0.1.1
rich                          12.5.1
rpds-py                       0.17.1
Rtree                         1.0.1
ruff                          0.2.1
s3transfer                    0.5.2
sax                           0.12.1
scipy                         1.12.0
Send2Trash                    1.8.2
setuptools                    68.2.2
shapely                       2.0.2
signac                        2.1.0
six                           1.16.0
smmap                         5.0.1
sniffio                       1.3.0
snowballstemmer               2.2.0
soupsieve                     2.5
Sphinx                        7.2.6
sphinx-book-theme             1.1.0
sphinx-copybutton             0.5.2
sphinx-sitemap                2.5.1
sphinx-tabs                   3.4.5
sphinxcontrib-applehelp       1.0.8
sphinxcontrib-devhelp         1.0.6
sphinxcontrib-htmlhelp        2.0.5
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-qthelp          1.0.7
sphinxcontrib-serializinghtml 1.1.10
sphinxemoji                   0.3.1
stack-data                    0.6.2
sympy                         1.12
synced-collections            1.0.0
tensorstore                   0.1.53
terminado                     0.18.0
tidy3d                        2.6.0rc1        /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d_frontend
tidy3d-beta                   1.9.0
tidy3d-denormalizer           0.1.0           /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d-denormalizer
tidy3d_pipeline               2.6.0rc1        /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d_pipeline
tinycss2                      1.2.1
tmm                           0.1.8
toml                          0.10.2
tomli                         2.0.1
tomlkit                       0.12.3
toolz                         0.12.1
tornado                       6.4
tox                           4.12.1
tqdm                          4.66.1
traitlets                     5.14.1
trimesh                       3.20.0
typeguard                     2.13.3
types-python-dateutil         2.8.19.20240106
typing_extensions             4.9.0
typing-inspect                0.9.0
typing-utils                  0.1.0
tzdata                        2023.4
uri-template                  1.3.0
urllib3                       1.26.18
virtualenv                    20.25.0
vtk                           9.2.6
wcwidth                       0.2.13
webcolors                     1.13
webencodings                  0.5.1
websocket-client              1.7.0
wheel                         0.41.2
widgetsnbextension            4.0.9
xarray                        2023.12.0
zipp                          3.17.0

@momchil-flex

I was able to reproduce the issue. The good news is that I've seen the error before. The bad news is that I haven't been able to fully understand it yet. I'll keep looking.

Ok, I figured out the issue.

One of the optimizations SAX makes is to assume that the shape of (i.e. the ports of) an S-matrix generated by a model function never changes. However, your component function takes a shape argument and hence the output shape of the s-matrix can vary. Unfortunately a partial is not going to save you from that because of some more introspection logic that SAX does to construct the circuit.

An easy workaround is to rewrite your partials as actual functions whenever you expect the output shape to change as a result of different input parameters. Something like this:

def component1x2(params=params0, beta=5):
    return component(params=params, beta=beta, shape=(1,2))
    
def component2x2(params=params0, beta=5):
    return component(params=params, beta=beta, shape=(2,2))

circuit_fn, _ = sax.circuit(
    netlist={
        "instances": {
            "splitter": component1x2,
            "phase_shifter": phase_shifter,
            "combiner": component2x2,
        },
        "connections": {
            "splitter,out0": "phase_shifter,in",
            "phase_shifter,out": "combiner,in0",
            "splitter,out1": "combiner,in1",
        },
        "ports": {
            "in": "splitter,in0",
            "out0": "combiner,out0",
            "out1": "combiner,out1",
        },
    }
)

circuit_fn

This solves the issue.

I will work on better error messages to handle this case in a future release.

Thanks! Really appreciate you looking into that. Out of curiosity why do you think it works on some environments and not others? Even if sax and jax are the same requirements?

Not really sure... I seem to have the problem in all my sax environments... Maybe at some point you were optimizing a 2x2x2x2 in stead of a 1x2x2x2?

One thing I just realized is that it might also be related to whether you have klujax installed or not. In the case when klujax is installed SAX will default to the significantly faster (at least for large circuits) KLU backend (backend='klu' in sax.circuit) rather than the alternative approach by Gunnar Filipsson (backend='fg' in sax.circuit). The latter backend has less strict requirements on the shapes of the models but the implementation in SAX is generally speaking slower (although I doubt that's the case for the circuit in your notebook as that one is very small)

Interesting. I'm pretty sure klujax was installed in both environments with the same version number but not 100% positive