|
|
@ -80,7 +80,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
|
|
|
|
auto &send_slr = send_var->Get<framework::SelectedRows>();
|
|
|
|
auto &send_slr = send_var->Get<framework::SelectedRows>();
|
|
|
|
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
|
|
|
|
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
|
|
|
|
|
|
|
|
|
|
|
|
auto send_rows = send_slr.rows();
|
|
|
|
auto &send_rows = send_slr.rows();
|
|
|
|
std::vector<std::vector<int>> outs_rows_idx;
|
|
|
|
std::vector<std::vector<int>> outs_rows_idx;
|
|
|
|
std::vector<std::vector<int>> outs_dense_idx;
|
|
|
|
std::vector<std::vector<int>> outs_dense_idx;
|
|
|
|
|
|
|
|
|
|
|
@ -88,7 +88,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
|
|
|
|
outs_dense_idx.resize(out_num);
|
|
|
|
outs_dense_idx.resize(out_num);
|
|
|
|
|
|
|
|
|
|
|
|
auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0];
|
|
|
|
auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0];
|
|
|
|
auto src = send_slr.value().data<T>();
|
|
|
|
auto *src = send_slr.value().data<T>();
|
|
|
|
|
|
|
|
|
|
|
|
// create output var in local scope
|
|
|
|
// create output var in local scope
|
|
|
|
std::vector<framework::SelectedRows *> outs;
|
|
|
|
std::vector<framework::SelectedRows *> outs;
|
|
|
@ -110,8 +110,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
|
|
|
|
outs[i]->set_height(rpc_ctx.height_sections[i]);
|
|
|
|
outs[i]->set_height(rpc_ctx.height_sections[i]);
|
|
|
|
auto dims = send_slr.GetCompleteDims();
|
|
|
|
auto dims = send_slr.GetCompleteDims();
|
|
|
|
dims[0] = rows_idx.size();
|
|
|
|
dims[0] = rows_idx.size();
|
|
|
|
outs[i]->mutable_value()->mutable_data<T>(dims, send_slr.place());
|
|
|
|
|
|
|
|
outs[i]->mutable_rows()->clear();
|
|
|
|
outs[i]->mutable_rows()->clear();
|
|
|
|
|
|
|
|
outs[i]->mutable_value()->mutable_data<T>(dims, send_slr.place());
|
|
|
|
if (rows_idx.size() > 0) {
|
|
|
|
if (rows_idx.size() > 0) {
|
|
|
|
for (auto idx : rows_idx) {
|
|
|
|
for (auto idx : rows_idx) {
|
|
|
|
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
|
|
|
|
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
|
|
|
|