Fix sum infershape issue

if dim is -1, compile time check fails.

test=develop

Signed-off-by: zhaoyuchen <zhaoyuchen01@baidu.com>
revert-16839-cmakelist_change
zhaoyuchen 6 years ago
parent 82cff5ec42
commit aeddb14148

@ -65,7 +65,21 @@ class SumOp : public framework::OperatorWithKernel {
if (framework::product(in_dim) == 0) {
in_dim = x_dim;
} else {
PADDLE_ENFORCE_EQ(in_dim, x_dim, "Input tensors must have same shape");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(in_dim, x_dim,
"Input tensors must have same shape");
} else {
PADDLE_ENFORCE_EQ(in_dim.size(), x_dim.size(),
"Input tensors must have same shape size");
// if in_dim or x_dim has -1, not check equal
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] == -1 || in_dim[i] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(in_dim[i], x_dim[i],
"Input tensors must have same shape if not -1");
}
}
}
}
ctx->SetOutputDim("Out", in_dim);

Loading…
Cancel
Save