Change Interface to unique_ptr

tonyyang-svail-feed-op-desgin
Yu Yang 7 years ago
parent 495a80a736
commit b2806135a5

@ -42,7 +42,7 @@ The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_`
```cpp
struct OpInfo {
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_;
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
...
};
```
@ -55,11 +55,11 @@ We propose a base class called `GradOpDescMakerBase` to let operator developers
class GradOpDescMakerBase {
public:
GradOpDescMakerBase(const OpDescBind& );
virtual std::vector<OpDescBind> operator()()const = 0;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
};
```
We can convert `GradOpDescMakerBase` to `std::function<std::vector<OpDescBind>(const OpDescBind&)>` by
We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
```cpp
using GradOpMaker = ...;

@ -24,7 +24,7 @@ class GradOpDescMakerBase {
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()() const = 0;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected:
static std::vector<std::string> ToGradNames(
@ -81,10 +81,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<OpDescBind> operator()() const { return {this->Apply()}; }
std::vector<std::unique_ptr<OpDescBind>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv;
retv.emplace_back(this->Apply());
return retv;
}
protected:
virtual OpDescBind Apply() const = 0;
virtual std::unique_ptr<OpDescBind> Apply() const = 0;
};
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
@ -92,23 +96,23 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual OpDescBind Apply() const {
OpDescBind grad;
grad.SetType(this->GradOpType());
virtual std::unique_ptr<OpDescBind> Apply() const {
auto* grad = new OpDescBind();
grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) {
grad.SetInput(input_param, this->Input(input_param));
grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param));
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
}
for (auto& output_param : this->OutputNames()) {
grad.SetInput(output_param, this->Output(output_param));
grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param));
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
}
grad.SetAttrMap(this->Attrs());
grad->SetAttrMap(this->Attrs());
return grad;
return std::unique_ptr<OpDescBind>(grad);
}
virtual std::string GradOpType() const {

@ -28,7 +28,7 @@ namespace framework {
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
std::function<std::vector<OpDescBind>(const OpDescBind&)> grad_op_maker_;
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};

@ -20,6 +20,7 @@
namespace paddle {
namespace framework {
class OperatorBase;
class OpDescBind;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto
@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
} // namespace framework
} // namespace paddle

Loading…
Cancel
Save