|
|
|
@ -30,8 +30,8 @@ template <typename Place, typename T>
|
|
|
|
|
class ClipKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto max = context.op().GetAttr<float>("max");
|
|
|
|
|
auto min = context.op().GetAttr<float>("min");
|
|
|
|
|
auto max = context.op().Attr<float>("max");
|
|
|
|
|
auto min = context.op().Attr<float>("min");
|
|
|
|
|
auto* x = context.Input<Tensor>("X");
|
|
|
|
|
auto* out = context.Output<Tensor>("Out");
|
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
@ -46,8 +46,8 @@ template <typename T>
|
|
|
|
|
class ClipGradKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto max = context.op().GetAttr<float>("max");
|
|
|
|
|
auto min = context.op().GetAttr<float>("min");
|
|
|
|
|
auto max = context.op().Attr<float>("max");
|
|
|
|
|
auto min = context.op().Attr<float>("min");
|
|
|
|
|
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* x = context.Output<Tensor>("X");
|
|
|
|
|