|
|
|
@ -17,14 +17,14 @@
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
#include <type_traits>
|
|
|
|
|
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
|
|
|
|
|
#include "paddle/fluid/operators/jitkernels/kernel_pool.h"
|
|
|
|
|
#include "paddle/fluid/operators/jit/kernel_base.h"
|
|
|
|
|
#include "paddle/fluid/operators/jit/kernel_pool.h"
|
|
|
|
|
#include "paddle/fluid/platform/place.h"
|
|
|
|
|
#include "paddle/fluid/platform/variant.h" // for UNUSED
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace jitkernels {
|
|
|
|
|
namespace jit {
|
|
|
|
|
|
|
|
|
|
// make_unique is supported since c++14
|
|
|
|
|
template <typename T, typename... Args>
|
|
|
|
@ -76,21 +76,21 @@ class JitKernelRegistrar {
|
|
|
|
|
msg)
|
|
|
|
|
|
|
|
|
|
// Refer always on CPUPlace
|
|
|
|
|
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
|
|
|
|
|
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
|
|
|
|
|
"REGISTER_KERNEL_REFER must be called in global namespace"); \
|
|
|
|
|
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
|
|
|
|
|
::paddle::operators::jitkernels::ReferKernelPool, \
|
|
|
|
|
::paddle::platform::CPUPlace, __VA_ARGS__> \
|
|
|
|
|
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
|
|
|
|
|
::paddle::operators::jitkernels::KernelType::kernel_type); \
|
|
|
|
|
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
|
|
|
|
|
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
|
|
|
|
|
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
|
|
|
|
|
"REGISTER_KERNEL_REFER must be called in global namespace"); \
|
|
|
|
|
static ::paddle::operators::jit::JitKernelRegistrar< \
|
|
|
|
|
::paddle::operators::jit::ReferKernelPool, ::paddle::platform::CPUPlace, \
|
|
|
|
|
__VA_ARGS__> \
|
|
|
|
|
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
|
|
|
|
|
::paddle::operators::jit::KernelType::kernel_type); \
|
|
|
|
|
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
|
|
|
|
|
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// kernel_type: should be in paddle::operators::jitkernels::KernelType
|
|
|
|
|
// kernel_type: should be in paddle::operators::jit::KernelType
|
|
|
|
|
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
|
|
|
|
|
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
|
|
|
|
|
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
|
|
|
|
@ -99,11 +99,11 @@ class JitKernelRegistrar {
|
|
|
|
|
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
|
|
|
|
|
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
|
|
|
|
|
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
|
|
|
|
|
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
|
|
|
|
|
::paddle::operators::jitkernels::KernelPool, \
|
|
|
|
|
::paddle::platform::place_type, __VA_ARGS__> \
|
|
|
|
|
static ::paddle::operators::jit::JitKernelRegistrar< \
|
|
|
|
|
::paddle::operators::jit::KernelPool, ::paddle::platform::place_type, \
|
|
|
|
|
__VA_ARGS__> \
|
|
|
|
|
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
|
|
|
|
|
::paddle::operators::jitkernels::KernelType::kernel_type); \
|
|
|
|
|
::paddle::operators::jit::KernelType::kernel_type); \
|
|
|
|
|
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
|
|
|
|
|
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
|
|
|
|
|
.Touch(); \
|
|
|
|
@ -139,6 +139,6 @@ class JitKernelRegistrar {
|
|
|
|
|
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
|
|
|
|
|
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
|
|
|
|
|
|
|
|
|
|
} // namespace jitkernels
|
|
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|