This code provides a scikit-learn
class that extends the sklearn.discriminant_analysis.LinearDiscriminantAnalysis
analysis to allow for weighting of each sample. For general use of the analysis, please refer to the scikit-learn
webpage:
WeightedLDA
provides the same basic utilities of the scikit-learn
class except for the following limitations:
- Only
solve="svd"
is supported. - The priors cannot be defined a priori.
- The
feature_names_in_
attribute is not included.
The following code shows standard usage given sample data X
, cluster identification y
, and weights wgt
:
from WeightedLDA import WeightedLDA
lda = WeightedLDA()
lda.fit(X, y, wgts=wgts)
Calling fit
without the wgts
variable produces a fit the same as the scikit-learn
class.
I tried to replicate the results numerically with the scikit-learn
class, and was able to do so with all examples except when the sample data has a nullspace. Even then, the resulting predict
method produces consistent results and the transform
method is numerically consistent when the data does not extend into the nullspace of the sampled data.
Write the fit
method using PyTorch
.