|
|
|
@ -31,16 +31,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
|
|
|
|
|
// allocate memory on device.
|
|
|
|
|
Out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto dims = X->dims();
|
|
|
|
|
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
|
|
|
|
|
framework::LoDTensor flattened_x;
|
|
|
|
|
framework::LoDTensor flattened_out;
|
|
|
|
|
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
|
|
|
|
|
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
|
|
|
|
|
int rank = X->dims().size();
|
|
|
|
|
Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X;
|
|
|
|
|
Tensor Out_2d =
|
|
|
|
|
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
|
|
|
|
|
|
|
|
|
|
math::SoftmaxFunctor<DeviceContext, T>()(
|
|
|
|
|
context.template device_context<DeviceContext>(), &flattened_x,
|
|
|
|
|
&flattened_out);
|
|
|
|
|
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -55,18 +52,16 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// allocate memory on device.
|
|
|
|
|
dX->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto dims = Out->dims();
|
|
|
|
|
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
|
|
|
|
|
framework::LoDTensor flattened_out;
|
|
|
|
|
framework::LoDTensor flattened_d_out;
|
|
|
|
|
framework::LoDTensor flattened_d_x;
|
|
|
|
|
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
|
|
|
|
|
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
|
|
|
|
|
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
|
|
|
|
|
int rank = Out->dims().size();
|
|
|
|
|
Tensor Out_2d =
|
|
|
|
|
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
|
|
|
|
|
Tensor dOut_2d =
|
|
|
|
|
rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut;
|
|
|
|
|
Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX;
|
|
|
|
|
|
|
|
|
|
math::SoftmaxGradFunctor<DeviceContext, T>()(
|
|
|
|
|
context.template device_context<DeviceContext>(), &flattened_out,
|
|
|
|
|
&flattened_d_out, &flattened_d_x);
|
|
|
|
|
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
|
|
|
|
|
&dX_2d);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|