flaport/sax

Issues with Models with complex arguments

Closed this issue · 2 comments

Trying to create a simple MZI circuit with where I can give complex parameters to each coupler. For example:

def coupler(S31=1/jnp.sqrt(2), S41=1j/jnp.sqrt(2)) -> sax.SDict:
    coupler_dict = sax.reciprocal(
        {
            ("in0", "out0"): S31,
            ("in0", "out1"): S41,
            ("in1", "out0"): S41,
            ("in1", "out1"): S31,
        }
    )
    return coupler_dict
def MZI_arms(phase_top = 0., phase_bottom = 0) -> sax.SDict:
    _sdict = sax.reciprocal(
        {
            ("in0", "out0"): jnp.exp(1j*phase_top),
            ("in1", "out1"): jnp.exp(1j*phase_bottom),
        }
    )
    return _sdict
mzi, info = sax.circuit(
    netlist={
        "instances": {
            "BS1": "coupler",
            "PS1": "phase_shifter",
            "BS2": "coupler",
            "PS2": "phase_shifter",
        },
        "connections": {
            "BS1,out0": "PS1,in0",
            "BS1,out1": "PS1,in1",
            "PS1,out0": "BS2,in0",
            "PS1,out1": "BS2,in1",
            "BS2,out0": "PS2,in0",
            "BS2,out1": "PS2,in1",
        },
        "ports": {
            "in0": "BS1,in0",
            "in1": "BS1,in1",
            "out0": "PS2,out0",
            "out1": "PS2,out1",
        },
    },
    models={
        "coupler": coupler,
        "phase_shifter": MZI_arms,
    }
)

This returns a warning for me as follows:

[.../lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2089]: ComplexWarning: Casting complex values to real discards the imaginary part
  out_array: Array = lax_internal._convert_element_type(

Is that an issue? I am not sure I should trust the final result...

Hi @thomaslima, thanks for the bug report. This issue will be fixed in sax>=0.10.4.