google-research/nasbench

Not compatible with tensorflow 2.0

CreeperLin opened this issue · 3 comments

The code works well with tensorflow 1.14 but throws the following error when using tf 2.0:

File ".../nasbench/lib/training_time.py", line 130, in
class _TimingRunHook(tf.train.SessionRunHook):
AttributeError: module 'tensorflow_core._api.v2.train' has no attribute 'SessionRunHook'

I think I've implemented a version in tf2.0: https://github.com/ultmaster/nasbench/tree/tf2.

It's not a full version (channel compute is not included) but should be a good start. I've trained a few architectures and sometimes achieved even better results than reported.

That's very helpful, thanks a lot! I noticed that the results reported in paper are evaluated across only 500 individual trials for each search alogrithm.

FWIW, I've managed to run example.py under TensorFlow 2.3.1 with the following tf -> tf.compat.v1 modifications:

diff --git a/nasbench/api.py b/nasbench/api.py
index 236897f..97173f2 100644
--- a/nasbench/api.py
+++ b/nasbench/api.py
@@ -143,7 +143,7 @@ class NASBench(object):
     # {108} for the smaller dataset with only the 108 epochs.
     self.valid_epochs = set()
 
-    for serialized_row in tf.python_io.tf_record_iterator(dataset_file):
+    for serialized_row in tf.compat.v1.python_io.tf_record_iterator(dataset_file):
       # Parse the data from the data file.
       module_hash, epochs, raw_adjacency, raw_operations, raw_metrics = (
           json.loads(serialized_row.decode('utf-8')))
diff --git a/nasbench/lib/evaluate.py b/nasbench/lib/evaluate.py
index b8cbf2c..3c38e82 100644
--- a/nasbench/lib/evaluate.py
+++ b/nasbench/lib/evaluate.py
@@ -27,7 +27,7 @@ import numpy as np
 import tensorflow as tf
 
 VALID_EXCEPTIONS = (
-    tf.train.NanLossDuringTrainingError,  # NaN loss
+    tf.compat.v1.train.NanLossDuringTrainingError,  # NaN loss
     tf.errors.ResourceExhaustedError,     # OOM
     tf.errors.InvalidArgumentError,       # NaN gradient
     tf.errors.DeadlineExceededError,      # Timed out
diff --git a/nasbench/lib/training_time.py b/nasbench/lib/training_time.py
index 691d4ec..56dd1da 100644
--- a/nasbench/lib/training_time.py
+++ b/nasbench/lib/training_time.py
@@ -127,7 +127,7 @@ _TimingVars = collections.namedtuple(  # pylint: disable=g-bad-name
     ])
 
 
-class _TimingRunHook(tf.train.SessionRunHook):
+class _TimingRunHook(tf.compat.v1.train.SessionRunHook):
   """Hook to stop the training after a certain amount of time."""
 
   def __init__(self, max_train_secs=None):
@@ -171,7 +171,7 @@ class _TimingRunHook(tf.train.SessionRunHook):
       run_context.request_stop()
 
 
-class _TimingSaverListener(tf.train.CheckpointSaverListener):
+class _TimingSaverListener(tf.compat.v1.train.CheckpointSaverListener):
   """Saving listener to store the train time up to the last checkpoint save."""
 
   def begin(self):
$ git clone https://github.com/google-research/nasbench
$ cd nasbench
$ virtualenv venv
created virtual environment CPython3.8.5.final.0-64 in 179ms
  creator CPython3Posix(dest=/home/anton/projects/nasbench/venv, clear=False, global=False)
  seeder FromAppData(download=False, ipaddr=latest, progress=latest, urllib3=latest, wheel=latest, distro=latest, pkg_resources=latest, retrying=latest, setuptools=latest, chardet=latest, lockfile=latest, pytoml=latest, colorama=latest, pep517=latest, contextlib2=latest, six=latest, packaging=latest, certifi=latest, webencodings=latest, requests=latest, appdirs=latest, pip=latest, pyparsing=latest, msgpack=latest, idna=latest, html5lib=latest, distlib=latest, CacheControl=latest, via=copy, app_data_dir=/home/anton/.local/share/virtualenv/seed-app-data/v1.0.1.debian)
  activators BashActivator,CShellActivator,FishActivator,PowerShellActivator,PythonActivator,XonshActivator
$ source venv/bin/activate
(venv) $ pip install -e .
  Running setup.py develop for nasbench
Successfully installed nasbench tensorboard-2.4.0 tensorflow-2.3.1 tensorflow-estimator-2.3.0