|
|
@ -24,7 +24,7 @@ class GradOpDescMakerBase {
|
|
|
|
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
|
|
|
|
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
|
|
|
|
|
|
|
|
|
|
|
|
virtual ~GradOpDescMakerBase() = default;
|
|
|
|
virtual ~GradOpDescMakerBase() = default;
|
|
|
|
virtual std::vector<OpDescBind> operator()() const = 0;
|
|
|
|
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
static std::vector<std::string> ToGradNames(
|
|
|
|
static std::vector<std::string> ToGradNames(
|
|
|
@ -81,10 +81,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using GradOpDescMakerBase::GradOpDescMakerBase;
|
|
|
|
using GradOpDescMakerBase::GradOpDescMakerBase;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<OpDescBind> operator()() const { return {this->Apply()}; }
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> operator()() const {
|
|
|
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> retv;
|
|
|
|
|
|
|
|
retv.emplace_back(this->Apply());
|
|
|
|
|
|
|
|
return retv;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
virtual OpDescBind Apply() const = 0;
|
|
|
|
virtual std::unique_ptr<OpDescBind> Apply() const = 0;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
|
|
|
|
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
|
|
|
@ -92,23 +96,23 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
|
|
|
|
using SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
using SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
virtual OpDescBind Apply() const {
|
|
|
|
virtual std::unique_ptr<OpDescBind> Apply() const {
|
|
|
|
OpDescBind grad;
|
|
|
|
auto* grad = new OpDescBind();
|
|
|
|
grad.SetType(this->GradOpType());
|
|
|
|
grad->SetType(this->GradOpType());
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& input_param : this->InputNames()) {
|
|
|
|
for (auto& input_param : this->InputNames()) {
|
|
|
|
grad.SetInput(input_param, this->Input(input_param));
|
|
|
|
grad->SetInput(input_param, this->Input(input_param));
|
|
|
|
grad.SetOutput(GradVarName(input_param), this->InputGrad(input_param));
|
|
|
|
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& output_param : this->OutputNames()) {
|
|
|
|
for (auto& output_param : this->OutputNames()) {
|
|
|
|
grad.SetInput(output_param, this->Output(output_param));
|
|
|
|
grad->SetInput(output_param, this->Output(output_param));
|
|
|
|
grad.SetInput(GradVarName(output_param), this->OutputGrad(output_param));
|
|
|
|
grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
grad.SetAttrMap(this->Attrs());
|
|
|
|
grad->SetAttrMap(this->Attrs());
|
|
|
|
|
|
|
|
|
|
|
|
return grad;
|
|
|
|
return std::unique_ptr<OpDescBind>(grad);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
virtual std::string GradOpType() const {
|
|
|
|
virtual std::string GradOpType() const {
|
|
|
|