Crash during training/inference with fp16 & factors-combine=concat
arturnn opened this issue · 1 comments
Bug description
Marian crashes during training/inference if --fp16 and --factors-combine=concat are provided at the same time.
The problem comes from embedWithConcat
method. Removing explicit type cast to Type::float32
in graph->constant
call in layers/embedding.cpp:61
seems to fix the issue. Could that be a sufficient fix, or will it break something in the long term?
/*private*/ Expr Embedding::embedWithConcat(const Words& data) const {
auto graph = E_->graph();
std::vector<IndexType> lemmaIndices;
std::vector<float> factorIndices;
factoredVocab_->lemmaAndFactorsIndexes(data, lemmaIndices, factorIndices);
auto lemmaEmbs = rows(E_, lemmaIndices);
int dimFactors = FactorEmbMatrix_->shape()[0];
auto factEmbs
= dot(graph->constant(
{(int)data.size(), dimFactors}, inits::fromVector(factorIndices), Type::float32),
FactorEmbMatrix_);
return concatenate({lemmaEmbs, factEmbs}, -1);
}
How to reproduce
Try to train/decode factored model with --factors-combine=concat and --fp16 options provided at the same time.
Context
- Marian version: v1.11.3 b8bf086 2022-02-11 06:04:38 -0800
[2022-02-11 21:58:10] Error: Child 1 has different type (first: float32 != child: float16)
[2022-02-11 21:58:10] Error: Aborted from static marian::Type marian::NaryNodeOp::commonType(const std::vector<IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase > > >&) in /home/anowakowski/MTExperiments/marian/tools/marian-dev/src/graph/node.h:207
[CALL STACK]
[0x55685e45c1f0] marian::NaryNodeOp:: commonType (std::vector<IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>,std::allocator<IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>>> const&) + 0x2b0
[0x55685e46377f] IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>> marian:: Expression <marian::DotNodeOp,IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>&,IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>&,bool&,bool&,float&>(IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>&, IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>&, bool&, bool&, float&) + 0x12f
[0x55685e39b635] marian:: dot (IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>, IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>, bool, bool, float) + 0x3a5
[0x55685e7b3206] marian::Embedding:: embedWithConcat (std::vector<marian::Word,std::allocatormarian::Word> const&) const + 0x266
Hi Artur, thanks for reporting this. I think your solution should work, would you mind opening a PR?