|
|
|
@ -83,7 +83,7 @@ class ClipGradKernel : public framework::OpKernel {
|
|
|
|
|
if (d_x != nullptr) {
|
|
|
|
|
auto* x = context.Input<Tensor>("X");
|
|
|
|
|
int64_t numel = d_out->numel();
|
|
|
|
|
auto d_x_data = d_x->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
|
|
|
|
|
const T* d_out_data = d_out->data<T>();
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
Transform<Place> trans;
|
|
|
|
|