|
|
|
@ -78,44 +78,38 @@ class GpuKernelRegister {
|
|
|
|
|
// variable has been created.
|
|
|
|
|
#define uchar unsigned char
|
|
|
|
|
|
|
|
|
|
#define UNIQUE_KERNEL_NAME(kernel) KERNEL_NAME(kernel, __COUNTER__)
|
|
|
|
|
#define UNIQUE_KERNEL_NAME(kernel) KERNEL_NAME(g_##kernel##_gpu_kernel_reg, __COUNTER__)
|
|
|
|
|
#define KERNEL_NAME(kernel, cnt) MERGE(kernel, cnt)
|
|
|
|
|
#define MERGE(kernel, cnt) kernel##cnt
|
|
|
|
|
|
|
|
|
|
#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \
|
|
|
|
|
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_gpu_kernel_reg)(#OPNAME, KernelAttr(), \
|
|
|
|
|
[]() { return new OPCLASS(); });
|
|
|
|
|
#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \
|
|
|
|
|
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, KernelAttr(), []() { return new OPCLASS(); });
|
|
|
|
|
|
|
|
|
|
// regular register of fixed accuracy kernels
|
|
|
|
|
#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \
|
|
|
|
|
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_gpu_kernel_reg)(#OPNAME, ATTR, \
|
|
|
|
|
[]() { return new OPCLASS(); });
|
|
|
|
|
#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \
|
|
|
|
|
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS(); });
|
|
|
|
|
|
|
|
|
|
// register of mixed accuracy kernels which use template and maintain one typename, ignore input num
|
|
|
|
|
#define MS_REG_GPU_KERNEL_SAME(OPNAME, ATTR, OPCLASS, T) \
|
|
|
|
|
static_assert(std::is_base_of<GpuKernel, OPCLASS<T>>::value, " must be base of GpuKernel"); \
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_gpu_kernel_reg)( \
|
|
|
|
|
#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
|
|
|
|
|
|
|
|
|
|
// register of mixed accuracy kernels which use template and maintain one typename
|
|
|
|
|
#define MS_REG_GPU_KERNEL_ONE(OPNAME, ATTR, OPCLASS, T) \
|
|
|
|
|
static_assert(std::is_base_of<GpuKernel, OPCLASS<T>>::value, " must be base of GpuKernel"); \
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_gpu_kernel_reg)( \
|
|
|
|
|
#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
|
|
|
|
|
|
|
|
|
|
// register of mixed accuracy kernels which use template and maintain two typename
|
|
|
|
|
#define MS_REG_GPU_KERNEL_TWO(OPNAME, ATTR, OPCLASS, T, S) \
|
|
|
|
|
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_##S##_gpu_kernel_reg)( \
|
|
|
|
|
#OPNAME, ATTR, []() { return new OPCLASS<T, S>(); });
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T, S>(); });
|
|
|
|
|
|
|
|
|
|
// register of mixed accuracy kernels which use template and maintain three typename
|
|
|
|
|
#define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \
|
|
|
|
|
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S, G>>::value, " must be base of GpuKernel"); \
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg)( \
|
|
|
|
|
#OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); });
|
|
|
|
|
static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); });
|
|
|
|
|
} // namespace kernel
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNELFACTORY_H_
|
|
|
|
|