covartech/PRT

Pass userData to output data set in cross-validation or binaryToMaryOneVsAll

cratto opened this issue · 6 comments

I wrote a new prtClass in which metadata is written out to DataSet.userData in the runAction() method.

However, when I call crossValidate() on that prtClass or use it as a base classifier in prtClassBinaryToMaryOneVsAll, the userData is not saved to the output.

I would like it such that userData was saved to the DataSet that comes out of either prtClass.crossValidate() or prtClassBinaryToMaryOneVsAll.runAction() Is this an easy internal PRT fix?

For both of these, it's not obvious what the desired result is. crossValidate() calls runAction() for each fold, but currently yields the userData produced for the first fold only. prtClassBinaryToMaryOneVsAll runs runAction() for each class, and discards the userData entirely. We think the best solution is to produce a struct vector where each element comes from one crossvalidation fold or from one of the binary classification runs. This was implemented in the latest commit 85b6087

Here's my example script. You'll need to add the attached class definition to your path. Change the extension to .m (apparently GitHub doesn't like m-files).

ds = prtDataGenUnimodal;
clsUserdata = prtClassAddUserdata;
keys = round(rand(ds.nObservations,1));
results = clsUserdata.crossValidate(ds,keys);
results.userData

%%
dsMulti = prtDataGenMary;
clsUserdata = prtClassAddUserdata;
multiCls = prtClassBinaryToMaryOneVsAll('baseClassifier',clsUserdata);
results = multiCls.rt(dsMulti);
results.userData

prtClassAddUserdata.txt

Thanks Patrick, this works fine for my purposes. Will this be merged into master at some point?

Revisitng this, because it's no longer working for me.

It appears the struct written to userData on the first fold is dropped at line 242 of prtAction.run:
dsOut = postRunProcessing(self, dsIn, dsOut);
Before running this line, dsOut.userData is a struct with one field, 'a'. After running it, dsOut.userData is an empty struct.

If you let cross-validation run until it gets to prtAction.crossValidate, line 380:
dsOut = dsOut.acquireNonDataAttributesFrom(dsIn);
Before running the line, dsOut.userData is a 2x1 empty struct with no fields, but after the line it is 1x1.

I did a git reset --soft 85b6087 on my local copy of devel and still had the same problem.

The script works after re-cloning the devel branch and doing a git reset --hard 85b6087. Are there changes made in later commits that may have undone this?

690bdee

You're right. We had offsetting bugs. One of them was fixed in 8d68581, breaking our userData handling. I have fixed the other.

See #66 (@peterTorrione).
@cratto, does it make sense to use observationInfo instead of userData for your application?