|
|
|
@ -58,7 +58,7 @@ class ClipGradFunctor {
|
|
|
|
|
T max_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class ClipKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
@ -69,12 +69,13 @@ class ClipKernel : public framework::OpKernel {
|
|
|
|
|
T* out_data = out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
int numel = x->numel();
|
|
|
|
|
Transform(context.device_context(), x_data, x_data + numel, out_data,
|
|
|
|
|
ClipFunctor<T>(min, max));
|
|
|
|
|
Transform<Place> trans;
|
|
|
|
|
trans(context.device_context(), x_data, x_data + numel, out_data,
|
|
|
|
|
ClipFunctor<T>(min, max));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class ClipGradKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
@ -88,8 +89,9 @@ class ClipGradKernel : public framework::OpKernel {
|
|
|
|
|
auto d_x_data = d_x->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const T* d_out_data = d_out->data<T>();
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
Transform(context.device_context(), d_out_data, d_out_data + numel,
|
|
|
|
|
x_data, d_x_data, ClipGradFunctor<T>(min, max));
|
|
|
|
|
Transform<Place> trans;
|
|
|
|
|
trans(context.device_context(), d_out_data, d_out_data + numel, x_data,
|
|
|
|
|
d_x_data, ClipGradFunctor<T>(min, max));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|