!12221 GPU AbsGrad calculat error while input is 0.0

From: @caojian05
Reviewed-by: @kisnwang,@wuxuejian
Signed-off-by: @wuxuejian
pull/12221/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b83478201f

@ -180,7 +180,7 @@ template <typename T>
struct AbsGradFunc { struct AbsGradFunc {
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) { __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
T zero = 0.0; T zero = 0.0;
return lhs < zero ? -rhs : rhs; return lhs < zero ? -rhs : lhs > zero ? rhs : zero;
} }
}; };
@ -188,7 +188,7 @@ template <>
struct AbsGradFunc<half2> { struct AbsGradFunc<half2> {
__device__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) { __device__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
half2 zero(0.0, 0.0); half2 zero(0.0, 0.0);
return lhs < zero ? -rhs : rhs; return lhs < zero ? -rhs : lhs > zero ? rhs : zero;
} }
}; };
@ -200,7 +200,7 @@ struct SquaredDifferenceFunc {
} }
}; };
// Element-wise Comparation // Element-wise Comparison
template <typename T, typename Func> template <typename T, typename Func>
__global__ void ElewiseCmpKernel(const int nums, const T *x0, const T *x1, bool *y) { __global__ void ElewiseCmpKernel(const int nums, const T *x0, const T *x1, bool *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
@ -305,7 +305,7 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, int64_t *y, template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, int64_t *y,
cudaStream_t stream); cudaStream_t stream);
// Broadcast comparation // Broadcast comparison
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } __device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }
template <typename T, typename Func> template <typename T, typename Func>

Loading…
Cancel
Save