Change Interface to unique_ptr

tonyyang-svail-feed-op-desgin
Yu Yang 7 years ago
parent 495a80a736
commit b2806135a5

@ -42,7 +42,7 @@ The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_`
```cpp ```cpp
struct OpInfo { struct OpInfo {
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_; std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
... ...
}; };
``` ```
@ -55,11 +55,11 @@ We propose a base class called `GradOpDescMakerBase` to let operator developers
class GradOpDescMakerBase { class GradOpDescMakerBase {
public: public:
GradOpDescMakerBase(const OpDescBind& ); GradOpDescMakerBase(const OpDescBind& );
virtual std::vector<OpDescBind> operator()()const = 0; virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
}; };
``` ```
We can convert `GradOpDescMakerBase` to `std::function<std::vector<OpDescBind>(const OpDescBind&)>` by We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
```cpp ```cpp
using GradOpMaker = ...; using GradOpMaker = ...;

@ -24,7 +24,7 @@ class GradOpDescMakerBase {
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {} explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
virtual ~GradOpDescMakerBase() = default; virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()() const = 0; virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected: protected:
static std::vector<std::string> ToGradNames( static std::vector<std::string> ToGradNames(
@ -81,10 +81,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<OpDescBind> operator()() const { return {this->Apply()}; } std::vector<std::unique_ptr<OpDescBind>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv;
retv.emplace_back(this->Apply());
return retv;
}
protected: protected:
virtual OpDescBind Apply() const = 0; virtual std::unique_ptr<OpDescBind> Apply() const = 0;
}; };
class DefaultGradOpDescMaker : public SingleGradOpDescMaker { class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
@ -92,23 +96,23 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
using SingleGradOpDescMaker::SingleGradOpDescMaker; using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
virtual OpDescBind Apply() const { virtual std::unique_ptr<OpDescBind> Apply() const {
OpDescBind grad; auto* grad = new OpDescBind();
grad.SetType(this->GradOpType()); grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) { for (auto& input_param : this->InputNames()) {
grad.SetInput(input_param, this->Input(input_param)); grad->SetInput(input_param, this->Input(input_param));
grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param)); grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
} }
for (auto& output_param : this->OutputNames()) { for (auto& output_param : this->OutputNames()) {
grad.SetInput(output_param, this->Output(output_param)); grad->SetInput(output_param, this->Output(output_param));
grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param)); grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
} }
grad.SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
return grad; return std::unique_ptr<OpDescBind>(grad);
} }
virtual std::string GradOpType() const { virtual std::string GradOpType() const {

@ -28,7 +28,7 @@ namespace framework {
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
std::string grad_op_type_; std::string grad_op_type_;
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_; GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr}; OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};

@ -20,6 +20,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase; class OperatorBase;
class OpDescBind;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto // The order should be as same as framework.proto
@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

Loading…
Cancel
Save