|
|
|
@ -41,7 +41,9 @@ class SumOp : public framework::OperatorWithKernel {
|
|
|
|
|
return; // skip runtime infershape when is tensor array;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto x_var_types = ctx->GetInputsVarType("X");
|
|
|
|
|
auto x_dims = ctx->GetInputsDim("X");
|
|
|
|
|
|
|
|
|
|
size_t N = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_GT(N, 0, "Input tensors count should > 0.");
|
|
|
|
|
if (N == 1) {
|
|
|
|
@ -49,7 +51,11 @@ class SumOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DDim in_dim({0});
|
|
|
|
|
for (auto& x_dim : x_dims) {
|
|
|
|
|
for (size_t i = 0; i < x_dims.size(); ++i) {
|
|
|
|
|
if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto& x_dim = x_dims[i];
|
|
|
|
|
if (framework::product(x_dim) == 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|