Duplicate Results When Predict for Classification
ZhaoyangChen opened this issue · 2 comments
According to demo
./starspace train \
-trainFile "${DATADIR}"/ag_news.train \
-model "${MODELDIR}"/ag_news \
-initRandSd 0.01 \
-adagrad false \
-ngrams 1 \
-lr 0.01 \
-epoch 5 \
-thread 20 \
-dim 10 \
-negSearchLimit 5 \
-trainMode 0 \
-label "__label__" \
-similarity "dot" \
-verbose true
echo "Start to evaluate trained model:"
./starspace test \
-model "${MODELDIR}"/ag_news \
-testFile "${DATADIR}"/ag_news.test \
-ngrams 1 \
-dim 10 \
-label "__label__" \
-thread 10 \
-similarity "dot" \
-trainMode 0 \
-verbose true \
-predictionFile '/tmp/starspace/pred'
I got the duplicate result in '/tmp/starspace/pred' as follows:
Example 0:
LHS:
, movie studios launch legal offensive against online pirates , los angeles - hollywood studios said thursday they will file hundreds of lawsuits later this month against individuals who swap pirated copies of movies over the internet .
RHS:
__label__4
Predictions:
(--) [0.164135] __label__4
(++) [0.164135] __label__4
(--) [0.0509765] __label__3
(--) [0.0509765] __label__3
(--) [-0.0315386] __label__1Example 1:
LHS:
, super ant colony hits australia , a giant 100km colony of ants which has been discovered in melbourne , australia , could threaten local insect species .
RHS:
__label__4
Predictions:
(--) [0.0536099] __label__4
(++) [0.0536099] __label__4
I wonder if it's a bug or I just misunderstand something.
Hey @ZhaoyangChen did you find a reason why this was happening or how to fix this?
I think the reason is that the loadBaseDocs
function is called twice
https://github.com/facebookresearch/StarSpace/blob/master/src/apps/query_predict.cpp#L33
-> https://github.com/facebookresearch/StarSpace/blob/master/src/starspace.cpp#L130
https://github.com/facebookresearch/StarSpace/blob/master/src/apps/query_predict.cpp#L41
This results in duplicates in baseDocVectors_
which makes the predictOne
function generate duplicate results
https://github.com/facebookresearch/StarSpace/blob/master/src/starspace.cpp#L323