covartech/PRT

Find boundaries

anaritam opened this issue · 6 comments

I want to know the values that characterise each class in the classifier.
How can I extranct the class boundaries from it?

In low dimensions (3 or fewer features) you can plot the trained classifier.

myClassifier = myClassifier.train(ds);
plot(myClassifier);

In higher dimensions (4 or more features) I am not sure your question has a simple answer...

I know I can plot it and take the information from the plot, but that would not be very accurate, I'm afraid...

If you add a "decider" to the classifier, the boundaries will be stark, but it still won't be very accurate to go in and click on all the points.

However, I don't think this question is well-posed. The "boundary" is an uncountably infinitely long list of points that in general doesn't have a simpler description than the trained classifier you are using. E.g., after training an SVM, it is not in general possible to summarize the SVM output with a few simple rules (although I think this is an active area of research).

I do have a decider in my classifier. But there should be a way of getting the maximum and miminum values of each feature for which the classifier decides the classes, rigth? If it's possible to plot the classes in different colors, there should be a way of getting the boundaries for each class, I believe...

To generate those plots, we make an NxN grid of points and evaluate the classifier at each point. You can do something similar yourself using meshgrid and your classifier's run method.

As I mentioned before, the "Boundaries" you are looking for are not well defined, e.g., what is the maximum and minimum value of each feature for each class in this plot?

http://groups.csail.mit.edu/ddmg/drupal/content/non-linear-svm-separation

There is, in general, no easy way to get what you want, especially for non-linear classifiers, though you can probably extract something like what you want using the NxN grid trick. See the help for linspace, and ndgrid, then do something like:

minX = -10; maxX = 10;
minY = -10; maxY = 10;
nX = 100; nY = 100;

myClass = prtClassSvm + prtDecisionBinaryMinPe;
ds = prtDataGenBimodal;
myClass = myClass.train(ds);

x = linspace(minX,maxX,nX);
y = linspace(minY,maxY,nY);
[xx,yy] = ndgrid(x,y);
ds = prtDataSetClass(cat(2,xx(:),yy(:)));
yOut = myClass.run(ds);
imageVals = reshape(yOut.X,size(xx));
imagesc(imageVals)

I understand what you'r saing and it makes sense. Thanks for your help!