|
|
|
@ -35,7 +35,6 @@ class ScaleOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename AttrType>
|
|
|
|
|
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
@ -47,9 +46,9 @@ Scale operator
|
|
|
|
|
|
|
|
|
|
$$Out = scale*X$$
|
|
|
|
|
)DOC");
|
|
|
|
|
AddAttr<AttrType>("scale",
|
|
|
|
|
"(float, default 1.0)"
|
|
|
|
|
"The scaling factor of the scale operator.")
|
|
|
|
|
AddAttr<float>("scale",
|
|
|
|
|
"(float, default 1.0)"
|
|
|
|
|
"The scaling factor of the scale operator.")
|
|
|
|
|
.SetDefault(1.0);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -73,8 +72,7 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker<float>,
|
|
|
|
|
ops::ScaleGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|