mckinziebrandon/DeepChatModels

AttentionModel problem

pazocode opened this issue · 5 comments

./main.py --config=configs/example_cornell.yml Setting up Cornell dataset. Creating DynamicBot . . . Traceback (most recent call last): File "./main.py", line 148, in <module> tf.app.run() File "/home/paz/DEEPCHATMODELTEST/deepchatenv/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 48, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "./main.py", line 139, in main bot = bot_class(dataset, config) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/dynamic_models.py", line 46, in __init__ self.build_computation_graph(dataset) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/dynamic_models.py", line 133, in build_computation_graph loop_embedder=self.embedder) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/components/decoders.py", line 317, in __call__ cell = self.get_cell('attn_cell', initial_state) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/components/decoders.py", line 332, in get_cell initial_cell_state=initial_state) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/components/base/_rnn.py", line 186, in __init__ super(SimpleAttentionWrapper, self).__init__(name=name) TypeError: object.__init__() takes no parameters
After fixing that error manually in the code another thing pops out.

Traceback (most recent call last): File "./main.py", line 148, in <module> tf.app.run() File "/home/paz/DEEPCHATMODELTEST/deepchatenv/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 48, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "./main.py", line 139, in main bot = bot_class(dataset, config) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/dynamic_models.py", line 46, in __init__ self.build_computation_graph(dataset) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/dynamic_models.py", line 133, in build_computation_graph loop_embedder=self.embedder) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/components/decoders.py", line 317, in __call__ cell = self.get_cell('attn_cell', initial_state) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/components/decoders.py", line 332, in get_cell initial_cell_state=initial_state) File "/home/paz/DEEPCHATMODELTEST/DeepChatModels/chatbot/components/base/_rnn.py", line 215, in __init__ self._attention_mechanism.batch_size, AttributeError: 'LuongAttention' object has no attribute 'batch_size'
I looked into source code of tensorflow and found that attribute batch_size is introduced in version 1.2. Long story short, could you please write what configuration of tensorflow and dependencies are you using for training models with attention?

Odd. My hope was that everything should work for versions >= 1.0, so I'll make some changes and push them soon (hopefully within the next 24 hours). Thanks for bringing this to my attention!

@pazocode Are these the only errors you've encountered? I'm running a test container now with tensorflow version r1.2.0-rc2 and I'm getting a few import errors that you did not mention.

I'm pretty surprised how many imports/APIs have changed from just 1.1 to 1.2 (see changes here). It might take a bit longer to get back up for 1.2 (you can always use 1.1) since I'm busy with work, but it should be up soon. I'd happily take a PR for this if you're interested.

Implementation of wrappers in version 1.1 does not contain batch_size for attention mechanisms. Also AttentionWrapper State from 1.1 don't have attributes that you used in your code like alignment history etc... I am using this version of tf1.1 for testin. Could you specify configuration for tf1.1, if everything works on your machine with 1.1 version.

I overcome problems with modification of your code and switching to tf1.2 but I need to test correctness of these changes.

@pazocode Finally got around to looking at this. I pushed what works for me in the branch tf1.2-fixes-issue-6. Does this work for you? I also updated the requirements.txt file in that branch to be the exact TensorFlow version I used to test it.

Edit: Forgot to mention the TF version I was using before. I regularly rebuild from source so I was hesitant to mention it. I've primarily been working in one of the 1.1 version release branches, but I don't have the exact commit hash handy. I've now updated to a "stable" version of 1.2 to avoid any further oddities like this. (Minor release updates are not supposed to break libraries like this)