|
|
|
@ -105,8 +105,15 @@ class SumKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto &sel_row = get_selected_row(i);
|
|
|
|
|
first_dim += sel_row.rows().size();
|
|
|
|
|
}
|
|
|
|
|
auto in_dim =
|
|
|
|
|
framework::vectorize(get_selected_row(N - 1).value().dims());
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> in_dim;
|
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
|
auto &sel_row = get_selected_row(i);
|
|
|
|
|
if (sel_row.rows().size() > 0) {
|
|
|
|
|
in_dim = framework::vectorize(sel_row.value().dims());
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
in_dim[0] = static_cast<int64_t>(first_dim);
|
|
|
|
|
|
|
|
|
|
out_value->Resize(framework::make_ddim(in_dim));
|
|
|
|
|