Update transform invocation

update-doc-pybind
wanghaoshuang 8 years ago
parent 3f3848cdf7
commit 1fdad1a60a

@ -80,5 +80,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
ops::ClipOpGrad);
REGISTER_OP_CPU_KERNEL(clip, ops::ClipKernel<float>);
REGISTER_OP_CPU_KERNEL(clip_grad, ops::ClipGradKernel<float>);
REGISTER_OP_CPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::CPUPlace, float>);

@ -15,5 +15,7 @@
#include "paddle/operators/clip_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip, ops::ClipKernel<float>);
REGISTER_OP_GPU_KERNEL(clip_grad, ops::ClipGradKernel<float>);
REGISTER_OP_GPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::GPUPlace, float>);

@ -58,7 +58,7 @@ class ClipGradFunctor {
T max_;
};
template <typename T>
template <typename Place, typename T>
class ClipKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
@ -69,12 +69,13 @@ class ClipKernel : public framework::OpKernel {
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int numel = x->numel();
Transform(context.device_context(), x_data, x_data + numel, out_data,
ClipFunctor<T>(min, max));
Transform<Place> trans;
trans(context.device_context(), x_data, x_data + numel, out_data,
ClipFunctor<T>(min, max));
}
};
template <typename T>
template <typename Place, typename T>
class ClipGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
@ -88,8 +89,9 @@ class ClipGradKernel : public framework::OpKernel {
auto d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>();
const T* x_data = x->data<T>();
Transform(context.device_context(), d_out_data, d_out_data + numel,
x_data, d_x_data, ClipGradFunctor<T>(min, max));
Transform<Place> trans;
trans(context.device_context(), d_out_data, d_out_data + numel, x_data,
d_x_data, ClipGradFunctor<T>(min, max));
}
}
};

Loading…
Cancel
Save