smazzanti/mrmr

Redundancy Matrix is asymmetric

Opened this issue · 0 comments

I would naively assume that the redundancy matrix should be symmetric since pairwise mutual information (MI) is invariant against permutation, i.e. MI(A, B) = MI(B, A) if A, B are two independent features.
However, I observe that the matrix is not symmetric, not only due to a small numerical factor but quite significantly. Am i missing something?

I attached a minimal working example to reproduce the issue (if it is one). Each pixel is the pairwise redundancy between two features. Matrix is obviously asymmetric:

drawing

import matplotlib.pyplot as plt
import pandas as pd
from mrmr import mrmr_classif
from sklearn.datasets import make_classification


def generate_synthetic_data(n_samples=1_000, n_features=25, n_informative=10, n_redundant=5, seed=42):
    X, y = make_classification(
        n_samples=n_samples,
        n_features=n_features,
        n_informative=n_informative,
        n_redundant=n_redundant,
        n_repeated=5,
        random_state=seed
    )
    return pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)]), pd.Series(y)


def check_symmetry(matrix, num_threshold=1e-3):
    # assess if redundancy matrix is symmetric
    symmetry_diff = matrix - matrix.T
    non_symmetric_indices = np.sum(np.abs(symmetry_diff) > num_threshold)
    print(f"non-symmetric indices: {100 * non_symmetric_indices / (matrix.shape[0] * matrix.shape[1]):.2f} %")


def main():
    X, y = generate_synthetic_data()
    _, _, redundancy = mrmr_classif(X=X, y=y, K=X.shape[1], return_scores=True)
    redundancy_matrix = np.array(redundancy)

    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    plt.subplots_adjust(left=.1, right=.98, bottom=.1, top=.98)
    im = ax.imshow(redundancy_matrix)
    plt.colorbar(im, shrink=.8, ax=ax)
    plt.show()
    fig.savefig('test.pdf', bbox_inches='tight')

    check_symmetry(redundancy_matrix)

    # symmetrize
    check_symmetry((redundancy_matrix + redundancy_matrix.T) / 2)


if __name__ == '__main__':
    main()