|
|
|
|
@ -32,9 +32,8 @@ class SoftmaxKernel : public framework::OpKernel<T> {
|
|
|
|
|
Out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
int rank = X->dims().size();
|
|
|
|
|
Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X;
|
|
|
|
|
Tensor Out_2d =
|
|
|
|
|
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
|
|
|
|
|
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
|
|
|
|
|
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
|
|
|
|
|
|
|
|
|
|
math::SoftmaxFunctor<DeviceContext, T>()(
|
|
|
|
|
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
|
|
|
|
|
@ -53,11 +52,9 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
dX->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
int rank = Out->dims().size();
|
|
|
|
|
Tensor Out_2d =
|
|
|
|
|
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
|
|
|
|
|
Tensor dOut_2d =
|
|
|
|
|
rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut;
|
|
|
|
|
Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX;
|
|
|
|
|
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
|
|
|
|
|
Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
|
|
|
|
|
Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
|
|
|
|
|
|
|
|
|
|
math::SoftmaxGradFunctor<DeviceContext, T>()(
|
|
|
|
|
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
|
|
|
|
|
|