enable blas jitcode vmul, vadd, vaddrelu, vscal and vaddbias

revert-15207-remove_op_handle_lock_and_fix_var
tensor-tang 7 years ago
parent 5e97be7ba7
commit fd0a954fbf

@ -10,3 +10,8 @@ endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(vmul)
USE_JITKERNEL_GEN(vadd)
#USE_JITKERNEL_GEN(vsub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(vaddrelu)
USE_JITKERNEL_GEN(vscal)
USE_JITKERNEL_GEN(vaddbias)

@ -104,18 +104,28 @@ void VXXJitCode::genCode() {
ret();
}
class VMulCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
return platform::MayIUse(platform::avx);
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
size_t CodeSize(const int& d) const override {
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<VMulJitCode>(attr, CodeSize(attr));
}
};
DECLARE_BLAS_CREATOR(VMul);
DECLARE_BLAS_CREATOR(VAdd);
DECLARE_BLAS_CREATOR(VSub);
DECLARE_BLAS_CREATOR(VAddRelu);
DECLARE_BLAS_CREATOR(VScal);
DECLARE_BLAS_CREATOR(VAddBias);
#undef DECLARE_BLAS_CREATOR
} // namespace gen
} // namespace jit
@ -125,3 +135,9 @@ class VMulCreator : public JitCodeCreator<int> {
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(vmul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(vadd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(vsub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(vaddrelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(vscal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(vaddbias, gen::VAddBiasCreator);

@ -15,6 +15,7 @@
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
@ -33,6 +34,9 @@ class VXXJitCode : public JitCode {
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::mul || type_ == operand_type::add)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
@ -78,12 +82,23 @@ class VXXJitCode : public JitCode {
ymm_t ymm_zero = ymm_t(3);
};
class VMulJitCode : public VXXJitCode {
public:
explicit VMulJitCode(int d, size_t code_size, void* code_ptr = nullptr)
: VXXJitCode(d, operand_type::mul, 0, false, code_size, code_ptr) {}
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
class name##JitCode : public VXXJitCode { \
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
} \
};
DECLARE_BLAS_JITCODE(VMul, operand_type::mul, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::add, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::sub, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::add, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::mul, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::add, 1, false);
#undef DECLARE_BLAS_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators

Loading…
Cancel
Save