`IndexError`
Closed this issue · 6 comments
Hi Floris, we have a notebook demonstrating sax. It seems to error at cell [18] for some users and not for others. There seem to be only minor differences in the dependencies and we can't figure out what is causing this discrepancy.
Do you have any suggestions for things to look into here? We're stumped after testing several different dependencies.
This is the stack trace
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[17], line 2
1 # how to pass specific parmaeters to each of the sub-functions for the instances
----> 2 s = circuit_fn(splitter={"params": params0}, combiner={"params": 0 * params0}, beta=3, phase_sifter=dict(phi=2.0))
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/saxtypes.py:267, in sdict.<locals>.wrapper(**kwargs)
265 @functools.wraps(model)
266 def wrapper(**kwargs):
--> 267 return sdict(model(**kwargs))
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/circuit.py:226, in _flat_circuit.<locals>._circuit(**settings)
223 for inst_name, model in inst2model.items():
224 instances[inst_name] = model(**full_settings.get(inst_name, {}))
--> 226 S = evaluate_fn(analyzed, instances)
227 return S
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/backends/klu.py:122, in evaluate_circuit_klu(analyzed, instances)
117 idx += len(ports_map)
119 Sx = jnp.concatenate(
120 [jnp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1
121 )
--> 122 CSx = Sx[..., mask]
123 Ix = jnp.ones((*batch_shape, n_col))
124 I_CSx = jnp.concatenate([-CSx, Ix], -1)
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/array.py:319, in ArrayImpl.__getitem__(self, idx)
317 return lax_numpy._rewriting_take(self, idx)
318 else:
--> 319 return lax_numpy._rewriting_take(self, idx)
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4152, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
4146 if (isinstance(aval, core.DShapedArray) and aval.shape == () and
4147 dtypes.issubdtype(aval.dtype, np.integer) and
4148 not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
4149 isinstance(arr.shape[0], int)):
4150 return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4152 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
4153 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
4154 unique_indices, mode, fill_value)
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4230, in _split_index_for_jit(idx, shape)
4226 raise TypeError(f"JAX does not support string indexing; got {idx=}")
4228 # Expand any (concrete) boolean indices. We can then use advanced integer
4229 # indexing logic to handle them.
-> 4230 idx = _expand_bool_indices(idx, shape)
4232 leaves, treedef = tree_flatten(idx)
4233 dynamic = [None] * len(leaves)
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4552, in _expand_bool_indices(idx, shape)
4550 expected_shape = shape[start: start + _ndim(i)]
4551 if i_shape != expected_shape:
-> 4552 raise IndexError("boolean index did not match shape of indexed array in index "
4553 f"{dim_number}: got {i_shape}, expected {expected_shape}")
4554 out.extend(np.where(i))
4555 else:
IndexError: boolean index did not match shape of indexed array in index 1: got (18,), expected (14,)
And when we pip freeze
for the erroring case (python 3.11 on ubuntu)
Package Version Editable project location
----------------------------- --------------- ---------------------------------------------------------------
absl-py 2.1.0
accessible-pygments 0.0.4
alabaster 0.7.16
annotated-types 0.6.0
anyio 4.2.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
astroid 3.0.3
asttokens 2.4.1
async-lru 2.0.4
attrs 23.2.0
autograd 1.6.2
Babel 2.14.0
beautifulsoup4 4.12.3
black 23.12.1
bleach 6.1.0
boto3 1.23.1
botocore 1.26.10
cached-property 1.5.2
cachetools 5.3.2
certifi 2024.2.2
cffi 1.16.0
cfgv 3.4.0
chardet 5.2.0
charset-normalizer 3.3.2
chex 0.1.82
click 8.0.3
cloudpickle 3.0.0
colorama 0.4.6
comm 0.2.1
commonmark 0.9.1
contourpy 1.2.0
coverage 7.4.1
cycler 0.12.1
dask 2023.10.1
dataclasses-json 0.6.4
debugpy 1.8.0
decorator 5.1.1
defusedxml 0.7.1
dill 0.3.8
distlib 0.3.8
docutils 0.20.1
entrypoints 0.4
etils 1.6.0
exceptiongroup 1.2.0
executing 2.0.1
fastcore 1.5.29
fastjsonschema 2.19.1
filelock 3.13.1
flax 0.7.4
flow360scripts 0.0.1
fonttools 4.47.2
fqdn 1.5.1
fsspec 2024.2.0
future 0.18.3
gdspy 1.6.13
gdstk 0.9.50
gitdb 4.0.11
GitPython 3.1.41
gltflib 1.0.13
grcwa 0.1.2
h11 0.14.0
h2 4.1.0
h5netcdf 1.0.2
h5py 3.10.0
hpack 4.0.0
httpcore 1.0.3
httpx 0.26.0
hyperframe 6.0.1
identify 2.5.33
idna 3.6
imagesize 1.4.1
importlib-metadata 6.11.0
importlib-resources 6.1.1
iniconfig 2.0.0
ipykernel 6.28.0
ipython 8.21.0
ipywidgets 8.1.1
isoduration 20.11.0
isort 5.13.2
jax 0.4.14
jaxlib 0.4.14
jaxtyping 0.2.25
jedi 0.19.1
Jinja2 3.1.3
jmespath 1.0.1
json5 0.9.14
jsonpointer 2.4
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
jupyter 1.0.0
jupyter_client 8.6.0
jupyter-console 6.6.3
jupyter_core 5.7.1
jupyter-events 0.9.0
jupyter-lsp 2.2.2
jupyter_server 2.12.5
jupyter-server-mathjax 0.2.6
jupyter_server_terminals 0.5.2
jupyterlab 4.0.12
jupyterlab_pygments 0.3.0
jupyterlab_server 2.25.2
jupyterlab-widgets 3.0.9
kiwisolver 1.4.5
klujax 0.2.4
locket 1.0.0
markdown-it-py 3.0.0
MarkupSafe 2.1.3
marshmallow 3.20.2
matplotlib 3.8.2
matplotlib-inline 0.1.6
mccabe 0.7.0
mdit-py-plugins 0.4.0
mdurl 0.1.2
memory-profiler 0.61.0
mistune 3.0.2
ml-dtypes 0.3.2
mpmath 1.3.0
msgpack 1.0.7
multiprocess 0.70.16
mypy-extensions 1.0.0
myst-parser 2.0.0
natsort 8.4.0
nbclient 0.8.0
nbconvert 7.16.1
nbdime 4.0.1
nbformat 5.9.2
nbsphinx 0.9.3
nest_asyncio 1.6.0
networkx 2.8.8
nodeenv 1.8.0
notebook 7.0.8
notebook_shim 0.2.3
numpy 1.26.3
opt-einsum 3.3.0
optax 0.1.9
orbax-checkpoint 0.5.3
orjson 3.9.13
overrides 7.7.0
packaging 23.2
pandas 2.2.0
pandocfilters 1.5.0
parso 0.8.3
partd 1.4.1
pathspec 0.12.1
pexpect 4.9.0
pillow 10.2.0
pip 23.3.1
pkgutil_resolve_name 1.3.10
platformdirs 4.2.0
pluggy 1.4.0
ply 3.11
pre-commit 3.6.0
prometheus_client 0.20.0
prompt-toolkit 3.0.43
protobuf 4.25.2
psutil 5.9.0
ptyprocess 0.7.0
pure-eval 0.2.2
pybind11 2.11.1
pycparser 2.21
pydantic 2.6.3
pydantic_core 2.16.3
pydata-sphinx-theme 0.15.2
Pygments 2.17.2
PyJWT 2.8.0
pylint 3.0.3
PyMieScatt 1.8.1.1
pyparsing 3.1.1
pyproject-api 1.6.1
pyroots 0.5.0
pyrsistent 0.20.0
pyswarms 1.3.0
pytest 8.0.0
pytest-timeout 2.2.0
python-dateutil 2.8.2
python-json-logger 2.0.7
pytz 2024.1
PyYAML 6.0.1
pyzmq 25.1.2
qtconsole 5.5.1
QtPy 2.4.1
referencing 0.33.0
requests 2.28.2
responses 0.24.1
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 12.5.1
rpds-py 0.17.1
Rtree 1.0.1
ruff 0.2.1
s3transfer 0.5.2
sax 0.12.1
scipy 1.12.0
Send2Trash 1.8.2
setuptools 68.2.2
shapely 2.0.2
signac 2.1.0
six 1.16.0
smmap 5.0.1
sniffio 1.3.0
snowballstemmer 2.2.0
soupsieve 2.5
Sphinx 7.2.6
sphinx-book-theme 1.1.0
sphinx-copybutton 0.5.2
sphinx-sitemap 2.5.1
sphinx-tabs 3.4.5
sphinxcontrib-applehelp 1.0.8
sphinxcontrib-devhelp 1.0.6
sphinxcontrib-htmlhelp 2.0.5
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 1.0.7
sphinxcontrib-serializinghtml 1.1.10
sphinxemoji 0.3.1
stack-data 0.6.2
sympy 1.12
synced-collections 1.0.0
tensorstore 0.1.53
terminado 0.18.0
tidy3d 2.6.0rc1 /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d_frontend
tidy3d-beta 1.9.0
tidy3d-denormalizer 0.1.0 /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d-denormalizer
tidy3d_pipeline 2.6.0rc1 /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d_pipeline
tinycss2 1.2.1
tmm 0.1.8
toml 0.10.2
tomli 2.0.1
tomlkit 0.12.3
toolz 0.12.1
tornado 6.4
tox 4.12.1
tqdm 4.66.1
traitlets 5.14.1
trimesh 3.20.0
typeguard 2.13.3
types-python-dateutil 2.8.19.20240106
typing_extensions 4.9.0
typing-inspect 0.9.0
typing-utils 0.1.0
tzdata 2023.4
uri-template 1.3.0
urllib3 1.26.18
virtualenv 20.25.0
vtk 9.2.6
wcwidth 0.2.13
webcolors 1.13
webencodings 0.5.1
websocket-client 1.7.0
wheel 0.41.2
widgetsnbextension 4.0.9
xarray 2023.12.0
zipp 3.17.0
I was able to reproduce the issue. The good news is that I've seen the error before. The bad news is that I haven't been able to fully understand it yet. I'll keep looking.
Ok, I figured out the issue.
One of the optimizations SAX makes is to assume that the shape of (i.e. the ports of) an S-matrix generated by a model function never changes. However, your component function takes a shape argument and hence the output shape of the s-matrix can vary. Unfortunately a partial is not going to save you from that because of some more introspection logic that SAX does to construct the circuit.
An easy workaround is to rewrite your partials as actual functions whenever you expect the output shape to change as a result of different input parameters. Something like this:
def component1x2(params=params0, beta=5):
return component(params=params, beta=beta, shape=(1,2))
def component2x2(params=params0, beta=5):
return component(params=params, beta=beta, shape=(2,2))
circuit_fn, _ = sax.circuit(
netlist={
"instances": {
"splitter": component1x2,
"phase_shifter": phase_shifter,
"combiner": component2x2,
},
"connections": {
"splitter,out0": "phase_shifter,in",
"phase_shifter,out": "combiner,in0",
"splitter,out1": "combiner,in1",
},
"ports": {
"in": "splitter,in0",
"out0": "combiner,out0",
"out1": "combiner,out1",
},
}
)
circuit_fn
This solves the issue.
I will work on better error messages to handle this case in a future release.
Thanks! Really appreciate you looking into that. Out of curiosity why do you think it works on some environments and not others? Even if sax and jax are the same requirements?
Not really sure... I seem to have the problem in all my sax environments... Maybe at some point you were optimizing a 2x2x2x2 in stead of a 1x2x2x2?
One thing I just realized is that it might also be related to whether you have klujax
installed or not. In the case when klujax is installed SAX will default to the significantly faster (at least for large circuits) KLU backend (backend='klu'
in sax.circuit
) rather than the alternative approach by Gunnar Filipsson (backend='fg'
in sax.circuit
). The latter backend has less strict requirements on the shapes of the models but the implementation in SAX is generally speaking slower (although I doubt that's the case for the circuit in your notebook as that one is very small)
Interesting. I'm pretty sure klujax was installed in both environments with the same version number but not 100% positive