parent
45bfa70cb8
commit
191948c933
@ -1,3 +1,5 @@
|
||||
|
||||
cc_library(jit_kernel_jitcode SRCS jitcode.cc DEPS jit_kernel_base xbyak)
|
||||
file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
|
||||
|
||||
cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak)
|
||||
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE)
|
||||
|
||||
@ -0,0 +1,118 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
#include "paddle/fluid/operators/jitkernels/jitcode/blas.h"
|
||||
#include "paddle/fluid/operators/jitkernels/registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
namespace jitcode {
|
||||
|
||||
void VXXJitCode::genCode() {
|
||||
// do not need push stack, and do not need save avx512reg if do not use avx512
|
||||
int offset = 0;
|
||||
if (with_relu_) {
|
||||
vxorps(ymm_zero, ymm_zero, ymm_zero);
|
||||
}
|
||||
if (scalar_index_ == 1) {
|
||||
vbroadcastss(ymm_src1, ptr[param1]);
|
||||
} else if (scalar_index_ == 2) {
|
||||
vbroadcastss(ymm_src2, ptr[param2]);
|
||||
}
|
||||
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
|
||||
if (scalar_index_ != 1) {
|
||||
vmovups(ymm_src1, ptr[param1 + offset]);
|
||||
}
|
||||
if (scalar_index_ != 2) {
|
||||
vmovups(ymm_src2, ptr[param2 + offset]);
|
||||
}
|
||||
if (type_ == operand_type::mul) {
|
||||
vmulps(ymm_dst, ymm_src1, ymm_src2);
|
||||
} else if (type_ == operand_type::add) {
|
||||
vaddps(ymm_dst, ymm_src1, ymm_src2);
|
||||
}
|
||||
if (with_relu_) {
|
||||
vmaxps(ymm_dst, ymm_zero, ymm_dst);
|
||||
}
|
||||
vmovups(ptr[param3 + offset], ymm_dst);
|
||||
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
||||
}
|
||||
int rest = num_ % YMM_FLOAT_BLOCK;
|
||||
while (rest > 0) {
|
||||
int block = XMM_FLOAT_BLOCK;
|
||||
if (rest >= 4) {
|
||||
block = 4;
|
||||
if (scalar_index_ != 1) {
|
||||
vmovups(xmm_src1, ptr[param1 + offset]);
|
||||
}
|
||||
if (scalar_index_ != 2) {
|
||||
vmovups(xmm_src2, ptr[param2 + offset]);
|
||||
}
|
||||
} else if (rest >= 2) {
|
||||
block = 2;
|
||||
if (scalar_index_ != 1) {
|
||||
vmovq(xmm_src1, ptr[param1 + offset]);
|
||||
}
|
||||
if (scalar_index_ != 2) {
|
||||
vmovq(xmm_src2, ptr[param2 + offset]);
|
||||
}
|
||||
} else {
|
||||
block = 1;
|
||||
if (scalar_index_ != 1) {
|
||||
vmovss(xmm_src1, ptr[param1 + offset]);
|
||||
}
|
||||
if (scalar_index_ != 2) {
|
||||
vmovss(xmm_src2, ptr[param2 + offset]);
|
||||
}
|
||||
}
|
||||
switch (type_) {
|
||||
case operand_type::mul:
|
||||
vmulps(xmm_dst, xmm_src1, xmm_src2);
|
||||
break;
|
||||
case operand_type::add:
|
||||
vaddps(xmm_dst, xmm_src1, xmm_src2);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
if (with_relu_) {
|
||||
vmaxps(xmm_dst, xmm_zero, xmm_dst);
|
||||
}
|
||||
if (rest >= 4) {
|
||||
vmovups(ptr[param3 + offset], xmm_dst);
|
||||
} else if (rest >= 2) {
|
||||
vmovq(ptr[param3 + offset], xmm_dst);
|
||||
} else {
|
||||
vmovss(ptr[param3 + offset], xmm_dst);
|
||||
}
|
||||
offset += sizeof(float) * block;
|
||||
rest -= block;
|
||||
}
|
||||
ret();
|
||||
}
|
||||
|
||||
} // namespace jitcode
|
||||
|
||||
template <>
|
||||
std::unique_ptr<JitBase> CreateJitCode<KernelType::vmul, float, int>(int attr) {
|
||||
if (UseJitCode<KernelType::vmul, float, int>(attr)) {
|
||||
return make_unique<jitcode::VMulJitCode>(
|
||||
attr, CodeSize<KernelType::vmul, float, int>(attr));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,88 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
namespace jitcode {
|
||||
|
||||
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
|
||||
class VXXJitCode : public JitCode {
|
||||
public:
|
||||
const char* name() const override {
|
||||
std::string base = "VXXJitCode";
|
||||
if (scalar_index_ == 1) {
|
||||
base += "_Scalar";
|
||||
} else {
|
||||
base += "_Vec";
|
||||
}
|
||||
if (type_ == operand_type::mul) {
|
||||
base += "_Mul";
|
||||
} else if (type_ == operand_type::add) {
|
||||
base += "_Add";
|
||||
}
|
||||
if (scalar_index_ == 2) {
|
||||
base += "_Scalar";
|
||||
} else {
|
||||
base += "_Vec";
|
||||
}
|
||||
base += (with_relu_ ? "_Relu" : "");
|
||||
return base.c_str();
|
||||
}
|
||||
explicit VXXJitCode(int d, operand_type type, int scalar_index,
|
||||
bool with_relu, size_t code_size = 256 * 1024,
|
||||
void* code_ptr = nullptr)
|
||||
: JitCode(code_size, code_ptr),
|
||||
num_(d),
|
||||
type_(type),
|
||||
scalar_index_(scalar_index),
|
||||
with_relu_(with_relu) {}
|
||||
// static bool init(int d, int scalar_index = 0);
|
||||
void genCode() override;
|
||||
|
||||
private:
|
||||
int num_;
|
||||
operand_type type_;
|
||||
int scalar_index_;
|
||||
bool with_relu_;
|
||||
reg64_t param1{abi_param1};
|
||||
reg64_t param2{abi_param2};
|
||||
reg64_t param3{abi_param3};
|
||||
|
||||
xmm_t xmm_src1 = xmm_t(0);
|
||||
xmm_t xmm_src2 = xmm_t(1);
|
||||
xmm_t xmm_dst = xmm_t(2);
|
||||
xmm_t xmm_zero = xmm_t(3);
|
||||
|
||||
ymm_t ymm_src1 = ymm_t(0);
|
||||
ymm_t ymm_src2 = ymm_t(1);
|
||||
ymm_t ymm_dst = ymm_t(2);
|
||||
ymm_t ymm_zero = ymm_t(3);
|
||||
};
|
||||
|
||||
class VMulJitCode : public VXXJitCode {
|
||||
public:
|
||||
explicit VMulJitCode(int d, size_t code_size, void* code_ptr = nullptr)
|
||||
: VXXJitCode(d, operand_type::mul, 0, false, code_size, code_ptr) {}
|
||||
};
|
||||
|
||||
} // namespace jitcode
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
Loading…
Reference in new issue