locuslab/SATNet

Unable to replicate the success on N-Queen instances

djin31 opened this issue · 3 comments

I tried adapting the code in sudoku.py for solving N-Queens problem. I formulated the problem as follows: Input instances consists of N-1 out of N queens placed on the board, with input mask placed over all the positions where queens are placed. The output is the corresponding placement of all N queens.

I am using Adam optimizer and binary cross entropy loss and I tried several values for rank and auxiliary variables but could not get any good results on test set (accuracy<40%) while I am able to get 100% accuracy for ranks as low as 10 for 10-queens problem.

I started out with m queens placed but algorithm did not perform well possibly due to multiple possible solution hence moved to the n-1 version. Moreover I tried placing mask over all the rows where queens were already placed, thus making the problem much easier than earlier however even here I was not able to get performance better than 70% on test set.

I did a sanity check on my training code by training on sudoku dataset with m=600 and aux=300 as mentioned in the default params of sudoku.py and was able to get 90+ accuracy within 20 epochs.

This leaves only two things which I may be doing wrong: choice of m and aux, or training data.

Can you suggest any help with the choice of m and aux?

Also n-queens (when n-1 queens are already placed) is much easier problem than sudoku so shouldn't it be solvable with much less data points. 10-queens has 724 solution and from each solution I can get 10 data points. This gives sufficient training data I believe. So any leads here would also be appreciated.

Also as addendum, I tried two runs:

  • One in which I shuffle these 7240 datapoints and then split in train and test in which case I was able to get good accuracy, but this might be due to simply learning up the solutions themselves.
  • So in the next case I split 724 solutions into train and test set, and then construct the data points. In this case I get poor accuracies.

By constructing 10 data points from 1 solution I mean that I can remove 1 queen at a time and this would give me a new input instance.

I was able to train the algorithm after some search for hyperparameters m and aux and modifying the data format to one-hot kind of method so that it is similar to sudoku dataset.

Thanks for letting us know!