|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
#include <stdint.h>
|
|
|
|
|
#include <sys/stat.h>
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <ostream>
|
|
|
|
|
#include <thread>
|
|
|
|
|
|
|
|
|
@ -63,14 +64,32 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
|
const platform::DeviceContext &dev_ctx) const override {
|
|
|
|
|
// blocking get one var from client.
|
|
|
|
|
const framework::LoDTensor &t = rpc_service_->Get();
|
|
|
|
|
framework::Scope &recv_scope = scope.NewScope();
|
|
|
|
|
// blocking get one var from client.
|
|
|
|
|
const detail::TensorWithName &v = rpc_service_->Get();
|
|
|
|
|
auto grad_var_name = v.first;
|
|
|
|
|
|
|
|
|
|
// framework::Scope &recv_scope = scope.NewScope();
|
|
|
|
|
auto param_list = Attr<std::vector<std::string>>("ParamList");
|
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
|
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
|
|
|
|
|
std::string param_var_name;
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
}
|
|
|
|
|
// set graph input var
|
|
|
|
|
auto *var = recv_scope.Var(Input("RX"));
|
|
|
|
|
auto input_grad = Input("RX");
|
|
|
|
|
|
|
|
|
|
// FIXME(typhoonzero): Find the parameter name from input grad name
|
|
|
|
|
// rename X -> Param
|
|
|
|
|
// rename RX -> Grad
|
|
|
|
|
auto *var = recv_scope.FindVar(input_grad);
|
|
|
|
|
auto *tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
recv_scope.Rename(param_var_name, "Param");
|
|
|
|
|
recv_scope.Rename("RX", "Grad");
|
|
|
|
|
|
|
|
|
|
// FIXME(typhoonzero): do not copy
|
|
|
|
|
framework::CopyFrom(t, dev_ctx.GetPlace(), dev_ctx, tensor);
|
|
|
|
|
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
|
|
|
|
|
|
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
|
framework::ProgramDesc program_desc;
|
|
|
|
@ -81,9 +100,14 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
executor.Run(program, &recv_scope, 0, /*global_block*/
|
|
|
|
|
false /*create_local_scope*/);
|
|
|
|
|
|
|
|
|
|
auto *out_var = recv_scope.FindVar("Out");
|
|
|
|
|
// push back
|
|
|
|
|
rpc_service_->Push(out_var->Get<framework::LoDTensor>());
|
|
|
|
|
auto *out_var = recv_scope.FindVar("Param");
|
|
|
|
|
detail::TensorWithName out;
|
|
|
|
|
out.first = param_var_name;
|
|
|
|
|
out.second = out_var->Get<framework::LoDTensor>();
|
|
|
|
|
rpc_service_->Push(out);
|
|
|
|
|
// rename back the params
|
|
|
|
|
recv_scope.Rename("Param", param_var_name);
|
|
|
|
|
recv_scope.Rename("Grad", "RX");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -93,13 +117,14 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
// grpc send/recv service implement to register.
|
|
|
|
|
std::shared_ptr<detail::SendRecvServerImpl> rpc_service_;
|
|
|
|
|
std::shared_ptr<std::thread> server_thread_;
|
|
|
|
|
framework::Scope const *recv_scope_{nullptr};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
RecvOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("RX", "(Tensor) Input tensor to be saved");
|
|
|
|
|
AddInput("RX", "(Tensor) Input tensor to be optimized").AsDuplicable();
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Recv operator
|
|
|
|
|
|
|
|
|
@ -112,6 +137,12 @@ This operator will recv tensor from send_op
|
|
|
|
|
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
|
|
|
|
|
AddAttr<std::string>("OptimizeProgram", "type string",
|
|
|
|
|
"Serialized ProgramDesc string for recv to run.");
|
|
|
|
|
AddAttr<std::vector<std::string>>(
|
|
|
|
|
"ParamList", "type list of string",
|
|
|
|
|
"grad->param name mapping to find which param to optimize.");
|
|
|
|
|
AddAttr<std::vector<std::string>>(
|
|
|
|
|
"GradList", "type list of string",
|
|
|
|
|
"grad->param name mapping to find which param to optimize.");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|