sum_op selectedRows dim bug fix

revert-12469-sum_op_dim_fix
tangwei12 7 years ago
parent baff71d504
commit c4c8f60bec

@ -105,8 +105,15 @@ class SumKernel : public framework::OpKernel<T> {
auto &sel_row = get_selected_row(i);
first_dim += sel_row.rows().size();
}
auto in_dim =
framework::vectorize(get_selected_row(N - 1).value().dims());
std::vector<int64_t> in_dim;
for (int i = 0; i < N; i++) {
auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() > 0) {
in_dim = framework::vectorize(sel_row.value().dims());
break;
}
}
in_dim[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim));

Loading…
Cancel
Save