|
|
@ -398,7 +398,12 @@ static void HardLabelSoftmaxWithCrossEntropy(
|
|
|
|
const platform::CUDADeviceContext& ctx, const T* logits_data,
|
|
|
|
const platform::CUDADeviceContext& ctx, const T* logits_data,
|
|
|
|
const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
|
|
|
|
const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
|
|
|
|
int64_t d, int axis_dim, int ignore_idx) {
|
|
|
|
int64_t d, int axis_dim, int ignore_idx) {
|
|
|
|
|
|
|
|
#ifdef __HIPCC__
|
|
|
|
|
|
|
|
// HIP platform will have loss nan if dim size > 256
|
|
|
|
|
|
|
|
constexpr int kMaxBlockDim = 256;
|
|
|
|
|
|
|
|
#else
|
|
|
|
constexpr int kMaxBlockDim = 512;
|
|
|
|
constexpr int kMaxBlockDim = 512;
|
|
|
|
|
|
|
|
#endif
|
|
|
|
int64_t block_dim = axis_dim >= kMaxBlockDim
|
|
|
|
int64_t block_dim = axis_dim >= kMaxBlockDim
|
|
|
|
? kMaxBlockDim
|
|
|
|
? kMaxBlockDim
|
|
|
|
: (1 << static_cast<int>(std::log2(axis_dim)));
|
|
|
|
: (1 << static_cast<int>(std::log2(axis_dim)));
|
|
|
|