|
|
|
@ -47,10 +47,7 @@ class ClipGradFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
|
|
|
|
|
HOSTDEVICE T operator()(const T& x, const T& y) const {
|
|
|
|
|
if (y > min_ && y < max_)
|
|
|
|
|
return x;
|
|
|
|
|
else
|
|
|
|
|
return 0;
|
|
|
|
|
return (y > min_ && y < max_) ? x : 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -68,7 +65,7 @@ class ClipKernel : public framework::OpKernel {
|
|
|
|
|
auto* out = context.Output<Tensor>("Out");
|
|
|
|
|
T* out_data = out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
int numel = x->numel();
|
|
|
|
|
int64_t numel = x->numel();
|
|
|
|
|
Transform<Place> trans;
|
|
|
|
|
trans(context.device_context(), x_data, x_data + numel, out_data,
|
|
|
|
|
ClipFunctor<T>(min, max));
|
|
|
|
|