!10017 Add Dyanmic Shape Support For Equal Operator on GPU

From: @huangxinjing
Reviewed-by: @liangchenghui,@stsuteng
Signed-off-by: @stsuteng
pull/10017/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 4a6e696b38

@ -31,6 +31,25 @@ struct LessFunc {
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; }
}; };
template <typename T>
struct EqualFunc {
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs == rhs ? true : false; }
};
template <>
struct EqualFunc <half> {
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? true : false;
}
};
template <>
struct EqualFunc <float> {
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
return std::abs(lhs - rhs) < 1e-9 ? true : false;
}
};
template <typename T> template <typename T>
struct MinimumFunc { struct MinimumFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; }
@ -188,6 +207,8 @@ void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *
return ElewiseCmpKernel<T, GreaterFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); return ElewiseCmpKernel<T, GreaterFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_LESS: case BROADCAST_TYPE_LESS:
return ElewiseCmpKernel<T, LessFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); return ElewiseCmpKernel<T, LessFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_EQUAL:
return ElewiseCmpKernel<T, EqualFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
default: default:
break; break;
} }
@ -331,6 +352,11 @@ void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t>
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y); y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_EQUAL:
return BroadcastCmpKernel<T, EqualFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
default: default:
break; break;
} }

@ -36,6 +36,7 @@ enum BroadcastOpType {
BROADCAST_TYPE_ABSGRAD = 10, BROADCAST_TYPE_ABSGRAD = 10,
BROADCAST_TYPE_DIV = 11, BROADCAST_TYPE_DIV = 11,
BROADCAST_TYPE_DIVNONAN = 12, BROADCAST_TYPE_DIVNONAN = 12,
BROADCAST_TYPE_EQUAL = 13,
BROADCAST_TYPE_INVALID = 0xffffffff, BROADCAST_TYPE_INVALID = 0xffffffff,
}; };

@ -26,6 +26,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, float) BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Maximum, Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
@ -75,6 +78,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, half) BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Maximum, Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -123,6 +129,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int) BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int) BroadcastOpGpuKernel, int)
@ -135,6 +144,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int) BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int) BroadcastOpGpuKernel, int)
@ -156,6 +168,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), Less, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int64_t) BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t) BroadcastOpGpuKernel, int64_t)
@ -168,6 +183,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), Mul, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t) BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t) BroadcastOpGpuKernel, int64_t)

@ -132,6 +132,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
static std::map<std::string, BroadcastOpType> kBroadcastCmpTypeMap = { static std::map<std::string, BroadcastOpType> kBroadcastCmpTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER}, {"Greater", BROADCAST_TYPE_GREATER},
{"Less", BROADCAST_TYPE_LESS}, {"Less", BROADCAST_TYPE_LESS},
{"Equal", BROADCAST_TYPE_EQUAL},
}; };
auto iter = kBroadcastCmpTypeMap.find(kernel_name); auto iter = kBroadcastCmpTypeMap.find(kernel_name);

Loading…
Cancel
Save