|
|
@ -65,7 +65,21 @@ class SumOp : public framework::OperatorWithKernel {
|
|
|
|
if (framework::product(in_dim) == 0) {
|
|
|
|
if (framework::product(in_dim) == 0) {
|
|
|
|
in_dim = x_dim;
|
|
|
|
in_dim = x_dim;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
PADDLE_ENFORCE_EQ(in_dim, x_dim, "Input tensors must have same shape");
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dim, x_dim,
|
|
|
|
|
|
|
|
"Input tensors must have same shape");
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dim.size(), x_dim.size(),
|
|
|
|
|
|
|
|
"Input tensors must have same shape size");
|
|
|
|
|
|
|
|
// if in_dim or x_dim has -1, not check equal
|
|
|
|
|
|
|
|
for (int i = 0; i < x_dim.size(); ++i) {
|
|
|
|
|
|
|
|
if (x_dim[i] == -1 || in_dim[i] == -1) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dim[i], x_dim[i],
|
|
|
|
|
|
|
|
"Input tensors must have same shape if not -1");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ctx->SetOutputDim("Out", in_dim);
|
|
|
|
ctx->SetOutputDim("Out", in_dim);
|
|
|
|