Add support for selective HMM parameter freezing during EM algorithm
eonu opened this issue · 1 comments
In most use cases, it is desirable to freely allow all initial & transition probabilities, mean vectors & covariance matrices and mixture weights to be updated freely. However, there may be situations when a user wants to prevent one or more of these parameter types from being updated during the EM algorithm.
hmmlearn
allows this to be specified in the GMMHMM
initialization function, which accepts an argument:
params (string, optional) – Controls which parameters are updated in the training process. Can contain any combination of
‘s’
for startprob,‘t’
for transmat,‘m’
for means, and‘c’
for covars, and‘w’
for GMM mixing weights. Defaults to all parameters.
We can add a variable self._frozen
(initialized in our GMMHMM.__init__
) which is set to None
at the start, then we can create a new function GMMHMM.freeze(params)
, where params
is a string consisting of any of 'stmcw'
which specifies which ones to freeze (reflected by a change to self._frozen
). If freeze()
is called with no parameters, all parameters are frozen. This would have to be called before GMMHMM.fit
(which computes the set difference between 'stmcw'
and self._frozen
and passes it to hmmlearn
's GMMHMM
initialization function as the param
argument).
We should also add an unfreeze
method that also accepts params
and does the opposite of freeze
(and also unfreezes everything of no argument is given), and make a property frozen()
for accessing self._frozen
.
Example:
Allows all parameters to update freely (as before)
# Initialize a HMM
hmm = GMMHMM(label=1, n_states=3, n_components=10)
# Set the inital state distribution and transition matrix
hmm.set_random_initial()
hmm.set_random_transitions()
# Fit the HMM
hmm.fit(X)
Freezes initial state distribution and mixture weights
# Initialize a HMM
hmm = GMMHMM(label=1, n_states=3, n_components=10)
# Set the inital state distribution and transition matrix
hmm.set_random_initial()
hmm.set_random_transitions()
# Prevent initial state distribution and mixture weights from being updated
hmm.freeze('sw')
# Fit the HMM
hmm.fit(X)
Freezes all parameters
# Initialize a HMM
hmm = GMMHMM(label=1, n_states=3, n_components=10)
# Set the inital state distribution and transition matrix
hmm.set_random_initial()
hmm.set_random_transitions()
# Prevent all parameters from being updated
hmm.freeze()
# Fit the HMM
hmm.fit(X)