facebookresearch/StarSpace

A nice way to obtain predictions?

SkBlaz opened this issue · 1 comments

Hello! First of all, thank you for open sourcing StarSpace, it is a great project.

I was wondering, whether there exists an elegant way of obtaining row-level predictions during testing. For example, if I train the model using:

./starspace train \ -trainFile "${DATADIR}"/ag_news.train \ -model "${MODELDIR}"/modelRandom \ -initRandSd 0.01 \ -adagrad false \ -ngrams 1 \ -lr 0.01 \ -epoch 5 \ -thread 20 \ -dim 10 \ -negSearchLimit 5 \ -trainMode 0 \ -label "__label__" \ -similarity "dot" \ -verbose true

I obtain a trained model called modelRandom. Now, when new instances arrive, I would like to compute, apart from HITS score, also e.g., F1 or AUC (for binary classification).

To my understanding, doing:

./starspace test \ -model "${MODELDIR}"/modelRandom\ -testFile "${DATADIR}"/ag_news.test \ -ngrams 1 \ -dim 10 \ -label "__label__" \ -thread 10 \ -similarity "dot" \ -trainMode 0 \ -verbose true

Already needs both instances, as well as their labels. This outputs HITS scores, yet I am wondering whether it is possible to obtain e.g., top k labels for a given instance.

I also found the following tool:

make query_predict ./query_predict <model> k [basedocs]

which does exactly what I want, apart from the fact that it requires user input (instead of e.g., file input). Am I simply missing something?

So, to summarize, given e.g., an instance comprised of

w_1 w_2 w_53 ...

I would like to obtain, using a trained model, [__label__something ... ] vector of predictions.

Thank you!

Currently solved with

./query_predict tmp/storedModel 1 < ./tmp/test_data.txt