davidhallac/TICC

Implement `predict()` method?

Closed this issue · 8 comments

Hello, great work on your paper.

Results of the algorithm look promising. Unfortunately, it seems like I'll have to call solve() method with all training dataset to cluster any new data every time. That's not efficient.

Now I have few questions:

  1. Is it possible to implement predict method on an already trained model?
  2. Could we add load_model and save_model methods?
  3. Will you be following scikit-learn interfaces in the future (__init__,fit, predict, etc)?

I could definitely contribute but after 2 hours of attempting to refactor the code, it seems like there are plenty of dependent variables and super long methods. How can I help?

Hi Mohamed,

These are all great suggestions! While we don’t have immediate plans add the proposed extensions to TICC, I agree that they could be very useful in certain scenarios. Specifically:

  1. One easy way to do so is to allow the user to pass the value of cluster_MRFs as input (this is the second value in the tuple that TICC.solve returns). Then, one could enforce maxIters = 0 and initialize the clusters to have those MRFs, running the dynamic programming step of our algorithm to predict the point assignments.

  2. All you need for LoadModel and SaveModel are the cluster_assignment and cluster_MRFs variables that TICC.solve returns. To LoadModel, one would just need to initialize the clusters and point assignments with that data (instead of using our random initialization scheme)

  3. We do not have any plans to do so at this time, since we expect most use cases of TICC to be with training on a given dataset (rather than using a pre-saved model on a new dataset). However, if there is enough interest in this extension, we could look into converting our solver to this API in the future.

That would be great if you contributed to the code base! I'm more than happy to answer any specific questions about the code that you may have, so feel free to reach out if there's anything we can help with!

Thanks David for your answer. Your algorithm is very promising; the code needs a tiny push to get up to speed.

I have put some effort in restructuring TICC_Solver and making it easier to read and debug. Please see changes in #21. If you accept such changes, it will be easier to apply the points we discussed earlier regarding adding predict, loadModel, saveModel and make it more compatible with scikit learn.

The merge should be straight forward and the code is up to date with your latest changes.

I'm guessing now is the time to work on adding 'predict' method. I believe you have explained the logic up in one of your comments. Since we have separated the fit method into several ones, it would be great to point them out.

David, since structure of TICC now is more organized with class, can you point out the steps that you mentioned in one of your comments to do the "predict" using TICC class.

"...
One easy way to do so is to allow the user to pass the value of cluster_MRFs as input (this is the second value in the tuple that TICC.solve returns). Then, one could enforce maxIters = 0 and initialize the clusters to have those MRFs, running the dynamic programming step of our algorithm to predict the point assignments.
..."

Thanks

Sure, I'd be happy to point out the steps in more detail! Essentially, to implement predict(), all you need to do is run lines 131-139 of TICC_solver.py on the new dataset.

It's a bit more complicated than that since you'll need to move the parameters from inside fit() to be global variables, but once you have the cluster parameters loaded, you just need to run two steps: 1) smoothen_clusters() to get the LLE of each point belonging to each cluster (line 131), and 2) updateClusters() to assign each point to a given cluster (line 139). The variable that is returned (clustered_points) will be a list showing the cluster that has been assigned to each point.

Thanks David, it works well for batch mode, but not for online mode where number of rows are small. Do you have any idea how to deal with that? Any suggestion would be appreciated.

Hmm, that’s a good question, and I don’t know if there’s a “correct” answer on how best to do it. My suggestion would be similar to the “Streaming Algorithm” paragraph in Section 3 in this paper: http://stanford.edu/~hallac/TVGL.pdf.

Essentially, whenever a new data point comes in, you go back a fixed length in the past, and you re-apply the TICC cluster_assignment step (lines 131-139 of TICC_Solver.py) to that most recent subset of points, holding the cluster assignments of all points before that as constant.

Then, if you run the streaming for a long time, you could occasionally update the cluster parameters (using the ADMM part of the TICC solver), as long as you keep track of the sample mean/covariance along the way as you add additional points to the clusters.

Hi! The newest version of TICC now support "predict" functionality, using the predict_clusters method.

Thanks again for the suggestion!