|
|
|
@ -88,7 +88,13 @@ void CropCUDAFunctoin(const framework::ExecutionContext& context) {
|
|
|
|
|
int d = out_dims[1];
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (n * d + block - 1) / block;
|
|
|
|
|
CropKernel<T, D><<<grid, block>>>(out_count, out_shape_gpu, x_shape_gpu,
|
|
|
|
|
|
|
|
|
|
auto* device_context =
|
|
|
|
|
const_cast<platform::DeviceContext*>(context.device_context_);
|
|
|
|
|
CropKernel<T,
|
|
|
|
|
D><<<grid, block, 0,
|
|
|
|
|
reinterpret_cast<platform::CUDADeviceContext*>(device_context)
|
|
|
|
|
->stream()>>>(out_count, out_shape_gpu, x_shape_gpu,
|
|
|
|
|
crop_rules_gpu, x_data, out_data);
|
|
|
|
|
cudaFree(crop_rules_gpu);
|
|
|
|
|
cudaFree(x_shape_gpu);
|
|
|
|
|