|
|
|
@ -90,7 +90,7 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> {
|
|
|
|
|
const int output_height = output->dims()[2];
|
|
|
|
|
const int output_width = output->dims()[3];
|
|
|
|
|
const T* input_data = input.data<T>();
|
|
|
|
|
const int * indices_data = indices.data<int>();
|
|
|
|
|
const int* indices_data = indices.data<int>();
|
|
|
|
|
T* output_data = output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
int threads = 1024;
|
|
|
|
|
int grid = (input.numel() + threads - 1) / threads;
|
|
|
|
@ -121,7 +121,7 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
|
|
|
|
|
const int output_height = output.dims()[2];
|
|
|
|
|
const int output_width = output.dims()[3];
|
|
|
|
|
const T* input_data = input.data<T>();
|
|
|
|
|
const int * indices_data = indices.data<int>();
|
|
|
|
|
const int* indices_data = indices.data<int>();
|
|
|
|
|
const T* output_data = output.data<T>();
|
|
|
|
|
const T* output_grad_data = output_grad.data<T>();
|
|
|
|
|
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|