|
|
@ -21,20 +21,20 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
|
|
|
|
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
|
|
|
|
|
|
|
|
|
|
|
|
#define CHECK_CASE(i, flags, kernel_name, args...) \
|
|
|
|
#define CHECK_CASE(i, flags, kernel_name, ...) \
|
|
|
|
if (i == flags) { \
|
|
|
|
if (i == flags) { \
|
|
|
|
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(args); \
|
|
|
|
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 0 for no scale, no bias
|
|
|
|
// 0 for no scale, no bias
|
|
|
|
// 1 for has scale, no bias
|
|
|
|
// 1 for has scale, no bias
|
|
|
|
// 2 for no scale, has bias
|
|
|
|
// 2 for no scale, has bias
|
|
|
|
// 3 for has scale, has bias
|
|
|
|
// 3 for has scale, has bias
|
|
|
|
#define UNROLL_ALL_CASES(flags, kernel_name, args...) \
|
|
|
|
#define UNROLL_ALL_CASES(flags, kernel_name, ...) \
|
|
|
|
CHECK_CASE(0, flags, kernel_name, args) \
|
|
|
|
CHECK_CASE(0, flags, kernel_name, __VA_ARGS__) \
|
|
|
|
CHECK_CASE(1, flags, kernel_name, args) \
|
|
|
|
CHECK_CASE(1, flags, kernel_name, __VA_ARGS__) \
|
|
|
|
CHECK_CASE(2, flags, kernel_name, args) \
|
|
|
|
CHECK_CASE(2, flags, kernel_name, __VA_ARGS__) \
|
|
|
|
CHECK_CASE(3, flags, kernel_name, args)
|
|
|
|
CHECK_CASE(3, flags, kernel_name, __VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
|
|
|
|
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
|
|
|
|