probml/dynamax

Fit CategoricalHMM with available data?

Opened this issue · 1 comments

I have walked through the example usage in: https://probml.github.io/dynamax/notebooks/hmm/casino_hmm_learning.html

However, the params and promps are all generated by initialize function in the example, and if I have ready-to-use lists, one input lists (X in general ML), and one label list (y in general ML), how could I use the fit function?

I know this question could be naive, but I'm relatively new to Python. I greatly appreciate someone could help.

Hi @AndyWeasley2004,

I'm not sure I have totally understood your question correctly but perhaps the following might be helpful.

There is an example at the start of the demo of selecting the values of the parameters:

num_states = 2      # two types of dice (fair and loaded)
num_emissions = 1   # only one die is rolled at a time
num_classes = 6     # each die has six faces

initial_probs = jnp.array([0.5, 0.5])
transition_matrix = jnp.array([[0.95, 0.05], 
                               [0.10, 0.90]])
emission_probs = jnp.array([[1/6,  1/6,  1/6,  1/6,  1/6,  1/6],    # fair die
                            [1/10, 1/10, 1/10, 1/10, 1/10, 5/10]])  # loaded die


# Construct the HMM
hmm = CategoricalHMM(num_states, num_emissions, num_classes)

# Initialize the parameters struct with known values
params, _ = hmm.initialize(initial_probs=initial_probs,
                           transition_matrix=transition_matrix,
                           emission_probs=emission_probs.reshape(num_states, num_emissions, num_classes))

In this example the values of the parameters are determined by the values in the arrays initial_probs, transition_matrix, and emission_probs.

What the hmm.initialize method is doing in this example is taking the arrays we have defined and converting them into the appropriate parameter objects (an instance of ParamsCategoricalHMM).

If you wanted to use your own parameter values then you can convert them into jax arrays (jnp.array(param_list)) and pass that into the initialize method as above.

The initialize method also allows you to sample random values for the parameters by passing a key. This approach is used later on in the demo for example:

key = jr.PRNGKey(0)
em_params, em_param_props = hmm.initialize(key)
em_params, log_probs = hmm.fit_em(em_params, 
                                  em_param_props, 
                                  batch_emissions, 
                                  num_iters=400)

It sounds like your "label list" might correspond to the hmm emissions, in which case you can pass that to fit function as the emissions argument (the array batch_emissions is being passed as the value for that argument in the example above).

I am not entirely sure what you mean by your "input lists" but if you can provide some more details I am happy to see if I can help further.