microsoft/augmented-interpretable-models

Multiclassification case

gustavecortal opened this issue · 0 comments

Hi, thank you for providing this repo!

I want to use emb-gam for multiclassification (4 classes, embedding size is 768), but I have the following error when the linear coefficients are calculated : "matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 4 is different from 768)"

(line 199 in cache_linear_coefs)

The error is corrected with linear_coef = embs @ coef_embs.T instead of linear_coef = embs @ coef_embs

Now I'm looking for a way to correctly predict the classes (I will probably modify _predict_cached function).