ParquetDataset return a error shape .
zhaozheng09 opened this issue · 1 comments
zhaozheng09 commented
error log:
Traceback (most recent call last):
File "train.py", line 832, in <module>
main()
File "train.py", line 544, in main
iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
AttributeError: 'PrefetchDataset' object has no attribute 'output_types'
modify :/root/DeepRec/modelzoo/dlrm
diff --git a/modelzoo/dlrm/train.py b/modelzoo/dlrm/train.py
index 1cd0e7915e..5fbc5ee4f2 100644
--- a/modelzoo/dlrm/train.py
+++ b/modelzoo/dlrm/train.py
@@ -24,6 +24,7 @@ from tensorflow.python.client import timeline
import json
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.data.experimental.ops import parquet_dataset_ops
# Set to INFO for tracking training, default is WARN. ERROR for least messages
tf.logging.set_verbosity(tf.logging.INFO)
@@ -300,6 +301,22 @@ def build_model_input(filename, batch_size, num_epochs):
features = all_columns
return features, labels
+ def parse_parquet(value):
+ cont_defaults = [[0.0] for i in range(1, 14)]
+ cate_defaults = [[' '] for i in range(1, 27)]
+ label_defaults = [[0]]
+ column_headers = TRAIN_DATA_COLUMNS
+ record_defaults = label_defaults + cont_defaults + cate_defaults
+ columns = value
+ vs = []
+ for k,v in columns.items():
+ vs.append(v)
+ all_columns = collections.OrderedDict(zip(column_headers, vs))
+ labels = all_columns.pop(LABEL_COLUMN[0])
+ features = all_columns
+ return features, labels
+
+
'''Work Queue Feature'''
if args.workqueue and not args.tf:
from tensorflow.python.ops.work_queue import WorkQueue
@@ -311,12 +328,8 @@ def build_model_input(filename, batch_size, num_epochs):
JackMoriarty commented
ParquetDataset usage examples have been added to DeepRec modelzoo, please refer to PR#684.