|
|
|
@ -30,8 +30,16 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
// allocate memory on device.
|
|
|
|
|
Out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto dims = X->dims();
|
|
|
|
|
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
|
|
|
|
|
framework::LoDTensor flattened_x;
|
|
|
|
|
framework::LoDTensor flattened_out;
|
|
|
|
|
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
|
|
|
|
|
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
|
|
|
|
|
|
|
|
|
|
math::SoftmaxCUDNNFunctor<T>()(
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>(), X, Out);
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>(),
|
|
|
|
|
&flattened_x, &flattened_out);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -46,9 +54,18 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
// allocate memory on device.
|
|
|
|
|
dX->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto dims = Out->dims();
|
|
|
|
|
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
|
|
|
|
|
framework::LoDTensor flattened_out;
|
|
|
|
|
framework::LoDTensor flattened_d_out;
|
|
|
|
|
framework::LoDTensor flattened_d_x;
|
|
|
|
|
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
|
|
|
|
|
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
|
|
|
|
|
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
|
|
|
|
|
|
|
|
|
|
math::SoftmaxGradCUDNNFunctor<T>()(
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>(), Out,
|
|
|
|
|
dOut, dX);
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>(),
|
|
|
|
|
&flattened_out, &flattened_d_out, &flattened_d_x);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|