dianna-ai/dianna

Visualization: `plot_tabular` does not work for x as `np.ndarray`

Closed this issue · 0 comments

The input x in plot_tabular is defined as Array of feature importance scores. However, it returns an error if x is np.ndarray. There isnot a test for this in test_visualization.py. You can reproduce the error as:

import numpy as np
def test_plot_tabular_with_array():
    """Test plot tabular data."""
    x = np.random.rand(5, 3)
    y = [f"Feature {i}" for i in range(x.shape[1])]
    fig, ax = plot_tabular(x=x, y=y, show_plot=False)
    assert fig is not None

the error

>       top_values = [x for _, x in sorted(zip(abs_values, x), reverse=True)][:num_features]
E       ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()