|
|
|
@ -121,9 +121,27 @@ This operator will send tensor to recv_op at the parameter server.
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SendOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc& op_desc,
|
|
|
|
|
framework::BlockDesc* block) const override {
|
|
|
|
|
auto out_var_name = op_desc.Output("RPCClient").front();
|
|
|
|
|
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
|
|
|
|
|
auto var_type = framework::proto::VarType::RAW;
|
|
|
|
|
out_var.SetType(var_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SendOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext* ctx) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker);
|
|
|
|
|
REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::SendOpMaker, ops::SendOpVarTypeInference,
|
|
|
|
|
ops::SendOpShapeInference);
|
|
|
|
|