!8308 gpu support kernel register

From: @wilfchen
Reviewed-by: @cristoval,@limingqi107,@cristoval
Signed-off-by: @cristoval
pull/8308/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f0e203eca6

@ -78,34 +78,43 @@ class GpuKernelRegister {
// variable has been created.
#define uchar unsigned char
#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, KernelAttr(), []() { return new OPCLASS(); });
#define UNIQUE_KERNEL_NAME(kernel) KERNEL_NAME(kernel, __COUNTER__)
#define KERNEL_NAME(kernel, cnt) MERGE(kernel, cnt)
#define MERGE(kernel, cnt) kernel##cnt
#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_gpu_kernel_reg)(#OPNAME, KernelAttr(), \
[]() { return new OPCLASS(); });
// regular register of fixed accuracy kernels
#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); });
#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \
static_assert(std::is_base_of<GpuKernel, OPCLASS>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_gpu_kernel_reg)(#OPNAME, ATTR, \
[]() { return new OPCLASS(); });
// register of mixed accuracy kernels which use template and maintain one typename, ignore input num
#define MS_REG_GPU_KERNEL_SAME(OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_gpu_kernel_reg)( \
#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
// register of mixed accuracy kernels which use template and maintain one typename
#define MS_REG_GPU_KERNEL_ONE(OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_gpu_kernel_reg)( \
#OPNAME, ATTR, []() { return new OPCLASS<T>(); });
// register of mixed accuracy kernels which use template and maintain two typename
#define MS_REG_GPU_KERNEL_TWO(OPNAME, ATTR, OPCLASS, T, S) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \
[]() { return new OPCLASS<T, S>(); });
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_##S##_gpu_kernel_reg)( \
#OPNAME, ATTR, []() { return new OPCLASS<T, S>(); });
// register of mixed accuracy kernels which use template and maintain three typename
#define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S, G>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg( \
static const GpuKernelRegister UNIQUE_KERNEL_NAME(g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg)( \
#OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); });
} // namespace kernel
} // namespace mindspore

Loading…
Cancel
Save