SamsungLabs/eagle

How to obtain final NAS results

Closed this issue · 4 comments

This results is a "search trace" which we later feed into our generic NAS toolkit, the toolkit selects models according to the provided search trace, querying a standard NAS-Bench-201 dataset to obtain accuracy values

As mentioned above in readme, I didn't find the generic NAS toolkit, so I can't reproduce the accuracy results in figure2, figure 3, figure 5, figure 6, table 3 in paper, could you please tell me how to reproduce the nas accuracy results in those figures?

vaenyr commented

The trace gives you a sequence of points you should use to query a relevant nasbench - there is many toolkist that enable you to do so (e.g., naslib), or you can simply write this piece of code yourself, it is rather simple and shouldn't take a lot of time.
You could even use the code provided with our repo, but it does not account for multiple training seeds and, if I recall correctly, we do not report test accuracy anywhere.

Also note that when plotting results we were taking max over test accuracies up to the point x (so all plots are strictly increasing), i.e., y(x) = max(models[:x], key='test_acc').test_acc. The results might be different if you report test accuracy of the model that achieves highest validation accuracy up to the point x, i.e., y(x) = max(models[:x], key='val_acc').test_acc. In a hindsight, the latter is probably more realistic and we should have used it rather than the former, but I didn't think about it when we were writing the paper. Bear in mind the predictor is still trained using validation accuracy exclusively, though.

image Thanks for your reply! If i didn't misunderstand, for example, in the left of this figure, I trained 100 models, so I can just simply use the trace sorted by predicted_accuracy, to choose maybe top 10 rows, and pick a model with max GT_accuracy , let the max GT_accuracy to be the final nas accuracy result (around 73.7% in the figure), instead of training the top 10 models by myself?

You could even use the code provided with our repo, but it does not account for multiple training seeds and, if I recall correctly, we do not report test accuracy anywhere.

I just didn't find where is the code relate to nas final results in repo.

vaenyr commented

This is the function we use to obtain a trace from a log file:

def read_results(logfile):
    ret = []
    all_points = []
    memo = set()
    ranking = True
    with open(logfile, 'r') as f:
        for line in f:
            line = line.strip()
            if line == '---':
                if not ranking:
                    break
                else:
                    ranking = False
                    continue
            if not line:
                continue

            if ranking:
                raw_values = line.split(' ', maxsplit=2)
                assert len(raw_values) == 3
                pt = eval(raw_values[2]) # you might need to change this, depending on a search space
                predicted = float(raw_values[1])
                all_points.append((pt, predicted))
            else:
                pt = eval(line) # you might need to change this, depending on a search space
                ret.append(pt)
                memo.add(pt)

    sorted_all_points = sorted(all_points, key=lambda p: p[1], reverse=True)
    for p in sorted_all_points:
        if p[0] in memo:
            continue

        ret.append(p[0])

    return ret

The code produces a sequence of models that would have been trained during the predictor training (parts of the plots before the vertical dashed line), followed by models that would be trained by following predictions from the trained predictor (parts of the plots after the line). If you have the trace, you can use it similar to the pseudo code below:

nb201 = load_nasbench('/some/path/to/nb201_data')
search_trace = read_results('/path/to/the/log/file')
models = [nb201.query(pt) for pt in search_trace] # assume .query returns a dict containing 'val_acc' and 'train_acc'
xs = list(range(300))
ys = [max(models[:x+1], key=lambda p: p['test_acc'])['test_acc'] for x in xs]

Hope that helps.

It really helps, thank you so much!