fix-develop-build.sh
Bai Yifan 7 years ago committed by qingqing01
parent ae67dcea09
commit e69d9c845b

@ -31,7 +31,8 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size;
i += blockDim.x * gridDim.x) {
int idx = i * class_num + labels[i];
logit_grad[idx] -= static_cast<T>(1.);
logit_grad[idx] -=
ignore_index == labels[i] ? static_cast<T>(0.) : static_cast<T>(1.);
}
}

Loading…
Cancel
Save