|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
|
|
|
|
|
#include <stdint.h>
|
|
|
|
|
#include <sys/stat.h>
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <ostream>
|
|
|
|
|
#include <thread>
|
|
|
|
|
|
|
|
|
@ -81,9 +80,9 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
|
auto trainer_count = Attr<int>("Trainers");
|
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
|
rpc_service_->Start();
|
|
|
|
|
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
|
|
|
|
|
while (true) {
|
|
|
|
|
rpc_service_->Start();
|
|
|
|
|
// Get from multiple trainers, we don't care about order in which
|
|
|
|
|
// the gradient arrives, just add suffix 0~n then average the gradient.
|
|
|
|
|
for (size_t i = 0; i < param_count * trainer_count; ++i) {
|
|
|
|
@ -95,8 +94,8 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "recved grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
VLOG(3) << "recved grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
auto *merged_grad = recv_scope.FindVar(grad_var_name);
|
|
|
|
|
if (merged_grad == nullptr) {
|
|
|
|
|
// create output of merged var.
|
|
|
|
@ -113,6 +112,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
// FIXME(typhoonzero): do not copy
|
|
|
|
|
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
|
|
|
|
|
}
|
|
|
|
|
rpc_service_->Start();
|
|
|
|
|
|
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
|
framework::ProgramDesc program_desc;
|
|
|
|
@ -127,14 +127,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
rpc_service_->Done();
|
|
|
|
|
|
|
|
|
|
// for (size_t i = 0; i < param_count; ++i) {
|
|
|
|
|
// auto *out_var = recv_scope.FindVar(param_list[i]);
|
|
|
|
|
// detail::TensorWithName out;
|
|
|
|
|
// out.first = param_list[i];
|
|
|
|
|
// out.second = out_var->Get<framework::LoDTensor>();
|
|
|
|
|
// rpc_service_->Push(out);
|
|
|
|
|
// }
|
|
|
|
|
grads_counter_.clear();
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|