Old datasets: Tapback and Recall New datasets: Recollect 2016 and 2017
Run train_hmm('recollect')
to train the HMM on recollect
Run train_hmm('tapback')
to train the HMM on recall
The follwing list of functions are called:
setup_newDS()
loads the data for recollectsetup_oldDS()
loads the data for tapbackcompute_emission_prob()
returns the emission matrix Bbaum_welch_cont()
finds the transition matrix A and the initial probabilities Pi Parametersreestimate_A()
provides MLE for A and Pisingle_seq()
computes expected sufficient statistics for each training sequencealpha_beta_pass()
carries out the forward-backward procedureforward2()
computes the predictioncompute_rmse()
computes the rmse and the predictionscontrol_fn()
computes the beta distribution to compute emission probabilities.
Run test_hmm(A,Pi,accu,nbacks,nTrials)
, Values returned are: Test_Predicted_Values, Test_Actual_Values, test_rmse, test_B
Run explot3(A,Pi,test_B,test_n_backs_list,Test_Predicted_Values,Test_Actual_Values)
. The plots are saved into 'User_Skill_Trace.pdf'
explot2()
generates plots that correspond to the mean and the variance of the error residuals across all subjects
Run tuning_parameters()
. Does a grid-search. Adjust suitable ranges for each hyper-parameter if needed.
NOTE: Here, theta is 3D. The transition function is theta(X-X', X'-nb_prev, binned(prev_acc)).
Run train_hmm2('recollect')
to train the HMM on recollect
The follwing list of functions are called:
computeA_B()
returns the emission matrix B and the list of transition matrices A for each subjectbaum_welch_cont2()
finds the transition probabilities theta and the initial probabilities Pi ParametersEstep()
provides MLEsingle_seq2()
computes expected sufficient statistics for each training sequencepopulateA()
populates the transition matrix A at each time-step from theta
Run write_file()
- Run
get_cluster_centers()
: Current number of clusters = 3, can be changed with more data pre_train()
: Initializing the cluster centersbaum_welch_cont_EM()
: performs EM clusteringreestimate_A_EM()
: Re-estimates cluster centersassignCluster()
: Assigns a sequence to a cluster
Run analysis_subj16()
and analysis_subj17()
to analyze clusters formed in Recollect-2016 and Recollect-2017 respectively.
NOTE: The Tapback and Recall datasets test subjects that go upto a n-back of 9. The memory score is 0.25, hence state space is 36, with 3 intermediate states between each level.
Run UKFTrain('recollect')
to train the UKF model on recollect
Run UKFTrain('tapback')
to train theUKF model on recall
The follwing list of functions are called:
setup_newDS()
loads the data for recollectsetup_oldDS()
loads the data for tapbacklearn_kalman()
finds the parameters: A, B, Q, R ,X0, V0 via EMEstep()
computes the expected sufficient statisticsukalman_filter()
computes filtered estimatesrts_smoother()
computes smoothened estimatescalc_sigmapoints()
computes the sigma points for the unscented transform.g()
is the non-linear function in the observation model
Run compute_rmse()
, that computes the rmse and the predictions
Run explot3()
. The plots are saved into 'User_Skill_Trace.pdf'
NOTE: this doesn't converge. It can be tried with larger datasets.
Run UKF_EM()
to cluster sequences. Current number of clusters = 2, can be changed with more data. It calls learn_kalman_EM
.