|
|
|
@ -109,6 +109,29 @@ DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer,
|
|
|
|
|
{framework::GradVarName("Out"),
|
|
|
|
|
framework::GradVarName("X")});
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ClipDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void Apply(GradOpPtr<T> op) const override {
|
|
|
|
|
op->SetType("clip_grad");
|
|
|
|
|
op->SetInput("X", this->Input("X"));
|
|
|
|
|
if (this->HasInput("Min")) {
|
|
|
|
|
op->SetInput("Min", this->Input("Min"));
|
|
|
|
|
}
|
|
|
|
|
if (this->HasInput("Max")) {
|
|
|
|
|
op->SetInput("Max", this->Input("Max"));
|
|
|
|
|
}
|
|
|
|
|
op->SetInput(framework::GradVarName("Out"),
|
|
|
|
|
this->OutputGrad(framework::GradVarName("X")));
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"),
|
|
|
|
|
this->InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
op->SetAttrMap(this->Attrs());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -117,7 +140,9 @@ REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
|
|
|
|
|
ops::ClipGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::ClipGradOpMaker<paddle::imperative::OpBase>,
|
|
|
|
|
ops::ClipInplaceInferer);
|
|
|
|
|
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer);
|
|
|
|
|
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer,
|
|
|
|
|
ops::ClipDoubleGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::ClipDoubleGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::ClipKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|