Merge pull request #14958 from tensor-tang/refine/jit
enhance jitrevert-15207-remove_op_handle_lock_and_fix_var
commit
693e5e65ce
@ -0,0 +1,25 @@
|
||||
|
||||
set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
|
||||
file(WRITE ${jit_file} "// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!\n\n")
|
||||
file(APPEND ${jit_file} "\#pragma once\n")
|
||||
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
|
||||
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
|
||||
|
||||
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
|
||||
|
||||
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
|
||||
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
|
||||
cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
|
||||
|
||||
# refer must go first
|
||||
add_subdirectory(refer)
|
||||
add_subdirectory(more)
|
||||
if(WITH_XBYAK)
|
||||
add_subdirectory(gen)
|
||||
endif()
|
||||
|
||||
cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
|
||||
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper)
|
||||
if(NOT WIN32)
|
||||
cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper)
|
||||
endif()
|
@ -0,0 +1,231 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "gflags/gflags.h"
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/operators/jit/kernels.h"
|
||||
#include "paddle/fluid/platform/device_tracer.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "paddle/fluid/platform/port.h"
|
||||
|
||||
DEFINE_int32(burning, 10, "Burning times.");
|
||||
DEFINE_int32(repeat, 3000, "Repeat times.");
|
||||
DEFINE_int32(max_size, 1000, "The Max size would be tested.");
|
||||
|
||||
template <typename T>
|
||||
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
|
||||
const T upper = static_cast<T>(20.f), unsigned int seed = 100) {
|
||||
std::mt19937 rng(seed);
|
||||
std::uniform_real_distribution<double> uniform_dist(0, 1);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> TestSizes() {
|
||||
std::vector<int> s;
|
||||
for (int i = 1; i <= FLAGS_max_size; ++i) {
|
||||
s.push_back(i);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
template <typename KernelTuples, typename... Args>
|
||||
struct BenchFunc {
|
||||
// return this function avg time
|
||||
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
|
||||
for (int i = 0; i < FLAGS_burning; ++i) {
|
||||
tgt(args...);
|
||||
}
|
||||
auto start = paddle::platform::PosixInNsec() / 1e-3;
|
||||
for (int i = 0; i < FLAGS_repeat; ++i) {
|
||||
tgt(args...);
|
||||
}
|
||||
auto end = paddle::platform::PosixInNsec() / 1e-3;
|
||||
return static_cast<double>(end - start) / FLAGS_repeat;
|
||||
}
|
||||
};
|
||||
|
||||
namespace jit = paddle::operators::jit;
|
||||
|
||||
template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
|
||||
typename... Args>
|
||||
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
|
||||
BenchFunc<KernelTuples, Args...> benchmark;
|
||||
std::vector<std::pair<std::string, double>> infos;
|
||||
// test refer
|
||||
auto refer = jit::GetRefer<KT, KernelTuples>();
|
||||
if (!refer) {
|
||||
LOG(FATAL) << "Refer can not be empty!";
|
||||
}
|
||||
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
|
||||
|
||||
// test jitcode
|
||||
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
|
||||
if (jitcode) {
|
||||
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
|
||||
}
|
||||
// test all impls in more
|
||||
jit::KernelKey kkey(KT, PlaceType());
|
||||
auto& pool = jit::KernelPool().Instance().AllKernels();
|
||||
auto iter = pool.find(kkey);
|
||||
if (iter != pool.end()) {
|
||||
auto& impls = iter->second;
|
||||
for (auto& impl : impls) {
|
||||
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
|
||||
if (i && i->UseMe(attr)) {
|
||||
auto more = i->GetFunc();
|
||||
infos.push_back(
|
||||
std::make_pair(i->ImplType(), benchmark(more, args...)));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Test result from Get function
|
||||
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
|
||||
if (!tgt) {
|
||||
LOG(FATAL) << "Target can not be empty!";
|
||||
}
|
||||
infos.push_back(std::make_pair("Target", benchmark(tgt, args...)));
|
||||
|
||||
// print
|
||||
std::ostringstream loginfos;
|
||||
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": ";
|
||||
for (auto pair : infos) {
|
||||
loginfos << pair.first << " takes " << pair.second << " us; ";
|
||||
}
|
||||
LOG(INFO) << loginfos.str();
|
||||
}
|
||||
|
||||
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
||||
void BenchXYZNKernel() {
|
||||
for (int d : TestSizes()) {
|
||||
std::vector<T> x(d), y(d), z(d);
|
||||
RandomVec<T>(d, x.data());
|
||||
RandomVec<T>(d, y.data());
|
||||
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data(), y.data(),
|
||||
z.data(), d);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
||||
void BenchAXYNKernel() {
|
||||
for (int d : TestSizes()) {
|
||||
const T a = static_cast<T>(3);
|
||||
std::vector<T> x(d), y(d);
|
||||
RandomVec<T>(d, x.data());
|
||||
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data(), y.data(),
|
||||
d);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
||||
void BenchXYNKernel() {
|
||||
for (int d : TestSizes()) {
|
||||
std::vector<T> x(d), y(d);
|
||||
RandomVec<T>(d, x.data());
|
||||
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data(), y.data(), d);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
||||
void BenchLSTMKernel() {
|
||||
for (bool use_peephole : {true, false}) {
|
||||
for (int d : TestSizes()) {
|
||||
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
|
||||
use_peephole);
|
||||
std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d);
|
||||
RandomVec<T>(4 * d, x.data(), -2.f, 2.f);
|
||||
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
|
||||
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
|
||||
const T* ct_1_data = ct_1.data();
|
||||
const T* wp_data = wp.data();
|
||||
T* x_data = x.data();
|
||||
T* checked_data = checked.data();
|
||||
T* ct_data = ct.data();
|
||||
T* ht_data = ht.data();
|
||||
jit::lstm_t step;
|
||||
step.gates = x_data;
|
||||
step.ct_1 = ct_1_data;
|
||||
step.ct = ct_data;
|
||||
step.ht = ht_data;
|
||||
if (use_peephole) {
|
||||
step.wp = wp_data;
|
||||
step.checked = checked_data;
|
||||
}
|
||||
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
|
||||
void BenchGRUKernel() {
|
||||
for (int d : TestSizes()) {
|
||||
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
|
||||
std::vector<T> x(3 * d), ht_1(d), ht(d);
|
||||
RandomVec<T>(3 * d, x.data(), -2.f, 2.f);
|
||||
RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
|
||||
const T* ht_1_data = ht_1.data();
|
||||
T* x_data = x.data();
|
||||
T* ht_data = ht.data();
|
||||
jit::gru_t step;
|
||||
step.gates = x_data;
|
||||
step.ht_1 = ht_1_data;
|
||||
step.ht = ht_data;
|
||||
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr);
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark all jit kernels including jitcode, mkl and refer.
|
||||
// To use this tool, run command: ./benchmark [options...]
|
||||
// Options:
|
||||
// --burning: the burning time before count
|
||||
// --repeat: the repeat times
|
||||
// --max_size: the max size would be tested
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
|
||||
<< " times.";
|
||||
using T = float;
|
||||
using PlaceType = paddle::platform::CPUPlace;
|
||||
// xyzn
|
||||
BenchXYZNKernel<jit::kVMul, T, PlaceType>();
|
||||
BenchXYZNKernel<jit::kVAdd, T, PlaceType>();
|
||||
BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>();
|
||||
BenchXYZNKernel<jit::kVSub, T, PlaceType>();
|
||||
|
||||
// axyn
|
||||
BenchAXYNKernel<jit::kVScal, T, PlaceType>();
|
||||
BenchAXYNKernel<jit::kVAddBias, T, PlaceType>();
|
||||
|
||||
// xyn
|
||||
BenchXYNKernel<jit::kVRelu, T, PlaceType>();
|
||||
BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
|
||||
BenchXYNKernel<jit::kVExp, T, PlaceType>();
|
||||
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
|
||||
BenchXYNKernel<jit::kVTanh, T, PlaceType>();
|
||||
|
||||
// lstm and peephole
|
||||
BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>();
|
||||
BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>();
|
||||
|
||||
// gru functions
|
||||
BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
|
||||
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
|
||||
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
|
||||
file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
|
||||
|
||||
cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak)
|
||||
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE)
|
||||
|
||||
function(USE_JITKERNEL_GEN TARGET)
|
||||
file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n")
|
||||
endfunction()
|
||||
|
||||
# use gen jitcode kernel by name
|
||||
USE_JITKERNEL_GEN(kVMul)
|
||||
USE_JITKERNEL_GEN(kVAdd)
|
||||
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
|
||||
USE_JITKERNEL_GEN(kVAddRelu)
|
||||
USE_JITKERNEL_GEN(kVScal)
|
||||
USE_JITKERNEL_GEN(kVAddBias)
|
||||
USE_JITKERNEL_GEN(kVRelu)
|
||||
USE_JITKERNEL_GEN(kVIdentity)
|
||||
USE_JITKERNEL_GEN(kVExp)
|
||||
USE_JITKERNEL_GEN(kVSigmoid)
|
||||
USE_JITKERNEL_GEN(kVTanh)
|
||||
USE_JITKERNEL_GEN(kLSTMCtHt)
|
||||
USE_JITKERNEL_GEN(kLSTMC1H1)
|
||||
USE_JITKERNEL_GEN(kGRUH1)
|
||||
USE_JITKERNEL_GEN(kGRUHtPart1)
|
||||
USE_JITKERNEL_GEN(kGRUHtPart2)
|
||||
USE_JITKERNEL_GEN(kNCHW16CMulNC)
|
@ -0,0 +1,135 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jit/gen/act.h"
|
||||
#include "paddle/fluid/operators/jit/registry.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = {
|
||||
REPEAT_8TIMES(1.f),
|
||||
REPEAT_8TIMES(2.f),
|
||||
REPEAT_8TIMES(0.5f),
|
||||
REPEAT_8TIMES(EXP_HIG),
|
||||
REPEAT_8TIMES(EXP_LOW),
|
||||
REPEAT_8TIMES(CEPHES_LOG2EF),
|
||||
REPEAT_8TIMES(CEPHES_EXP_C1),
|
||||
REPEAT_8TIMES(CEPHES_EXP_C2),
|
||||
REPEAT_8TIMES(CEPHES_EXP_P0),
|
||||
REPEAT_8TIMES(CEPHES_EXP_P1),
|
||||
REPEAT_8TIMES(CEPHES_EXP_P2),
|
||||
REPEAT_8TIMES(CEPHES_EXP_P3),
|
||||
REPEAT_8TIMES(CEPHES_EXP_P4),
|
||||
REPEAT_8TIMES(CEPHES_EXP_P5),
|
||||
REPEAT_8TIMES(EXP_MAX_INPUT),
|
||||
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
|
||||
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
|
||||
|
||||
const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)};
|
||||
int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0};
|
||||
|
||||
void VActJitCode::genCode() {
|
||||
int offset = 0;
|
||||
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
|
||||
vmovups(ymm_src, ptr[param1 + offset]);
|
||||
act<ymm_t>(ymm_dst, ymm_src, type_);
|
||||
vmovups(ptr[param2 + offset], ymm_dst);
|
||||
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
||||
}
|
||||
int rest = num_ % YMM_FLOAT_BLOCK;
|
||||
while (rest > 0) {
|
||||
int block = XMM_FLOAT_BLOCK;
|
||||
if (rest >= 4) {
|
||||
block = 4;
|
||||
vmovups(xmm_src, ptr[param1 + offset]);
|
||||
} else if (rest >= 2) {
|
||||
block = 2;
|
||||
vmovq(xmm_src, ptr[param1 + offset]);
|
||||
} else {
|
||||
block = 1;
|
||||
vmovss(xmm_src, ptr[param1 + offset]);
|
||||
}
|
||||
act<xmm_t>(xmm_dst, xmm_src, type_);
|
||||
if (rest >= 4) {
|
||||
vmovups(ptr[param2 + offset], xmm_dst);
|
||||
} else if (rest >= 2) {
|
||||
vmovq(ptr[param2 + offset], xmm_dst);
|
||||
} else {
|
||||
vmovss(ptr[param2 + offset], xmm_dst);
|
||||
}
|
||||
offset += sizeof(float) * block;
|
||||
rest -= block;
|
||||
}
|
||||
ret();
|
||||
}
|
||||
|
||||
#define DECLARE_ACT_CREATOR(name) \
|
||||
class name##Creator : public JitCodeCreator<int> { \
|
||||
public: \
|
||||
bool UseMe(const int& attr) const override { \
|
||||
return platform::MayIUse(platform::avx); \
|
||||
} \
|
||||
size_t CodeSize(const int& d) const override; \
|
||||
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
|
||||
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
|
||||
} \
|
||||
}
|
||||
|
||||
DECLARE_ACT_CREATOR(VRelu);
|
||||
DECLARE_ACT_CREATOR(VIdentity);
|
||||
DECLARE_ACT_CREATOR(VExp);
|
||||
DECLARE_ACT_CREATOR(VSigmoid);
|
||||
DECLARE_ACT_CREATOR(VTanh);
|
||||
|
||||
// TODO(TJ): tuning use me
|
||||
size_t VReluCreator::CodeSize(const int& d) const {
|
||||
return 96 /* init size */ +
|
||||
(d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ *
|
||||
8 /* average bytes for each instruction */;
|
||||
}
|
||||
|
||||
size_t VIdentityCreator::CodeSize(const int& d) const {
|
||||
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
|
||||
}
|
||||
|
||||
size_t VExpCreator::CodeSize(const int& d) const {
|
||||
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 70 * 8;
|
||||
}
|
||||
|
||||
size_t VSigmoidCreator::CodeSize(const int& d) const {
|
||||
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 82 * 8;
|
||||
}
|
||||
|
||||
size_t VTanhCreator::CodeSize(const int& d) const {
|
||||
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 84 * 8;
|
||||
}
|
||||
|
||||
#undef DECLARE_ACT_CREATOR
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace gen = paddle::operators::jit::gen;
|
||||
|
||||
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
|
||||
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
|
||||
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
|
||||
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
|
||||
REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,186 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jit/gen/blas.h"
|
||||
#include "paddle/fluid/operators/jit/registry.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
void VXXJitCode::genCode() {
|
||||
// do not need push stack, and do not need save avx512reg if do not use avx512
|
||||
int offset = 0;
|
||||
if (with_relu_) {
|
||||
vxorps(ymm_zero, ymm_zero, ymm_zero);
|
||||
}
|
||||
if (scalar_index_ == 1) {
|
||||
vbroadcastss(ymm_src1, ptr[param1]);
|
||||
} else if (scalar_index_ == 2) {
|
||||
vbroadcastss(ymm_src2, ptr[param2]);
|
||||
}
|
||||
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
|
||||
if (scalar_index_ != 1) {
|
||||
vmovups(ymm_src1, ptr[param1 + offset]);
|
||||
}
|
||||
if (scalar_index_ != 2) {
|
||||
vmovups(ymm_src2, ptr[param2 + offset]);
|
||||
}
|
||||
if (type_ == operand_type::MUL) {
|
||||
vmulps(ymm_dst, ymm_src1, ymm_src2);
|
||||
} else if (type_ == operand_type::ADD) {
|
||||
vaddps(ymm_dst, ymm_src1, ymm_src2);
|
||||
}
|
||||
if (with_relu_) {
|
||||
vmaxps(ymm_dst, ymm_zero, ymm_dst);
|
||||
}
|
||||
vmovups(ptr[param3 + offset], ymm_dst);
|
||||
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
||||
}
|
||||
int rest = num_ % YMM_FLOAT_BLOCK;
|
||||
while (rest > 0) {
|
||||
int block = XMM_FLOAT_BLOCK;
|
||||
if (rest >= 4) {
|
||||
block = 4;
|
||||
if (scalar_index_ != 1) {
|
||||
vmovups(xmm_src1, ptr[param1 + offset]);
|
||||
}
|
||||
if (scalar_index_ != 2) {
|
||||
vmovups(xmm_src2, ptr[param2 + offset]);
|
||||
}
|
||||
} else if (rest >= 2) {
|
||||
block = 2;
|
||||
if (scalar_index_ != 1) {
|
||||
vmovq(xmm_src1, ptr[param1 + offset]);
|
||||
}
|
||||
if (scalar_index_ != 2) {
|
||||
vmovq(xmm_src2, ptr[param2 + offset]);
|
||||
}
|
||||
} else {
|
||||
block = 1;
|
||||
if (scalar_index_ != 1) {
|
||||
vmovss(xmm_src1, ptr[param1 + offset]);
|
||||
}
|
||||
if (scalar_index_ != 2) {
|
||||
vmovss(xmm_src2, ptr[param2 + offset]);
|
||||
}
|
||||
}
|
||||
switch (type_) {
|
||||
case operand_type::MUL:
|
||||
vmulps(xmm_dst, xmm_src1, xmm_src2);
|
||||
break;
|
||||
case operand_type::ADD:
|
||||
vaddps(xmm_dst, xmm_src1, xmm_src2);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
if (with_relu_) {
|
||||
vmaxps(xmm_dst, xmm_zero, xmm_dst);
|
||||
}
|
||||
if (rest >= 4) {
|
||||
vmovups(ptr[param3 + offset], xmm_dst);
|
||||
} else if (rest >= 2) {
|
||||
vmovq(ptr[param3 + offset], xmm_dst);
|
||||
} else {
|
||||
vmovss(ptr[param3 + offset], xmm_dst);
|
||||
}
|
||||
offset += sizeof(float) * block;
|
||||
rest -= block;
|
||||
}
|
||||
ret();
|
||||
}
|
||||
|
||||
void NCHW16CMulNCJitCode::genCode() {
|
||||
// RDI is ptr x_input
|
||||
// RSI is ptr y_input
|
||||
// RDX is ptr output
|
||||
// RCX is height
|
||||
// r8 is width
|
||||
|
||||
push(rbx);
|
||||
|
||||
xor_(rax, rax);
|
||||
xor_(r10, r10);
|
||||
vmovups(zmm3, ptr[rsi]);
|
||||
|
||||
L("h_loop");
|
||||
xor_(rbx, rbx);
|
||||
L("w_loop");
|
||||
vmovups(zmm2, ptr[rdi + rax]);
|
||||
vmulps(zmm1, zmm2, zmm3);
|
||||
vmovups(ptr[rdx + rax], zmm1);
|
||||
add(rax, 64);
|
||||
inc(rbx);
|
||||
cmp(r8, rbx);
|
||||
jnz("w_loop");
|
||||
inc(r10);
|
||||
cmp(r10, rcx);
|
||||
jnz("h_loop");
|
||||
|
||||
pop(rbx);
|
||||
ret();
|
||||
}
|
||||
|
||||
class NCHW16CMulNCCreator : public JitCodeCreator<int> {
|
||||
public:
|
||||
bool UseMe(const int& attr) const override {
|
||||
return platform::MayIUse(platform::avx512f);
|
||||
}
|
||||
size_t CodeSize(const int& d) const override { return 256 * 1024; }
|
||||
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
|
||||
return make_unique<NCHW16CMulNCJitCode>(attr, CodeSize(attr));
|
||||
}
|
||||
};
|
||||
|
||||
#define DECLARE_BLAS_CREATOR(name) \
|
||||
class name##Creator : public JitCodeCreator<int> { \
|
||||
public: \
|
||||
bool UseMe(const int& attr) const override { \
|
||||
return platform::MayIUse(platform::avx); \
|
||||
} \
|
||||
size_t CodeSize(const int& d) const override { \
|
||||
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
|
||||
} \
|
||||
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
|
||||
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
|
||||
} \
|
||||
}
|
||||
|
||||
DECLARE_BLAS_CREATOR(VMul);
|
||||
DECLARE_BLAS_CREATOR(VAdd);
|
||||
DECLARE_BLAS_CREATOR(VSub);
|
||||
DECLARE_BLAS_CREATOR(VAddRelu);
|
||||
DECLARE_BLAS_CREATOR(VScal);
|
||||
DECLARE_BLAS_CREATOR(VAddBias);
|
||||
|
||||
#undef DECLARE_BLAS_CREATOR
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace gen = paddle::operators::jit::gen;
|
||||
|
||||
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
|
||||
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
|
||||
// TODO(TJ): enable sub
|
||||
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
|
||||
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
|
||||
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
|
||||
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
|
||||
REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator);
|
@ -0,0 +1,117 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/operators/jit/gen/jitcode.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
|
||||
class VXXJitCode : public JitCode {
|
||||
public:
|
||||
explicit VXXJitCode(int d, operand_type type, int scalar_index,
|
||||
bool with_relu, size_t code_size = 256 * 1024,
|
||||
void* code_ptr = nullptr)
|
||||
: JitCode(code_size, code_ptr),
|
||||
num_(d),
|
||||
type_(type),
|
||||
scalar_index_(scalar_index),
|
||||
with_relu_(with_relu) {
|
||||
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
|
||||
LOG(FATAL) << "Do not support this operand type: " << type_;
|
||||
}
|
||||
this->genCode();
|
||||
}
|
||||
|
||||
virtual const char* name() const {
|
||||
std::string base = "VXXJitCode";
|
||||
if (scalar_index_ == 1) {
|
||||
base += "_Scalar";
|
||||
} else {
|
||||
base += "_Vec";
|
||||
}
|
||||
if (type_ == operand_type::MUL) {
|
||||
base += "_Mul";
|
||||
} else if (type_ == operand_type::ADD) {
|
||||
base += "_Add";
|
||||
}
|
||||
if (scalar_index_ == 2) {
|
||||
base += "_Scalar";
|
||||
} else {
|
||||
base += "_Vec";
|
||||
}
|
||||
base += (with_relu_ ? "_Relu" : "");
|
||||
return base.c_str();
|
||||
}
|
||||
void genCode() override;
|
||||
|
||||
private:
|
||||
int num_;
|
||||
operand_type type_;
|
||||
int scalar_index_;
|
||||
bool with_relu_;
|
||||
reg64_t param1{abi_param1};
|
||||
reg64_t param2{abi_param2};
|
||||
reg64_t param3{abi_param3};
|
||||
|
||||
xmm_t xmm_src1 = xmm_t(0);
|
||||
xmm_t xmm_src2 = xmm_t(1);
|
||||
xmm_t xmm_dst = xmm_t(2);
|
||||
xmm_t xmm_zero = xmm_t(3);
|
||||
|
||||
ymm_t ymm_src1 = ymm_t(0);
|
||||
ymm_t ymm_src2 = ymm_t(1);
|
||||
ymm_t ymm_dst = ymm_t(2);
|
||||
ymm_t ymm_zero = ymm_t(3);
|
||||
};
|
||||
|
||||
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
|
||||
class name##JitCode : public VXXJitCode { \
|
||||
public: \
|
||||
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
|
||||
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
|
||||
} \
|
||||
};
|
||||
|
||||
DECLARE_BLAS_JITCODE(VMul, operand_type::MUL, 0, false);
|
||||
DECLARE_BLAS_JITCODE(VAdd, operand_type::ADD, 0, false);
|
||||
DECLARE_BLAS_JITCODE(VSub, operand_type::SUB, 0, false);
|
||||
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::ADD, 0, true);
|
||||
DECLARE_BLAS_JITCODE(VScal, operand_type::MUL, 1, false);
|
||||
DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false);
|
||||
|
||||
#undef DECLARE_BLAS_JITCODE
|
||||
|
||||
// nChw16c = nChw16c .* NC
|
||||
class NCHW16CMulNCJitCode : public JitCode {
|
||||
public:
|
||||
DECLARE_JIT_CODE(NCHW16CMulNCJitCode);
|
||||
explicit NCHW16CMulNCJitCode(int d /*unused*/, size_t code_size,
|
||||
void* code_ptr = nullptr)
|
||||
: JitCode(code_size, code_ptr) {
|
||||
this->genCode();
|
||||
}
|
||||
void genCode() override;
|
||||
};
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,116 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jit/gen/gru.h"
|
||||
#include <stddef.h> // offsetof
|
||||
#include "paddle/fluid/operators/jit/registry.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
void GRUJitCode::genCode() {
|
||||
reg64_t reg_ptr_gates = rax;
|
||||
reg64_t reg_ptr_ht_1 = r9;
|
||||
reg64_t reg_ptr_ht = r10;
|
||||
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
|
||||
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
|
||||
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
|
||||
ymm_t ymm_one = ymm_t(0);
|
||||
|
||||
if (id_ == 2) {
|
||||
reg64_t reg_ptr_tmp = r11;
|
||||
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
|
||||
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
|
||||
}
|
||||
int offset = 0;
|
||||
int d = num_ * sizeof(float);
|
||||
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
|
||||
ymm_t ymm_u = ymm_t(1);
|
||||
ymm_t ymm_r = ymm_t(2);
|
||||
ymm_t ymm_s = ymm_t(3);
|
||||
ymm_t ymm_ht_1 = ymm_t(4);
|
||||
// W: {W_update, W_reset; W_state}
|
||||
if (id_ == 0 || id_ == 2) {
|
||||
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
|
||||
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
|
||||
}
|
||||
if (id_ == 1) {
|
||||
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
|
||||
}
|
||||
if (id_ == 1 || id_ == 2) {
|
||||
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
|
||||
}
|
||||
|
||||
if (id_ == 0) {
|
||||
// ht = act_gate(u) * act_cand(s)
|
||||
act<ymm_t>(ymm_u, ymm_u, act_gate_);
|
||||
act<ymm_t>(ymm_s, ymm_s, act_cand_);
|
||||
vmulps(ymm_s, ymm_s, ymm_u);
|
||||
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
|
||||
} else if (id_ == 1) {
|
||||
// ht = act_gate(r) * ht_1
|
||||
act<ymm_t>(ymm_r, ymm_r, act_gate_);
|
||||
vmulps(ymm_r, ymm_r, ymm_ht_1);
|
||||
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
|
||||
} else if (id_ == 2) {
|
||||
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
|
||||
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
|
||||
act<ymm_t>(ymm_u, ymm_u, act_gate_);
|
||||
act<ymm_t>(ymm_s, ymm_s, act_cand_);
|
||||
vmulps(ymm_s, ymm_s, ymm_u);
|
||||
vsubps(ymm_u, ymm_one_inner, ymm_u);
|
||||
vmulps(ymm_u, ymm_ht_1, ymm_u);
|
||||
vaddps(ymm_u, ymm_s, ymm_u);
|
||||
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
|
||||
}
|
||||
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
||||
}
|
||||
ret();
|
||||
}
|
||||
|
||||
#define DECLARE_GRU_CREATOR(name) \
|
||||
class name##Creator : public JitCodeCreator<gru_attr_t> { \
|
||||
public: \
|
||||
/* TODO(TJ): enable more */ \
|
||||
bool UseMe(const gru_attr_t& attr) const override { \
|
||||
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
|
||||
} \
|
||||
size_t CodeSize(const gru_attr_t& attr) const override { \
|
||||
return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \
|
||||
} \
|
||||
std::unique_ptr<GenBase> CreateJitCode( \
|
||||
const gru_attr_t& attr) const override { \
|
||||
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
|
||||
} \
|
||||
}
|
||||
|
||||
DECLARE_GRU_CREATOR(GRUH1);
|
||||
DECLARE_GRU_CREATOR(GRUHtPart1);
|
||||
DECLARE_GRU_CREATOR(GRUHtPart2);
|
||||
|
||||
#undef DECLARE_GRU_CREATOR
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace gen = paddle::operators::jit::gen;
|
||||
|
||||
REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator);
|
||||
REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator);
|
||||
REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator);
|
@ -0,0 +1,113 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/operators/jit/gen/act.h"
|
||||
#include "paddle/fluid/operators/jit/gen/jitcode.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
class GRUJitCode : public VActFunc {
|
||||
public:
|
||||
explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size,
|
||||
void* code_ptr = nullptr)
|
||||
: VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
|
||||
auto typeExchange = [](KernelType type) -> gen::operand_type {
|
||||
if (type == KernelType::kVSigmoid) {
|
||||
return operand_type::SIGMOID;
|
||||
} else if (type == KernelType::kVRelu) {
|
||||
return operand_type::RELU;
|
||||
} else if (type == KernelType::kVTanh) {
|
||||
return operand_type::TANH;
|
||||
} else if (type == KernelType::kVIdentity) {
|
||||
return operand_type::IDENTITY;
|
||||
} else {
|
||||
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
|
||||
}
|
||||
return operand_type::IDENTITY;
|
||||
};
|
||||
act_gate_ = typeExchange(attr.act_gate);
|
||||
act_cand_ = typeExchange(attr.act_cand);
|
||||
|
||||
this->genCode();
|
||||
}
|
||||
|
||||
const char* name() const override {
|
||||
std::string base = "GRUJitCode";
|
||||
if (id_ == 0) {
|
||||
base += "_H1";
|
||||
} else if (id_ == 1) {
|
||||
base += "_HtPart1";
|
||||
} else if (id_ == 2) {
|
||||
base += "_HtPart2";
|
||||
}
|
||||
auto AddTypeStr = [&](operand_type type) {
|
||||
switch (type) {
|
||||
case operand_type::RELU:
|
||||
base += "_Relu";
|
||||
break;
|
||||
case operand_type::EXP:
|
||||
base += "_Exp";
|
||||
break;
|
||||
case operand_type::SIGMOID:
|
||||
base += "_Sigmoid";
|
||||
break;
|
||||
case operand_type::TANH:
|
||||
base += "_Tanh";
|
||||
break;
|
||||
case operand_type::IDENTITY:
|
||||
base += "_Identity";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
};
|
||||
AddTypeStr(act_gate_);
|
||||
AddTypeStr(act_cand_);
|
||||
return base.c_str();
|
||||
}
|
||||
void genCode() override;
|
||||
|
||||
protected:
|
||||
int id_;
|
||||
int num_;
|
||||
operand_type act_gate_;
|
||||
operand_type act_cand_;
|
||||
reg64_t param1{abi_param1};
|
||||
};
|
||||
|
||||
#define DECLARE_GRU_JITCODE(name, id) \
|
||||
class name##JitCode : public GRUJitCode { \
|
||||
public: \
|
||||
explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \
|
||||
void* code_ptr = nullptr) \
|
||||
: GRUJitCode(id, attr, code_size, code_ptr) {} \
|
||||
};
|
||||
|
||||
DECLARE_GRU_JITCODE(GRUH1, 0);
|
||||
DECLARE_GRU_JITCODE(GRUHtPart1, 1);
|
||||
DECLARE_GRU_JITCODE(GRUHtPart2, 2);
|
||||
|
||||
#undef DECLARE_GRU_JITCODE
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,126 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include "paddle/fluid/operators/jit/gen_base.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
|
||||
#define XBYAK_USE_MMAP_ALLOCATOR
|
||||
#include "xbyak/xbyak.h"
|
||||
#include "xbyak/xbyak_util.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
// Application Binary Interface
|
||||
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
|
||||
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
|
||||
abi_param4(Xbyak::Operand::RCX);
|
||||
|
||||
constexpr Xbyak::Operand::Code g_abi_regs[] = {
|
||||
Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
|
||||
Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15};
|
||||
|
||||
constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]);
|
||||
|
||||
using reg64_t = const Xbyak::Reg64;
|
||||
using reg32_t = const Xbyak::Reg32;
|
||||
using xmm_t = const Xbyak::Xmm;
|
||||
using ymm_t = const Xbyak::Ymm;
|
||||
using zmm_t = const Xbyak::Zmm;
|
||||
using Label = Xbyak::Label;
|
||||
|
||||
typedef enum {
|
||||
MUL = 0,
|
||||
ADD,
|
||||
SUB,
|
||||
RELU,
|
||||
EXP,
|
||||
SIGMOID,
|
||||
TANH,
|
||||
IDENTITY
|
||||
} operand_type;
|
||||
|
||||
#define DECLARE_JIT_CODE(codename) \
|
||||
const char* name() const override { return #codename; }
|
||||
|
||||
class JitCode : public GenBase, public Xbyak::CodeGenerator {
|
||||
public:
|
||||
explicit JitCode(size_t code_size, void* code_ptr = nullptr)
|
||||
: Xbyak::CodeGenerator(
|
||||
(code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size),
|
||||
code_ptr) {}
|
||||
|
||||
virtual const char* name() const = 0;
|
||||
virtual void genCode() = 0;
|
||||
|
||||
size_t getSize() const override { return CodeGenerator::getSize(); }
|
||||
const unsigned char* getCodeInternal() override {
|
||||
const Xbyak::uint8* code = CodeGenerator::getCode();
|
||||
return code;
|
||||
}
|
||||
|
||||
protected:
|
||||
Xbyak::Reg64 param1{abi_param1};
|
||||
const int EVEX_max_8b_offt = 0x200;
|
||||
const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
|
||||
|
||||
virtual void preCode() {
|
||||
for (int i = 0; i < num_g_abi_regs; ++i) {
|
||||
push(Xbyak::Reg64(g_abi_regs[i]));
|
||||
}
|
||||
if (platform::MayIUse(platform::avx512f)) {
|
||||
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
|
||||
}
|
||||
}
|
||||
virtual void postCode() {
|
||||
for (int i = 0; i < num_g_abi_regs; ++i) {
|
||||
pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i]));
|
||||
}
|
||||
ret();
|
||||
}
|
||||
void L(const char* label) { Xbyak::CodeGenerator::L(label); }
|
||||
void L(const Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
|
||||
// Enhanced vector extension
|
||||
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
|
||||
bool bcast = false) {
|
||||
int scale = 0;
|
||||
// Learn from https://github.com/intel/mkl-dnn
|
||||
if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
|
||||
offt = offt - 2 * EVEX_max_8b_offt;
|
||||
scale = 1;
|
||||
} else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
|
||||
offt = offt - 4 * EVEX_max_8b_offt;
|
||||
scale = 2;
|
||||
}
|
||||
auto re = Xbyak::RegExp() + base + offt;
|
||||
if (scale) {
|
||||
re = re + reg_EVEX_max_8b_offt * scale;
|
||||
}
|
||||
if (bcast) {
|
||||
return zword_b[re];
|
||||
} else {
|
||||
return zword[re];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,142 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jit/gen/lstm.h"
|
||||
#include <stddef.h> // offsetof
|
||||
#include "paddle/fluid/operators/jit/registry.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
void LSTMJitCode::genCode() {
|
||||
if (use_peephole_) {
|
||||
preCode();
|
||||
}
|
||||
reg64_t reg_ptr_gates = rax;
|
||||
reg64_t reg_ptr_ct_1 = r9;
|
||||
reg64_t reg_ptr_ct = r10;
|
||||
reg64_t reg_ptr_ht = r11;
|
||||
reg64_t reg_ptr_wp = r12;
|
||||
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
|
||||
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
|
||||
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
|
||||
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
|
||||
if (use_peephole_) {
|
||||
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
|
||||
}
|
||||
|
||||
int offset = 0;
|
||||
int d = num_ * sizeof(float);
|
||||
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
|
||||
/* gates: W_ch, W_ih, W_fh, W_oh */
|
||||
ymm_t ymm_c = ymm_t(0);
|
||||
ymm_t ymm_i = ymm_t(1);
|
||||
ymm_t ymm_f = ymm_t(2);
|
||||
ymm_t ymm_o = ymm_t(3);
|
||||
ymm_t ymm_ct_1 = ymm_t(4);
|
||||
ymm_t ymm_wp0 = ymm_t(5);
|
||||
ymm_t ymm_wp1 = ymm_t(6);
|
||||
ymm_t ymm_wp2 = ymm_t(7);
|
||||
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
|
||||
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
|
||||
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
|
||||
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
|
||||
if (!compute_c1h1_) {
|
||||
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
|
||||
}
|
||||
if (use_peephole_) {
|
||||
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
|
||||
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
|
||||
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
|
||||
}
|
||||
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
|
||||
// act_cand(c)
|
||||
act<ymm_t>(ymm_c, ymm_c, act_cand_);
|
||||
// act_gate(i) or act_gate(ct_1 * wp0 + i)
|
||||
if (!compute_c1h1_ && use_peephole_) {
|
||||
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
|
||||
vaddps(ymm_i, ymm_i, ymm_wp0);
|
||||
}
|
||||
act<ymm_t>(ymm_i, ymm_i, act_gate_);
|
||||
vmulps(ymm_c, ymm_c, ymm_i);
|
||||
if (!compute_c1h1_) {
|
||||
// act_gate(f) or act_gate(ct_1 * wp1 + f)
|
||||
if (use_peephole_) {
|
||||
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
|
||||
vaddps(ymm_f, ymm_f, ymm_wp1);
|
||||
}
|
||||
act<ymm_t>(ymm_f, ymm_f, act_gate_);
|
||||
// ct
|
||||
vmulps(ymm_f, ymm_f, ymm_ct_1);
|
||||
vaddps(ymm_f, ymm_f, ymm_c);
|
||||
}
|
||||
/* H_t = act_cell(C_t) * act_gate(o) */
|
||||
// act_cell(C_t)
|
||||
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
|
||||
ymm_t ymm_tmp = ymm_i;
|
||||
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
|
||||
// act_gate(o) or act_gate(ct * wp2 + o)
|
||||
if (use_peephole_) {
|
||||
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
|
||||
vaddps(ymm_o, ymm_o, ymm_wp2);
|
||||
}
|
||||
act<ymm_t>(ymm_o, ymm_o, act_gate_);
|
||||
// ht
|
||||
vmulps(ymm_o, ymm_o, ymm_tmp);
|
||||
// save ct and ht
|
||||
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
|
||||
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
|
||||
offset += sizeof(float) * YMM_FLOAT_BLOCK;
|
||||
}
|
||||
|
||||
if (use_peephole_) {
|
||||
postCode();
|
||||
} else {
|
||||
ret();
|
||||
}
|
||||
}
|
||||
|
||||
#define DECLARE_LSTM_CREATOR(name) \
|
||||
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
|
||||
public: \
|
||||
/* TODO(TJ): enable more */ \
|
||||
bool UseMe(const lstm_attr_t& attr) const override { \
|
||||
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
|
||||
} \
|
||||
size_t CodeSize(const lstm_attr_t& attr) const override { \
|
||||
return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \
|
||||
} \
|
||||
std::unique_ptr<GenBase> CreateJitCode( \
|
||||
const lstm_attr_t& attr) const override { \
|
||||
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
|
||||
} \
|
||||
}
|
||||
|
||||
DECLARE_LSTM_CREATOR(LSTMCtHt);
|
||||
DECLARE_LSTM_CREATOR(LSTMC1H1);
|
||||
|
||||
#undef DECLARE_LSTM_CREATOR
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace gen = paddle::operators::jit::gen;
|
||||
|
||||
REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator);
|
||||
REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator);
|
@ -0,0 +1,118 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/operators/jit/gen/act.h"
|
||||
#include "paddle/fluid/operators/jit/gen/jitcode.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
namespace gen {
|
||||
|
||||
class LSTMJitCode : public VActFunc {
|
||||
public:
|
||||
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
|
||||
size_t code_size, void* code_ptr = nullptr)
|
||||
: VActFunc(code_size, code_ptr),
|
||||
num_(attr.d),
|
||||
compute_c1h1_(compute_c1h1),
|
||||
use_peephole_(attr.use_peephole) {
|
||||
auto typeExchange = [](KernelType type) -> gen::operand_type {
|
||||
if (type == KernelType::kVSigmoid) {
|
||||
return operand_type::SIGMOID;
|
||||
} else if (type == KernelType::kVRelu) {
|
||||
return operand_type::RELU;
|
||||
} else if (type == KernelType::kVTanh) {
|
||||
return operand_type::TANH;
|
||||
} else if (type == KernelType::kVIdentity) {
|
||||
return operand_type::IDENTITY;
|
||||
} else {
|
||||
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
|
||||
}
|
||||
return operand_type::IDENTITY;
|
||||
};
|
||||
act_gate_ = typeExchange(attr.act_gate);
|
||||
act_cand_ = typeExchange(attr.act_cand);
|
||||
act_cell_ = typeExchange(attr.act_cell);
|
||||
|
||||
this->genCode();
|
||||
}
|
||||
|
||||
const char* name() const override {
|
||||
std::string base = "LSTMJitCode";
|
||||
if (use_peephole_) {
|
||||
base += "_Peephole";
|
||||
}
|
||||
if (compute_c1h1_) {
|
||||
base += "_C1H1";
|
||||
}
|
||||
auto AddTypeStr = [&](operand_type type) {
|
||||
switch (type) {
|
||||
case operand_type::RELU:
|
||||
base += "_Relu";
|
||||
break;
|
||||
case operand_type::EXP:
|
||||
base += "_Exp";
|
||||
break;
|
||||
case operand_type::SIGMOID:
|
||||
base += "_Sigmoid";
|
||||
break;
|
||||
case operand_type::TANH:
|
||||
base += "_Tanh";
|
||||
break;
|
||||
case operand_type::IDENTITY:
|
||||
base += "_Identity";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
};
|
||||
AddTypeStr(act_gate_);
|
||||
AddTypeStr(act_cand_);
|
||||
AddTypeStr(act_cell_);
|
||||
return base.c_str();
|
||||
}
|
||||
void genCode() override;
|
||||
|
||||
protected:
|
||||
int num_;
|
||||
bool compute_c1h1_;
|
||||
bool use_peephole_;
|
||||
operand_type act_gate_;
|
||||
operand_type act_cand_;
|
||||
operand_type act_cell_;
|
||||
reg64_t param1{abi_param1};
|
||||
};
|
||||
|
||||
#define DECLARE_LSTM_JITCODE(name, compute_c1h1) \
|
||||
class name##JitCode : public LSTMJitCode { \
|
||||
public: \
|
||||
explicit name##JitCode(const lstm_attr_t& attr, size_t code_size, \
|
||||
void* code_ptr = nullptr) \
|
||||
: LSTMJitCode(compute_c1h1, attr, code_size, code_ptr) {} \
|
||||
};
|
||||
|
||||
DECLARE_LSTM_JITCODE(LSTMCtHt, false);
|
||||
DECLARE_LSTM_JITCODE(LSTMC1H1, true);
|
||||
|
||||
#undef DECLARE_LSTM_JITCODE
|
||||
|
||||
} // namespace gen
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,43 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jit/gen_base.h"
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
|
||||
// refer do not need useme, it would be the last one.
|
||||
void GenBase::dumpCode(const unsigned char* code) const {
|
||||
if (code) {
|
||||
static int counter = 0;
|
||||
std::ostringstream filename;
|
||||
filename << "paddle_jitcode_" << name() << "." << counter << ".bin";
|
||||
counter++;
|
||||
std::ofstream fout(filename.str(), std::ios::out);
|
||||
if (fout.is_open()) {
|
||||
fout.write(reinterpret_cast<const char*>(code), this->getSize());
|
||||
fout.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,70 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <memory> // for unique_ptr
|
||||
#include "paddle/fluid/operators/jit/kernel_base.h"
|
||||
|
||||
DECLARE_bool(dump_jitcode);
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
|
||||
class GenBase : public Kernel {
|
||||
public:
|
||||
virtual ~GenBase() = default;
|
||||
virtual const char* name() const = 0;
|
||||
virtual size_t getSize() const = 0;
|
||||
virtual const unsigned char* getCodeInternal() = 0;
|
||||
template <typename Func>
|
||||
Func getCode() {
|
||||
const unsigned char* code = this->getCodeInternal();
|
||||
if (FLAGS_dump_jitcode) {
|
||||
this->dumpCode(code);
|
||||
}
|
||||
return reinterpret_cast<Func>(const_cast<unsigned char*>(code));
|
||||
}
|
||||
|
||||
protected:
|
||||
void dumpCode(const unsigned char* code) const;
|
||||
};
|
||||
|
||||
// Creator is used to creat the jitcode and save in pool.
|
||||
// Every JitCode should have one creator.
|
||||
class GenCreator {
|
||||
public:
|
||||
virtual ~GenCreator() = default;
|
||||
};
|
||||
|
||||
template <typename Attr>
|
||||
class JitCodeCreator : public GenCreator {
|
||||
public:
|
||||
virtual ~JitCodeCreator() = default;
|
||||
|
||||
// condition when this jit code can be used.
|
||||
virtual bool UseMe(const Attr& attr) const = 0;
|
||||
|
||||
// estimate this code size
|
||||
virtual size_t CodeSize(const Attr& attr) const = 0;
|
||||
|
||||
// create this code
|
||||
virtual std::unique_ptr<GenBase> CreateJitCode(const Attr& attr) const = 0;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,76 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jit/helper.h"
|
||||
#include <algorithm> // tolower
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
|
||||
#define ONE_CASE(key) \
|
||||
case key: \
|
||||
return #key
|
||||
|
||||
const char* to_string(KernelType kt) {
|
||||
switch (kt) {
|
||||
ONE_CASE(kVMul);
|
||||
ONE_CASE(kVAdd);
|
||||
ONE_CASE(kVAddRelu);
|
||||
ONE_CASE(kVSub);
|
||||
ONE_CASE(kVScal);
|
||||
ONE_CASE(kVAddBias);
|
||||
ONE_CASE(kVRelu);
|
||||
ONE_CASE(kVIdentity);
|
||||
ONE_CASE(kVExp);
|
||||
ONE_CASE(kVSigmoid);
|
||||
ONE_CASE(kVTanh);
|
||||
ONE_CASE(kLSTMCtHt);
|
||||
ONE_CASE(kLSTMC1H1);
|
||||
ONE_CASE(kGRUH1);
|
||||
ONE_CASE(kGRUHtPart1);
|
||||
ONE_CASE(kGRUHtPart2);
|
||||
ONE_CASE(kCRFDecoding);
|
||||
ONE_CASE(kLayerNorm);
|
||||
ONE_CASE(kNCHW16CMulNC);
|
||||
default:
|
||||
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
|
||||
return "NOT JITKernel";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
#undef ONE_CASE
|
||||
|
||||
KernelType to_kerneltype(const std::string& act) {
|
||||
std::string lower = act;
|
||||
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
|
||||
if (lower == "relu" || lower == "vrelu") {
|
||||
return kVRelu;
|
||||
} else if (lower == "identity" || lower == "videntity" || lower == "") {
|
||||
return kVIdentity;
|
||||
} else if (lower == "exp" || lower == "vexp") {
|
||||
return kVExp;
|
||||
} else if (lower == "sigmoid" || lower == "vsigmoid") {
|
||||
return kVSigmoid;
|
||||
} else if (lower == "tanh" || lower == "vtanh") {
|
||||
return kVTanh;
|
||||
}
|
||||
PADDLE_THROW("Not support type: %s, or forget to add this case", act);
|
||||
return kNone;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,140 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/operators/jit/gen_base.h"
|
||||
#include "paddle/fluid/operators/jit/kernel_base.h"
|
||||
#include "paddle/fluid/operators/jit/kernel_key.h"
|
||||
#include "paddle/fluid/operators/jit/kernel_pool.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
|
||||
template <KernelType KT, typename KernelTuples, typename PlaceType>
|
||||
inline typename std::enable_if<
|
||||
std::is_same<typename KernelTuples::data_type, float>::value &&
|
||||
std::is_same<PlaceType, platform::CPUPlace>::value,
|
||||
typename KernelTuples::func_type>::type
|
||||
GetJitCode(const typename KernelTuples::attr_type& attr) {
|
||||
using Func = typename KernelTuples::func_type;
|
||||
using Attr = typename KernelTuples::attr_type;
|
||||
size_t key = JitCodeKey<Attr>(attr);
|
||||
auto& codes = JitCodePool<KT>().Instance();
|
||||
if (codes.Has(key)) {
|
||||
return codes.AllKernels().at(key)->template getCode<Func>();
|
||||
}
|
||||
|
||||
// creator is not related with attr, so can use KernelKey as key
|
||||
KernelKey kkey(KT, PlaceType());
|
||||
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
|
||||
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
|
||||
auto iter = creator_map.find(kkey);
|
||||
if (iter != creator_map.end()) {
|
||||
auto& creators = iter->second;
|
||||
for (auto& cur : creators) {
|
||||
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
|
||||
if (i && i->UseMe(attr)) {
|
||||
auto p = i->CreateJitCode(attr);
|
||||
if (p) {
|
||||
auto f = p->template getCode<Func>();
|
||||
codes.Insert(key, std::move(p));
|
||||
return f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <KernelType KT, typename KernelTuples, typename PlaceType>
|
||||
inline typename std::enable_if<
|
||||
!std::is_same<typename KernelTuples::data_type, float>::value ||
|
||||
!std::is_same<PlaceType, platform::CPUPlace>::value,
|
||||
typename KernelTuples::func_type>::type
|
||||
GetJitCode(const typename KernelTuples::attr_type& attr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Refer code do not related with attr, which is just for cast
|
||||
// Refer is always on CPUPlace
|
||||
template <KernelType KT, typename KernelTuples>
|
||||
inline typename KernelTuples::func_type GetRefer() {
|
||||
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
|
||||
KernelKey kkey(KT, platform::CPUPlace());
|
||||
auto ref_iter = ref_pool.find(kkey);
|
||||
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
|
||||
"Every Kernel should have reference function.");
|
||||
auto& ref_impls = ref_iter->second;
|
||||
for (auto& impl : ref_impls) {
|
||||
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
|
||||
if (i) {
|
||||
return i->GetFunc();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <KernelType KT, typename KernelTuples,
|
||||
typename PlaceType = platform::CPUPlace>
|
||||
typename KernelTuples::func_type Get(
|
||||
const typename KernelTuples::attr_type& attr) {
|
||||
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
|
||||
if (jitfunc) {
|
||||
return jitfunc;
|
||||
}
|
||||
|
||||
// pool: (KernelKey(type, place), vector<KernelPtr>)
|
||||
KernelKey kkey(KT, PlaceType());
|
||||
auto& pool = KernelPool().Instance().AllKernels();
|
||||
auto iter = pool.find(kkey);
|
||||
if (iter != pool.end()) {
|
||||
auto& impls = iter->second;
|
||||
for (auto& impl : impls) {
|
||||
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get());
|
||||
if (i && i->UseMe(attr)) {
|
||||
return i->GetFunc();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The last implementation should be reference function on CPUPlace.
|
||||
return GetRefer<KT, KernelTuples>();
|
||||
}
|
||||
|
||||
const char* to_string(KernelType kt);
|
||||
|
||||
KernelType to_kerneltype(const std::string& act);
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const lstm_attr_t& attr) {
|
||||
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
|
||||
<< "],act_cand[" << to_string(attr.act_cand) << "],act_cell["
|
||||
<< to_string(attr.act_cell) << "],use_peephole["
|
||||
<< (attr.use_peephole ? "True" : "False") << "]";
|
||||
return os;
|
||||
}
|
||||
inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
|
||||
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
|
||||
<< "],act_cand[" << to_string(attr.act_cand) << "]";
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,172 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/operators/jit/macro.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
|
||||
typedef enum {
|
||||
kNone = 0,
|
||||
kVMul = 1,
|
||||
kVAdd = 2,
|
||||
kVAddRelu,
|
||||
kVSub,
|
||||
kVScal,
|
||||
kVAddBias,
|
||||
kVRelu,
|
||||
kVIdentity,
|
||||
kVExp,
|
||||
kVSigmoid,
|
||||
kVTanh,
|
||||
kLSTMCtHt,
|
||||
kLSTMC1H1,
|
||||
kGRUH1,
|
||||
kGRUHtPart1,
|
||||
kGRUHtPart2,
|
||||
kCRFDecoding,
|
||||
kLayerNorm,
|
||||
kNCHW16CMulNC,
|
||||
} KernelType;
|
||||
|
||||
template <typename T>
|
||||
struct XYZNTuples {
|
||||
typedef T data_type;
|
||||
typedef int attr_type;
|
||||
typedef void (*func_type)(const T*, const T*, T*, int);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AXYNTuples : public XYZNTuples<T> {};
|
||||
|
||||
template <typename T>
|
||||
struct XYNTuples {
|
||||
typedef T data_type;
|
||||
typedef int attr_type;
|
||||
typedef void (*func_type)(const T*, T*, int);
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
|
||||
const void* ct_1;
|
||||
void* ct;
|
||||
void* ht;
|
||||
/* weight_peephole and checked data are only used in peephole*/
|
||||
const void* wp{nullptr}; // W_ic, W_fc, W_oc
|
||||
void* checked{nullptr}; // size: 2 * d
|
||||
} lstm_t;
|
||||
|
||||
typedef struct {
|
||||
void* gates; // gates: {x_update, x_reset; x_state}
|
||||
const void* ht_1;
|
||||
void* ht;
|
||||
} gru_t;
|
||||
|
||||
struct rnn_attr_s {
|
||||
int d;
|
||||
KernelType act_gate, act_cand;
|
||||
rnn_attr_s() = default;
|
||||
explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
|
||||
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
|
||||
};
|
||||
|
||||
struct lstm_attr_s : public rnn_attr_s {
|
||||
bool use_peephole;
|
||||
KernelType act_cell;
|
||||
lstm_attr_s() = default;
|
||||
explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
|
||||
KernelType _act_cell, bool _use_peephole = false)
|
||||
: rnn_attr_s(_d, _act_gate, _act_cand),
|
||||
use_peephole(_use_peephole),
|
||||
act_cell(_act_cell) {}
|
||||
};
|
||||
|
||||
typedef struct rnn_attr_s gru_attr_t;
|
||||
typedef struct lstm_attr_s lstm_attr_t;
|
||||
|
||||
template <typename T>
|
||||
struct LSTMTuples {
|
||||
typedef T data_type;
|
||||
typedef lstm_attr_t attr_type;
|
||||
typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GRUTuples {
|
||||
typedef T data_type;
|
||||
typedef gru_attr_t attr_type;
|
||||
typedef void (*func_type)(gru_t*, const gru_attr_t*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CRFDecodingTuples {
|
||||
typedef T data_type;
|
||||
typedef int attr_type;
|
||||
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LayerNormTuples {
|
||||
typedef T data_type;
|
||||
typedef int attr_type;
|
||||
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
|
||||
const float, int);
|
||||
};
|
||||
|
||||
// nChw16c = nChw16c .* NC
|
||||
template <typename T>
|
||||
struct NCHW16CMulNCTuples {
|
||||
typedef T data_type;
|
||||
typedef int attr_type;
|
||||
typedef void (*func_type)(const T*, const T*, T*, int, int);
|
||||
};
|
||||
|
||||
// Just for adding to kernel pool without template
|
||||
class Kernel {
|
||||
public:
|
||||
Kernel() = default;
|
||||
virtual ~Kernel() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(Kernel);
|
||||
};
|
||||
|
||||
template <typename KernelTuples>
|
||||
class KernelMore : public Kernel {
|
||||
public:
|
||||
using T = typename KernelTuples::data_type;
|
||||
using Func = typename KernelTuples::func_type;
|
||||
using Attr = typename KernelTuples::attr_type;
|
||||
virtual Func GetFunc() const { return func; }
|
||||
virtual bool UseMe(const Attr& attr) const = 0;
|
||||
virtual const char* ImplType() const = 0;
|
||||
|
||||
protected:
|
||||
Func func{nullptr};
|
||||
};
|
||||
|
||||
template <typename KernelTuples>
|
||||
class ReferKernel : public KernelMore<KernelTuples> {
|
||||
public:
|
||||
// Refer code can always be used
|
||||
bool UseMe(const typename KernelTuples::attr_type& attr) const override {
|
||||
return true;
|
||||
}
|
||||
const char* ImplType() const override { return "Refer"; }
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,47 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jit/kernel_key.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
|
||||
template <>
|
||||
size_t JitCodeKey<int>(const int& d) {
|
||||
return d;
|
||||
}
|
||||
|
||||
constexpr int act_type_shift = 3; // suppot 2^3 act types
|
||||
|
||||
template <>
|
||||
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
|
||||
size_t key = attr.d;
|
||||
int gate_key = static_cast<int>(attr.act_gate) << 1;
|
||||
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
|
||||
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
|
||||
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
|
||||
attr.use_peephole;
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
|
||||
size_t key = attr.d;
|
||||
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) +
|
||||
(static_cast<int>(attr.act_cand) << act_type_shift);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,53 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/operators/jit/kernel_base.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jit {
|
||||
|
||||
struct KernelKey {
|
||||
struct Hash {
|
||||
size_t operator()(const KernelKey& key) const {
|
||||
int place = key.place_.which(); // less than 2^8
|
||||
int type = static_cast<int>(key.type_) << 8; // less than 2^(32-8)
|
||||
std::hash<int> hasher;
|
||||
return hasher(place + type);
|
||||
}
|
||||
};
|
||||
|
||||
KernelType type_;
|
||||
platform::Place place_;
|
||||
|
||||
KernelKey(KernelType type, platform::Place place)
|
||||
: type_(type), place_(place) {}
|
||||
size_t hash_key() const { return Hash()(*this); }
|
||||
|
||||
bool operator==(const KernelKey& o) const {
|
||||
return platform::places_are_same_class(place_, o.place_) &&
|
||||
type_ == o.type_;
|
||||
}
|
||||
bool operator!=(const KernelKey& o) const { return !(*this == o); }
|
||||
};
|
||||
|
||||
// Every JitCode should have a method to get the key from attribution
|
||||
template <typename Attr>
|
||||
size_t JitCodeKey(const Attr& attr);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue