|
|
|
@ -34,34 +34,36 @@ class SendOp : public framework::OperatorBase {
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {
|
|
|
|
|
// init client when the operator is created at runtime.
|
|
|
|
|
if (!client_) {
|
|
|
|
|
std::string endpoint = Attr<std::string>("endpoint");
|
|
|
|
|
client_.reset(new detail::RPCClient(
|
|
|
|
|
grpc::CreateChannel(endpoint, grpc::InsecureChannelCredentials())));
|
|
|
|
|
// TODO(typhoonzero): how to call InitVariables
|
|
|
|
|
std::vector<std::string> endpoints =
|
|
|
|
|
Attr<std::vector<std::string>>("endpoints");
|
|
|
|
|
for (auto ep : endpoints) {
|
|
|
|
|
client_map_[ep].reset(new detail::RPCClient(
|
|
|
|
|
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials())));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
|
const platform::DeviceContext &dev_ctx) const override {
|
|
|
|
|
auto ins = Inputs("X");
|
|
|
|
|
// TODO(typhoonzero): currently it's non-blocking,
|
|
|
|
|
// should block until server responds.
|
|
|
|
|
for (auto in : ins) {
|
|
|
|
|
bool ret = client_->SendVariable(scope, in);
|
|
|
|
|
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
|
|
|
|
|
// TODO(typhoonzero): use async calls to send multiple variable asyncly.
|
|
|
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
|
|
|
bool ret = client_map_[epmap[i]]->SendVariable(scope, ins[i]);
|
|
|
|
|
if (!ret) {
|
|
|
|
|
LOG(ERROR) << "send variable error";
|
|
|
|
|
LOG(ERROR) << "send variable error: " << ins[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto in : ins) {
|
|
|
|
|
bool ret = client_->GetVariable(scope);
|
|
|
|
|
client_map_[0]->Wait(); // TODO(typhoonzero): support async optimization
|
|
|
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
|
|
|
bool ret = client_map_[epmap[i]]->GetVariable(scope, ins[i]);
|
|
|
|
|
if (!ret) {
|
|
|
|
|
LOG(ERROR) << "GetVariable error";
|
|
|
|
|
LOG(ERROR) << "GetVariable error: " << ins[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::shared_ptr<detail::RPCClient> client_{nullptr};
|
|
|
|
|
mutable std::unordered_map<std::string, std::shared_ptr<detail::RPCClient>>
|
|
|
|
|
client_map_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
@ -74,11 +76,13 @@ Recv operator
|
|
|
|
|
|
|
|
|
|
This operator will recv tensor from send_op
|
|
|
|
|
)DOC");
|
|
|
|
|
AddAttr<std::string>("endpoint",
|
|
|
|
|
"(string, default 127.0.0.1:6164)"
|
|
|
|
|
"IP address to listen on.")
|
|
|
|
|
.SetDefault("127.0.0.1:6164")
|
|
|
|
|
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
|
|
|
|
|
AddAttr<std::vector<std::string>>("endpoints",
|
|
|
|
|
"(string vector, default 127.0.0.1:6164)"
|
|
|
|
|
"Server endpoints to send variables to.");
|
|
|
|
|
AddAttr<std::vector<std::string>>("epmap",
|
|
|
|
|
"(string vector, default 127.0.0.1:6164)"
|
|
|
|
|
"Server endpoints in the order of input "
|
|
|
|
|
"variables for mapping");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|