[ROCM] fix softmax_with_cross_entropy_op, test=develop (#31629)

pull/1/head
ronnywang 4 years ago committed by GitHub
parent 75433126df
commit da10c5cf8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -452,12 +452,7 @@ struct HardLabelCrossEntropyFunctorWithIgnoreIdx {
// labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain;
if (idx_axis == ignore_idx_) {
loss_[idx_lbl] = 0;
return;
}
if (idx_axis == labels_[idx_lbl]) {
if (idx_axis == labels_[idx_lbl] && idx_axis != ignore_idx_) {
loss_[idx_lbl] = -log_on_device(logits_data_[idx]);
}
}
@ -732,7 +727,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(
template <typename T>
static void CrossEntropyFusedKernel(const T* logits_data, const T* labels_data,
T* loss_data, int n, int d, int axis_dim,
cudaStream_t stream) {
gpuStream_t stream) {
constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
@ -792,11 +787,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
auto* softmax_out_data = softmax_out->mutable_data<T>(context.GetPlace());
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
if (axis_dim == 1) {
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), softmax_out,
static_cast<T>(1));
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
return;
}

@ -116,7 +116,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D(
self.shape = [13, 8]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
@ -129,7 +129,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D(
self.shape = [13, 8]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
@ -145,7 +145,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D(
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
@ -155,7 +155,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 1
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
@ -168,7 +168,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 2
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
@ -181,7 +181,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 3
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
@ -206,7 +206,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D(
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
@ -216,7 +216,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 1
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
@ -229,7 +229,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 2
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
@ -242,7 +242,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 3
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
@ -267,7 +267,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore(
self.shape = [13, 8]
self.axis = -1
self.ignore_index = 2
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
@ -280,7 +280,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis(
self.shape = [13, 8]
self.axis = 1
self.ignore_index = 2
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
@ -293,7 +293,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore(
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = 2
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.softmax_switch = False #default is true, means "with softmax"
@ -303,7 +303,7 @@ class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3(
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.axis = 2
self.ignore_index = 2
self.shape = [3, 5, 7, 11]

Loading…
Cancel
Save