vferat/pycrostates

[ENH] Auto-label/order the microstates maps

Closed this issue · 3 comments

Any way to have a sort of auto-order and naming of the microstates maps by comparing them with templates?

Reorder

If have this old smart_reorder method that was implemented some time ago in Pycrostates:

   def smart_reorder(self):
        """Automaticaly reorder cluster centers.
        Returns
        ----------
        self : self
            The modfied instance.
        """
        self._check_fit()
        info = self.info
        centers = self.cluster_centers_

        template = np.array([[-0.13234463, -0.19008217, -0.01808156, -0.06665204, -0.18127315,
        -0.25741473, -0.2313206 ,  0.04239534, -0.14411298, -0.25635016,
         0.1831745 ,  0.17520883, -0.06034687, -0.21948988, -0.2057277 ,
         0.27723199,  0.04632557, -0.1383458 ,  0.36954792,  0.33889126,
         0.1425386 , -0.05140216, -0.07532628,  0.32313928,  0.21629226,
         0.11352515],
        [-0.15034466, -0.08511373, -0.19531161, -0.24267313, -0.16871454,
        -0.04761393,  0.02482456, -0.26414511, -0.15066143,  0.04628036,
        -0.1973625 , -0.24065874, -0.08569745,  0.1729162 ,  0.22345117,
        -0.17553494,  0.00688743,  0.25853483, -0.09196588, -0.09478585,
         0.09460047,  0.32742083,  0.4325027 ,  0.09535141,  0.1959104 ,
         0.31190313],
        [0.29388541,  0.2886461 ,  0.27804376,  0.22674127,  0.21938115,
         0.21720292,  0.25153101,  0.12125869,  0.10996983,  0.10638135,
         0.11575272, -0.01388831, -0.04507772, -0.03708886,  0.08203929,
        -0.14818182, -0.20299531, -0.16658826, -0.09488949, -0.23512102,
        -0.30464665, -0.25762648, -0.14058166, -0.22072284, -0.22175042,
        -0.22167467],
       [-0.21660409, -0.22350361, -0.27855619, -0.0097109 ,  0.07119601,
         0.00385336, -0.24792901,  0.08145982,  0.23290418,  0.09985582,
        -0.24242583,  0.13516244,  0.3304661 ,  0.16710186, -0.21832217,
         0.15575575,  0.33346027,  0.18885162, -0.21687347,  0.10926662,
         0.26182733,  0.13760157, -0.19536083, -0.15966419, -0.14684497,
        -0.15296749],
       [-0.12444958, -0.12317709, -0.06189361, -0.20820917, -0.25736043,
        -0.20740485, -0.06941215, -0.18086612, -0.26979589, -0.17602898,
         0.05332203, -0.10101208, -0.20095764, -0.09582802,  0.06883067,
         0.0082463 , -0.07052899,  0.00917889,  0.26984673,  0.13288481,
         0.08062487,  0.13616082,  0.30845643,  0.36843231,  0.35510687,
         0.35583386]])
        ch_names_template =  ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC3', 'FCz',
                            'FC4', 'T3', 'C3', 'Cz', 'C4', 'T4', 'CP3', 'CPz', 'CP4',
                            'T5', 'P3', 'Pz', 'P4','T6', 'O1', 'Oz', 'O2']

        ch_names_template = [name.lower() for name in ch_names_template]
        ch_names_centers = [name.lower() for name in info['ch_names']]
        common_ch_names = list(set(ch_names_centers).intersection(ch_names_template))

        if len (common_ch_names) <= 10:
            warn("Not enought common electrodes with built-in template to automaticalv reorder maps. "
                 "Order hasn't been changed.")
            return()

        common_names_template = [ch_names_template.index(name) for name in common_ch_names]
        common_names_centers = [ch_names_centers.index(name) for name in common_ch_names]

        reduc_template = template[:, common_names_template]
        reduc_centers = centers[:, common_names_centers]

        mat = np.corrcoef(reduc_template,reduc_centers)[:len(reduc_template), -len(reduc_centers):]
        mat = np.abs(mat)
        mat_ = mat.copy()
        rows = list()
        columns = list()
        while len(columns) < len(template) and len(columns) < len(centers):
            mask_columns = np.ones(mat.shape[1], bool)
            mask_rows = np.ones(mat.shape[0], bool)
            mask_rows[rows] = 0
            mask_columns[columns] = 0
            mat_ = mat[mask_rows,:][:,mask_columns]
            row, column = np.unravel_index(np.where(mat.flatten() == np.max(mat_))[0][0], mat.shape)
            rows.append(row)
            columns.append(column)
            mat[row, column] = -1
        order = [x for _,x in sorted(zip(rows,columns))]
        order = order + [x for x in range(len(centers)) if x not in order]
        self.reorder(order)
        return(self)

However, I don't think providing a template is a good idea because same topographies doesn't necessary means same sources therefore not the same brain functions.

We could start with a function that takes two fitted BaseCluster_ an reorder one of them base on the other.
But we have to make sure to deal with:
- inconsistent montage
- inconsistent number of cluster_centers
- No 1 to 1 correspondence between (map 1 and 2 of instance 1 has highest correlation with map 1 of instance 2)

Plot

I also have the code to make this kind of figures that may be useful to compare clustering results:
image

that we could implement in pycrostates.viz if we accept to depend on seaborn

Templates

We could also think to have a dataset of cluster_centers corresponding to publish studies:

pycrostates.datasets.clusters_centers.get(study="ferat2022", conditon="ADHD")

However we also need to return the corresponding montage, maybe by creating a new class ImportedCluster( _BaseCluster)

I haven't looked at the code in detail, but for the auto-order/naming based on a template, my idea was that the topomap always looks the same for a given state (beside the sign), right? And a topomap is represented as values interpolated on a grid, thus we don't even need to bother with montages, we just have to match the interpoalted grids?

That plot looks amazing, I'm fine on depending on seaborn for stuff like that!

Hey,

I worked a bit on the automatic ordering tonight.
One can considerer it as a assignment problem where the cost matrix is equal to the opposite ( to revert the problem from "cost reduction" to "cost maximization") of the spatial correlation between maps.

I don't have time to open a PR right now, so I put the code here in the meantime.

def _reorder_template(current, template, ignore_polarity=True):
    M = np.corrcoef(template, current)[:n_states, n_states:]
    if ignore_polarity:
        M = np.abs(M)
    cost_matrix, order = scipy.optimize.linear_sum_assignment(-M)
    return(order)

along with some test:

n_states = 5
n_electrodes = 3
# Random template
template = np.random.randint(-10,10, (n_states,n_electrodes))
# Shuffle template
arr = np.arange(n_states)
np.random.shuffle(arr)
random_template = template[arr]
# invert polarity
polarities = np.random.choice([-1, 1], n_states)
random_pol_template = polarities[:, np.newaxis] * random_template 

# No suffle
current = template
ignore_polarity = True
order = _reorder_template(current, template, ignore_polarity=ignore_polarity)
assert np.all(order == np.arange(n_states))

# Shuffle
current = random_template
ignore_polarity = False
order = _reorder_template(current, template, ignore_polarity=ignore_polarity)
assert np.allclose(current[order], template)

# Shuffle + ignore_polarity 
current = random_template
ignore_polarity = True
order = _reorder_template(current, template, ignore_polarity=ignore_polarity)
assert np.allclose(current[order], template)

# Shuffle + sign + ignore_polarity 
current = random_pol_template
ignore_polarity = True
order_ = _reorder_template(current, template, ignore_polarity=ignore_polarity)
assert np.all(order == order_)

# Shuffle + sign 
current = random_pol_template
ignore_polarity = False
order = _reorder_template(current, template, ignore_polarity=ignore_polarity)
corr = np.corrcoef(template, current[order])[n_states:, :n_states]
corr_order = np.corrcoef(template, current[order])[n_states:, :n_states]
assert np.trace(corr) <= np.trace(corr_order)