|
|
|
@ -116,8 +116,22 @@ class SumKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *out = context.Output<SelectedRows>("Out");
|
|
|
|
|
out->mutable_rows()->clear();
|
|
|
|
|
|
|
|
|
|
math::scatter::MergeAdd<DeviceContext, T> merge_add;
|
|
|
|
|
merge_add(context.template device_context<DeviceContext>(), inputs, out);
|
|
|
|
|
bool has_data = false;
|
|
|
|
|
for (auto &in : inputs) {
|
|
|
|
|
if (in->rows().size() > 0) {
|
|
|
|
|
has_data = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (has_data) {
|
|
|
|
|
math::scatter::MergeAdd<DeviceContext, T> merge_add;
|
|
|
|
|
merge_add(context.template device_context<DeviceContext>(), inputs,
|
|
|
|
|
out);
|
|
|
|
|
} else {
|
|
|
|
|
// no data, just set a empty out tensor.
|
|
|
|
|
out->mutable_value()->mutable_data<T>(framework::make_ddim({0}),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
} else if (out_var->IsType<framework::LoDTensorArray>()) {
|
|
|
|
|
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
|
|
|
|
|
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
|
|
|
|
|