Units on Q matrix for CtHMM
cmaclell opened this issue · 6 comments
I'm trying to do a form of HMM clustering where I fit an HMM to each group of data points, then construct a big HMM that handles the cluster assignment for me. I've done this with regular HMMs in the past and am trying to port the idea over to work with your continuous time HMMs, but I'm a bit unclear on the units of the Q matrix that is passed into the CtHMM constructor.
Your ipython notebooks says # Q is the matrix of transition rates from state [row] to state [column].
Digging into your code, it looks like if I run scipy.linalg.expm(chmm.q)
I can get something akin to transition probabilities that should sum to 1, and I can translate probabilities into transition rates by doing the opposite, scipy.linalg.logm(Q_prob)
does this sound right to you?
What I'm trying to do is to construct an HMM where there are groups of state that are connected to within group, but not between group and there is a start probability that specifies the probability that a sequence is in any one group. So my parameters might look something like this:
Initial probabilities (π) :
0
0 0.633333
1 0.000000
2 0.000000
3 0.066667
4 0.300000
5 0.000000
Transition rate matrix (Q):
0 1 2 3 4 5
0 -0.391591 0.313380 0.078211 0.000000 0.000000 0.000000
1 0.949768 -0.960014 0.010246 0.000000 0.000000 0.000000
2 5.652769 21.046363 -26.699132 0.000000 0.000000 0.000000
3 0.000000 0.000000 0.000000 -0.282953 0.006097 0.276856
4 0.000000 0.000000 0.000000 13.135779 -14.030817 0.895038
5 0.000000 0.000000 0.000000 0.661941 0.003459 -0.665400
Transition probabilities for one time unit :
0 1 2 3 4 5
0 0.790908 0.206685 0.002406 0.000000 0.000000 0.000000
1 0.526236 0.472067 0.001697 0.000000 0.000000 0.000000
2 0.575279 0.422892 0.001829 0.000000 0.000000 0.000000
3 0.000000 0.000000 0.000000 0.820134 0.000402 0.179464
4 0.000000 0.000000 0.000000 0.801739 0.000399 0.197862
5 0.000000 0.000000 0.000000 0.430510 0.000324 0.569167
Emission probabilities matrix (B):
0 1 2
0 0.359413 0.000000 0.640587
1 0.000000 1.000000 0.000000
2 0.333333 0.333333 0.333333
3 0.000000 0.583333 0.416667
4 0.000000 1.000000 0.000000
5 1.000000 0.000000 0.000000
I ask about the units of Q because when I try and fit this CtHMM to the data I get the following error:
File "hmms/cthmm.pyx", line 907, in hmms.cthmm.CtHMM.baum_welch
File "hmms/cthmm.pyx", line 1060, in hmms.cthmm.CtHMM._baum_welch
ValueError: Parameter error! Matrix Q can't contain unreachable states.
I'm not sure if it is because I specified the units wrong, or if maybe your formulation doesn't like the somewhat strange Q matrix I provided.
I'm curious to hear your thoughts and, if possible, to get a brief description of the units for Q.
Thanks!
Chris
Also, let me say, I love the library! I was excited to see a continuous time HMM library on pypi.
So just following up a bit, it looks like there was a runtime error I missed much further up that seems to be the problem: RuntimeWarning: invalid value encountered in true_divide
It looks like the tau parameter is full of NaNs, which is causing the generation of an updated Q matrix in the baum walch to fail, so it just maintains its initialized values (all zeros), hence the error.
Trying to debug now, but don't work in Cython much and I can't seem to stack trace the errors very easily.
So here is a simple example that reproduces the error:
>>> Q = np.array([[1, 0], [0, 1]])
>>> Q2 = scipy.linalg.logm(Q)
>>> B = np.array([[1], [1]])
>>> P = np.array([0.5, 0.5])
>>> chmm = hmms.CtHMM(Q2, B, P)
>>> hmms.print_parameters(chmm)
Initial probabilities (π) :
0
0 0.5
1 0.5
Transition rate matrix (Q):
0 1
0 0.0 0.0
1 0.0 0.0
Transition probabilities for one time unit :
0 1
0 1.0 0.0
1 0.0 1.0
Emission probabilities matrix (B):
0 1
0 1.0 0.0
1 0.0 1.0
>>> seqs = np.array([[0, 0, 0, 0], [1, 1, 1, 1]])
>>> times = np.array([[1, 3, 4, 8], [2, 6, 8, 12]])
>>> chmm.baum_welch(times, seqs)
iteration 1 / 10
__main__:1: RuntimeWarning: invalid value encountered in true_divide
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "hmms/cthmm.pyx", line 907, in hmms.cthmm.CtHMM.baum_welch
return self._baum_welch( t_seqs, e_seqs, kvargs['est'], kvargs['fast'], iterations, method )
File "hmms/cthmm.pyx", line 1060, in hmms.cthmm.CtHMM._baum_welch
raise ValueError("Parameter error! Matrix Q can't contain unreachable states.")
ValueError: Parameter error! Matrix Q can't contain unreachable states.
Here is the result of a bit more searching. It would appear strange things happen when 0's (or near 0's) appear in the Q matrix. If I randomly initialize, but then provide the above seqs and times and run it long enough, then it converges to the above solution and fails with the same error.
Here is an example. I modified cthmm.pyx to print self._q at the beginning of each iteration:
>>> chmm = hmms.CtHMM.random(2, 2)
>>> hmms.print_parameters(chmm)
Initial probabilities (π) :
0
0 0.399793
1 0.600207
Transition rate matrix (Q):
0 1
0 -1.0 1.0
1 1.0 -1.0
Transition probabilities for one time unit :
0 1
0 0.567668 0.432332
1 0.432332 0.567668
Emission probabilities matrix (B):
0 1
0 0.184445 0.815555
1 0.608621 0.391379
>>> seqs = np.array([[0, 0, 0, 0], [1, 1, 1, 1]])
>>> times = np.array([[1, 3, 4, 8], [2, 6, 8, 12]])
>>> chmm.baum_welch(times, seqs, 100)
iteration 1 / 100
[[-1. 1.]
[ 1. -1.]]
iteration 2 / 100
[[-1.02341455 1.02341455]
[ 0.96538453 -0.96538453]]
iteration 3 / 100
[[-1.02321609 1.02321609]
[ 0.95381649 -0.95381649]]
iteration 4 / 100
[[-1.01979822 1.01979822]
[ 0.94472226 -0.94472226]]
iteration 5 / 100
[[-1.0151953 1.0151953 ]
[ 0.93600009 -0.93600009]]
iteration 6 / 100
[[-1.00958299 1.00958299]
[ 0.92737133 -0.92737133]]
iteration 7 / 100
[[-1.00294069 1.00294069]
[ 0.91874687 -0.91874687]]
iteration 8 / 100
[[-0.99522704 0.99522704]
[ 0.91005417 -0.91005417]]
iteration 9 / 100
[[-0.98639911 0.98639911]
[ 0.90121709 -0.90121709]]
iteration 10 / 100
[[-0.97641369 0.97641369]
[ 0.89215316 -0.89215316]]
iteration 11 / 100
[[-0.96522586 0.96522586]
[ 0.88277334 -0.88277334]]
iteration 12 / 100
[[-0.95278692 0.95278692]
[ 0.87298241 -0.87298241]]
iteration 13 / 100
[[-0.939042 0.939042 ]
[ 0.86267944 -0.86267944]]
iteration 14 / 100
[[-0.92392767 0.92392767]
[ 0.85175822 -0.85175822]]
iteration 15 / 100
[[-0.90736967 0.90736967]
[ 0.84010718 -0.84010718]]
iteration 16 / 100
[[-0.88928103 0.88928103]
[ 0.82760883 -0.82760883]]
iteration 17 / 100
[[-0.8695604 0.8695604 ]
[ 0.81413832 -0.81413832]]
iteration 18 / 100
[[-0.84809083 0.84809083]
[ 0.79956103 -0.79956103]]
iteration 19 / 100
[[-0.82473872 0.82473872]
[ 0.78372914 -0.78372914]]
iteration 20 / 100
[[-0.79935295 0.79935295]
[ 0.76647666 -0.76647666]]
iteration 21 / 100
[[-0.77176394 0.77176394]
[ 0.74761291 -0.74761291]]
iteration 22 / 100
[[-0.7417823 0.7417823 ]
[ 0.72691388 -0.72691388]]
iteration 23 / 100
[[-0.70919694 0.70919694]
[ 0.70411067 -0.70411067]]
iteration 24 / 100
[[-0.67377187 0.67377187]
[ 0.67887422 -0.67887422]]
iteration 25 / 100
[[-0.63524134 0.63524134]
[ 0.65079457 -0.65079457]]
iteration 26 / 100
[[-0.59330238 0.59330238]
[ 0.6193528 -0.6193528 ]]
iteration 27 / 100
[[-0.54760375 0.54760375]
[ 0.58388239 -0.58388239]]
iteration 28 / 100
[[-0.49773124 0.49773124]
[ 0.54351673 -0.54351673]]
iteration 29 / 100
[[-0.44319131 0.44319131]
[ 0.49711996 -0.49711996]]
iteration 30 / 100
[[-0.38340577 0.38340577]
[ 0.44320793 -0.44320793]]
iteration 31 / 100
[[-0.31776517 0.31776517]
[ 0.37990856 -0.37990856]]
iteration 32 / 100
[[-0.24590917 0.24590917]
[ 0.30518283 -0.30518283]]
iteration 33 / 100
[[-0.16878943 0.16878943]
[ 0.21814379 -0.21814379]]
iteration 34 / 100
[[-0.09193163 0.09193163]
[ 0.12388825 -0.12388825]]
iteration 35 / 100
[[-0.03097504 0.03097504]
[ 0.04330831 -0.04330831]]
iteration 36 / 100
[[-0.00379829 0.00379829]
[ 0.00541473 -0.00541473]]
iteration 37 / 100
[[-5.86637267e-05 5.86637267e-05]
[ 8.38605549e-05 -8.38605549e-05]]
iteration 38 / 100
[[-1.40328819e-08 1.40328819e-08]
[ 2.00553149e-08 -2.00553149e-08]]
iteration 39 / 100
[[-8.02500517e-16 8.02500517e-16]
[ 1.14670585e-15 -1.14670585e-15]]
iteration 40 / 100
[[-2.62346389e-30 2.62346389e-30]
[ 3.74833250e-30 -3.74833250e-30]]
iteration 41 / 100
[[-2.80308943e-59 2.80308943e-59]
[ 4.00474314e-59 -4.00474314e-59]]
iteration 42 / 100
[[-3.19965658e-117 3.19965658e-117]
[ 4.57115913e-117 -4.57115913e-117]]
iteration 43 / 100
[[-4.16871739e-233 4.16871739e-233]
[ 5.95548040e-233 -5.95548040e-233]]
Traceback (most recent call last):
File "test_hmms.py", line 14, in <module>
chmm.baum_welch(times, seqs, 100)
File "hmms/cthmm.pyx", line 909, in hmms.cthmm.CtHMM.baum_welch
File "hmms/cthmm.pyx", line 1063, in hmms.cthmm.CtHMM._baum_welch
ValueError: Parameter error! Matrix Q can't contain unreachable states.
So digging into cthmm.pyx a bit, I think what is going on is essentially an underflow error.
You precompute the _pt values which converts them into probabilities. Unfortunately, this is not good because the probabilities are very very small (it would be better to keep them as log probabilities maybe, if possible).
Then you have code that converts them back into log probabilities, e.g., tau[i] += numpy.exp( self.log_sum( (ksi_sum[ix] + numpy.log( tA ) ).flatten() ) ) #tau is not in log prob anymore.
, notice the numpy.log(tA)
. I think the tA values just above have tA /= self._pt[ ix ]
just above, where very very small values of _pt will produce NaN values, which is causing the problems I think.
So I found the error, it wasn't actually due to using probabilities instead of log probabilities. Instead, it was due to the way you were calculating tA. In particular, you have multiple lines: tA /= self._pt[ ix ]
(e.g., line 817). In cases where tA is 0, self._pt[ ix] will be 0, and you get NaN values, then the next line tau[i] += numpy.exp( self.log_sum( (ksi_sum[ix] + numpy.log( tA ) ).flatten() ) )
does not handle the NaN's properly.
To fix this, I only divide tA by _pt values when they are non-zero, so I replace tA /= self._pt[ ix ]
by tA = numpy.divide(tA, self._pt[ ix ], out=numpy.zeros_like(tA), where=tA!=0)
.
This fixes my original example. However, if we have a case where there are two states and no transitions between them (only initial probabilities that specify the likelihood of each state). We sometimes have Q matrics that are all zeros.
You have the code that throws an error when this happens (but it is a valid case, especially now that I've fixed the above error):
if sum( self._q.flatten() ) == 0:
raise ValueError("Parameter error! Matrix Q can't contain unreachable states.")
If we remove this code, then my second and third examples work.
After making these changes (and fixing one other line that prints the correct baum_welch iteration) the code passes all your tests as well my examples above. I just sent you a pull request for these changes (#8 ).