|
|
|
@ -419,6 +419,12 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
|
|
|
|
|
int axis_dim = logits->dims()[axis];
|
|
|
|
|
|
|
|
|
|
const int n = SizeToAxis(axis, logits->dims());
|
|
|
|
|
const int d = SizeFromAxis(axis, logits->dims());
|
|
|
|
|
|
|
|
|
|
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
if (axis_dim == 1) {
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
|
|
|
|
|
set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
|
|
|
|
@ -426,12 +432,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int n = SizeToAxis(axis, logits->dims());
|
|
|
|
|
const int d = SizeFromAxis(axis, logits->dims());
|
|
|
|
|
|
|
|
|
|
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto soft_label = context.Attr<bool>("soft_label");
|
|
|
|
|
auto ignore_index = context.Attr<int>("ignore_index");
|
|
|
|
|
|
|
|
|
|