eric-moreno/IN

How is the JSD calculated in the final paper?

Justin-Tan opened this issue · 3 comments

Hi authors,

How do you reliably compute the Jensen-Shannon divergence between the continuous mass distributions passing/failing the threshold? As far as I understand, the KL Divergence and by extension the JS divergence b/w continuous distributions are intractable and hard to estimate reliably.

I looked through the code base but found it mildly confusing. From what I can gather, you discretize the distributions somehow and use an estimator for the entropy of the discretized distribution? But how do you calculate the cross entropy as well? A quick rundown would be very helpful, as this would be a very useful metric to quantify the extent of decorrelation to a given pivot variable.

Cheers,
Justin

Hi @Justin-Tan,

The simplest computation is here:

IN/make_good_plots.py

Lines 600 to 617 in 0ef7e34

# digitze into bins
spec_pass = np.digitize(mass_pass, bins=np.linspace(mmin,mmax,nbins+1), right=False)-1
spec_fail = np.digitize(mass_fail, bins=np.linspace(mmin,mmax,nbins+1), right=False)-1
# one hot encoding
spec_ohe_pass = np.zeros((spec_pass.shape[0],nbins))
spec_ohe_pass[np.arange(spec_pass.shape[0]),spec_pass] = 1
spec_ohe_pass_sum = np.sum(spec_ohe_pass,axis=0)/spec_ohe_pass.shape[0]
spec_ohe_fail = np.zeros((spec_fail.shape[0],nbins))
spec_ohe_fail[np.arange(spec_fail.shape[0]),spec_fail] = 1
spec_ohe_fail_sum = np.sum(spec_ohe_fail,axis=0)/spec_ohe_fail.shape[0]
M = 0.5*spec_ohe_pass_sum+0.5*spec_ohe_fail_sum
kld_pass = scipy.stats.entropy(spec_ohe_pass_sum,M,base=2)
kld_fail = scipy.stats.entropy(spec_ohe_fail_sum,M,base=2)
jsd = 0.5*kld_pass+0.5*kld_fail

Basically we first get the mass_pass and mass_fail numpy arrays of mass values. These are turned into binned, normalized mass distributions spec_ohe_pass_sum and spec_ohe_fail_sum.

We take the average M of these two distributions. Then we compute the two KL divergences and average them to get the JS divergence.

Thanks for the response! That makes sense, I didn't realize the scipy.stats.entropy function calculates the relative entropy, though its obvious in hindsight ... :\

By the way, did you find the metric sensitive to the choice of binning at all?