|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "paddle/fluid/operators/jit/gen/jitcode.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -33,6 +34,9 @@ class VXXJitCode : public JitCode {
|
|
|
|
|
type_(type),
|
|
|
|
|
scalar_index_(scalar_index),
|
|
|
|
|
with_relu_(with_relu) {
|
|
|
|
|
if (!(type_ == operand_type::mul || type_ == operand_type::add)) {
|
|
|
|
|
LOG(FATAL) << "Do not support this operand type: " << type_;
|
|
|
|
|
}
|
|
|
|
|
this->genCode();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -78,11 +82,22 @@ class VXXJitCode : public JitCode {
|
|
|
|
|
ymm_t ymm_zero = ymm_t(3);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class VMulJitCode : public VXXJitCode {
|
|
|
|
|
public:
|
|
|
|
|
explicit VMulJitCode(int d, size_t code_size, void* code_ptr = nullptr)
|
|
|
|
|
: VXXJitCode(d, operand_type::mul, 0, false, code_size, code_ptr) {}
|
|
|
|
|
};
|
|
|
|
|
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
|
|
|
|
|
class name##JitCode : public VXXJitCode { \
|
|
|
|
|
public: \
|
|
|
|
|
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
|
|
|
|
|
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
|
|
|
|
|
} \
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_BLAS_JITCODE(VMul, operand_type::mul, 0, false);
|
|
|
|
|
DECLARE_BLAS_JITCODE(VAdd, operand_type::add, 0, false);
|
|
|
|
|
DECLARE_BLAS_JITCODE(VSub, operand_type::sub, 0, false);
|
|
|
|
|
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::add, 0, true);
|
|
|
|
|
DECLARE_BLAS_JITCODE(VScal, operand_type::mul, 1, false);
|
|
|
|
|
DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false);
|
|
|
|
|
|
|
|
|
|
#undef DECLARE_BLAS_JITCODE
|
|
|
|
|
|
|
|
|
|
} // namespace gen
|
|
|
|
|
} // namespace jit
|
|
|
|
|