SPFlow/SPFlow

ValueError: setting an array element with a sequence

janakact opened this issue · 0 comments

What did you expect to happen?

Following code to run without an issue.

def learn_mspn_example():
    a = np.random.randint(2, size=10).reshape(-1, 1)
    b = np.random.randint(3, size=10).reshape(-1, 1)
    c = np.r_[np.random.normal(10, 5, (4, 1)), np.random.normal(20, 10, (6, 1))]
    d = 5 * a + 3 * b + c
    train_data = np.c_[a, b, c, d]
    ds_context = Context(meta_types=[MetaType.DISCRETE, MetaType.DISCRETE, MetaType.REAL, MetaType.REAL])
    ds_context.add_domains(train_data)
    mspn = learn_mspn(train_data, ds_context, min_instances_slice=4)

What actually happened?

Gets following Error on add_domains method:

Traceback (most recent call last):
  File "./sum_product.py", line 62, in <module>
    main()
  File "./sum_product.py", line 59, in main
    learn_mspn_example()
  File "./sum_product.py", line 51, in learn_mspn_example
    ds_context.add_domains(train_data)
  File "..../python3.8/site-packages/spn/structure/Base.py", line 157, in add_domains
    self.domains = np.asanyarray(domain)
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (4,) + inhomogeneous part.

Describe your attempts to resolve the issue

It can be fixed by removing np.asanyarray() in the add_domains method.

    def add_domains(self, data):
        assert len(data.shape) == 2, "data is not 2D?"
        assert data.shape[1] == len(self.meta_types), "Data columns and metatype size doesn't match"

        from spn.structure.StatisticalTypes import MetaType

        domain = []

        for col in range(data.shape[1]):
            feature_meta_type = self.meta_types[col]
            min_val = np.nanmin(data[:, col])
            max_val = np.nanmax(data[:, col])
            domain_values = [min_val, max_val]

            if feature_meta_type == MetaType.REAL or feature_meta_type == MetaType.BINARY:
                domain.append(domain_values)
            elif feature_meta_type == MetaType.DISCRETE:
                domain.append(np.arange(domain_values[0], domain_values[1] + 1, 1))
            else:
                raise Exception("Unkown MetaType " + str(feature_meta_type))
        print(domain)
        # self.domains = np.asanyarray(domain) <----here
        self.domains = domain # <---- fix

        return self

Steps to reproduce

Just run above learn_mspn_example() method above.

System Information

Python: 3.8.10
SPflow: 0.0.41
OS: Ubuntu 20.04

Installed Python Packages

<details>
absl-py==1.3.0
alabaster==0.7.12
ale-py==0.7.4
alembic==1.9.0
anyio==3.6.2
arch==5.3.0
arff==0.9
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asgiref==3.6.0
asttokens==2.2.1
astunparse==1.6.3
atari-py==0.2.9
attrs==22.1.0
autopage==0.5.1
AutoROM==0.4.2
AutoROM.accept-rom-license==0.5.0
Babel==2.11.0
backcall==0.2.0
backports.zoneinfo==0.2.1
beautifulsoup4==4.11.1
black==22.12.0
blackhc.mdp==1.0.4
bleach==5.0.1
box2d-py==2.3.8
cachetools==5.2.0
certifi==2022.12.7
cffi==1.15.1
chardet==5.1.0
charset-normalizer==2.1.1
click==8.1.3
cliff==4.1.0
cloudpickle==2.2.0
cmaes==0.9.0
cmd2==2.4.2
colorama==0.4.6
colorlog==6.7.0
comm==0.1.2
commonmark==0.9.1
contourpy==1.0.6
coverage==7.0.0
cvxopt==1.3.0
cycler==0.11.0
Cython==0.29.32
dataclasses==0.6
DataProperty==0.55.0
debugpy==1.6.4
decorator==4.4.2
defusedxml==0.7.1
Django==4.1.4
docker-pycreds==0.4.0
docutils==0.17.1
entrypoints==0.4
ete3==3.1.2
exceptiongroup==1.0.4
execnet==1.9.0
executing==1.2.0
fasteners==0.18
fastjsonschema==2.16.2
filelock==3.8.2
flake8==6.0.0
flake8-bugbear==22.12.6
Flask==2.2.2
flatbuffers==22.12.6
fonttools==4.38.0
fqdn==1.5.1
gast==0.4.0
gitdb==4.0.10
GitPython==3.1.29
glfw==2.5.5
google-auth==2.15.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
greenlet==2.0.1
grpcio==1.51.1
gym==0.21.0
gym-minigrid==1.0.3
gym-notices==0.0.8
h5py==3.7.0
huggingface-hub==0.11.1
huggingface-sb3==2.2.4
idna==3.4
image==1.5.33
imageio==2.23.0
imageio-ffmpeg==0.4.7
imagesize==1.4.1
importlab==0.8
importlib-metadata==4.13.0
importlib-resources==5.10.1
iniconfig==1.1.1
ipdb==0.13.11
ipykernel==6.19.4
ipython==8.7.0
ipython-genutils==0.2.0
ipythonblocks==1.9.0
ipywidgets==8.0.3
isoduration==20.11.0
isort==5.11.3
itsdangerous==2.1.2
jedi==0.18.2
Jinja2==3.1.2
joblib==1.2.0
jsonpointer==2.3
jsonschema==4.17.3
jupyter==1.0.0
jupyter-console==6.4.4
jupyter-events==0.5.0
jupyter_client==7.4.8
jupyter_core==5.1.0
jupyter_server==2.0.2
jupyter_server_terminals==0.4.3
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.4
keras==2.11.0
Keras-Preprocessing==1.1.2
kiwisolver==1.4.4
lark-parser==0.12.0
libclang==14.0.6
libcst==0.4.9
libtorrent==2.0.7
livereload==2.6.3
lxml==4.9.1
lz4==4.0.2
Mako==1.2.4
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib==3.6.2
matplotlib-inline==0.1.6
mbstrdecoder==1.1.1
mccabe==0.7.0
mistune==2.0.4
moviepy==1.0.3
mpmath==1.2.1
mujoco==2.2.0
mujoco-py==2.1.2.14
mypy-extensions==0.4.3
nbclassic==0.4.8
nbclient==0.7.2
nbconvert==7.2.7
nbformat==5.7.1
nest-asyncio==1.5.6
networkx==2.8.8
ninja==1.11.1
notebook==6.5.2
notebook_shim==0.2.2
numpy==1.24.0
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
oauthlib==3.2.2
opencv-python==4.6.0.66
opt-einsum==3.3.0
optuna==3.0.5
packaging==22.0
panda-gym==1.1.1
pandas==1.5.2
pandocfilters==1.5.0
parso==0.8.3
pathspec==0.10.3
pathtools==0.1.2
pathvalidate==2.5.2
patsy==0.5.3
pbr==5.11.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.3.0
pkgutil_resolve_name==1.3.10
platformdirs==2.6.0
plotly==5.11.0
pluggy==1.0.0
prettytable==3.5.0
proglog==0.1.10
prometheus-client==0.15.0
promise==2.3
prompt-toolkit==3.0.36
property-cached==1.6.4
protobuf==3.19.6
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
py==1.11.0
pyaml==21.10.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybullet==3.2.5
pycodestyle==2.10.0
pycparser==2.21
pydot==1.4.2
pydotplus==2.0.2
pyenchant==3.2.2
pyflakes==3.0.1
pygame==2.1.0
pyglet==2.0.2.1
Pygments==2.13.0
PyOpenGL==3.1.6
pyparsing==3.0.9
pyperclip==1.8.2
PyQt5==5.15.7
PyQt5-Qt5==5.15.2
PyQt5-sip==12.11.0
pyrsistent==0.19.2
pytablewriter==0.64.2
pytest==7.2.0
pytest-cov==4.0.0
pytest-env==0.8.1
pytest-xdist==3.1.0
python-dateutil==2.8.2
python-json-logger==2.0.4
pytype==2022.6.6
pytz==2022.7
PyVirtualDisplay==3.0
PyYAML==6.0
pyzmq==24.0.1
qpsolvers==2.7.0
qtconsole==5.4.0
QtPy==2.3.0
render-browser==0.5
requests==2.28.1
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==12.6.0
rl-zoo3==1.6.3
rliable==1.0.8
rsa==4.9
sb3-contrib==1.6.2
scikit-learn==1.2.0
scikit-optimize==0.9.0
scipy==1.8.1
seaborn==0.12.1
Send2Trash==1.8.0
sentry-sdk==1.12.1
setproctitle==1.3.2
shortuuid==1.0.11
six==1.16.0
sklearn==0.0.post1
smmap==5.0.0
sniffio==1.3.0
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soupsieve==2.3.2.post1
spflow==0.0.41
Sphinx==5.3.0
sphinx-autobuild==2021.3.14
sphinx-copybutton==0.5.1
sphinx-rtd-theme==1.1.1
sphinx_autodoc_typehints==1.19.5
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sphinxcontrib-spelling==7.7.0
SQLAlchemy==1.4.45
sqlparse==0.4.3
stable-baselines3==1.6.2
stack-data==0.6.2
statsmodels==0.13.5
stevedore==4.1.1
swig==4.1.0
sympy==1.11.1
tabledata==1.3.0
tabulate==0.9.0
tcolorpy==0.1.2
tenacity==8.1.0
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-estimator==2.11.0
tensorflow-io-gcs-filesystem==0.29.0
termcolor==2.1.1
terminado==0.17.1
threadpoolctl==3.1.0
tinycss2==1.2.1
toml==0.10.2
tomli==2.0.1
torch==1.13.1
torchvision==0.14.1
tornado==6.2
tqdm==4.64.1
traitlets==5.8.0
typed-ast==1.5.4
typepy==1.3.0
typing==3.7.4.3
typing-inspect==0.8.0
typing_extensions==4.4.0
uri-template==1.2.0
urllib3==1.26.13
wandb==0.13.7
wasabi==1.1.0
wcwidth==0.2.5
webcolors==1.12
webencodings==0.5.1
websocket-client==1.4.2
Werkzeug==2.2.2
widgetsnbextension==4.0.4
wrapt==1.14.1
xvfbwrapper==0.2.9
zipp==3.11.0
</details>