[ROCM] fix softmax with loss nan in HIP platform, test=develop (#31491)

fix_imperative_dygraph_error
Qi Li 4 years ago committed by GitHub
parent f57739be35
commit 416e47edef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -398,7 +398,12 @@ static void HardLabelSoftmaxWithCrossEntropy(
const platform::CUDADeviceContext& ctx, const T* logits_data,
const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
int64_t d, int axis_dim, int ignore_idx) {
#ifdef __HIPCC__
// HIP platform will have loss nan if dim size > 256
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512;
#endif
int64_t block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));

Loading…
Cancel
Save