|
|
|
@ -54,12 +54,11 @@ class ClipGradientOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
int D = d_x->dims()[1];
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (N * D + block - 1) / block;
|
|
|
|
|
auto* device_context =
|
|
|
|
|
const_cast<platform::DeviceContext*>(context.device_context_);
|
|
|
|
|
ClipGradientKernel<
|
|
|
|
|
T><<<grid, block, 0,
|
|
|
|
|
reinterpret_cast<platform::CUDADeviceContext*>(device_context)
|
|
|
|
|
->stream()>>>(count, min, max, x_data, d_out_data, d_x_data);
|
|
|
|
|
ClipGradientKernel<T><<<
|
|
|
|
|
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
context.device_context())
|
|
|
|
|
.stream()>>>(count, min, max, x_data, d_out_data,
|
|
|
|
|
d_x_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|