parent
439af8d50a
commit
77236e33fc
@ -0,0 +1,17 @@
|
||||
|
||||
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
|
||||
|
||||
cc_library(jit_kernel_base SRCS kernels.cc DEPS ${JIT_KERNEL_DEPS})
|
||||
|
||||
add_subdirectory(more)
|
||||
add_subdirectory(refer)
|
||||
|
||||
if(WITH_XBYAK)
|
||||
add_subdirectory(jitcode)
|
||||
endif()
|
||||
|
||||
# Debug
|
||||
message(STATUS "--------${JIT_KERNEL_DEPS}")
|
||||
|
||||
cc_library(jit_kernel SRCS kernels.cc DEPS ${JIT_KERNEL_DEPS})
|
||||
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel)
|
@ -0,0 +1 @@
|
||||
TBD
|
@ -0,0 +1,3 @@
|
||||
|
||||
cc_library(jit_kernel_jitcode SRCS jitcode.cc DEPS jit_kernel_base xbyak)
|
||||
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE)
|
@ -0,0 +1,15 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
|
@ -0,0 +1,54 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include "paddle/fluid/operators/jitkernels/kernels.h"
|
||||
|
||||
#define XBYAK_USE_MMAP_ALLOCATOR
|
||||
#include "xbyak/xbyak.h"
|
||||
#include "xbyak/xbyak_util.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
namespace jitcode {
|
||||
|
||||
// Application Binary Interface
|
||||
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
|
||||
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
|
||||
abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX);
|
||||
|
||||
template <KernelType KT, typename Attr>
|
||||
class JitCode : public JitBase, public Xbyak::CodeGenerator {
|
||||
public:
|
||||
JitCode(Attr attr, size_t code_size, void* code_ptr = nullptr)
|
||||
: Xbyak::CodeGenerator(code_size, code_ptr) {
|
||||
this->genCode();
|
||||
}
|
||||
|
||||
virtual const char* name() const = 0;
|
||||
virtual void genCode() = 0;
|
||||
|
||||
const unsigned char* getCodeInternal() override {
|
||||
const Xbyak::uint8* code = CodeGenerator::getCode();
|
||||
return code;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jitcode
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,40 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
|
||||
|
||||
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
|
||||
// refer do not need useme, it would be the last one.
|
||||
void JitBase::dumpCode(const unsigned char* code) const {
|
||||
if (code) {
|
||||
static int counter = 0;
|
||||
std::ostringstream filename;
|
||||
filename << "paddle_jitcode_" << name() << "." << counter << ".bin";
|
||||
counter++;
|
||||
std::ofstream fout(filename.str(), std::ios::out);
|
||||
if (fout.is_open()) {
|
||||
fout.write(reinterpret_cast<const char*>(code), getSize());
|
||||
fout.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,73 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
DECLARE_bool(dump_jitcode);
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
|
||||
// TODO(TJ): make these functions as virtual of a class
|
||||
|
||||
// Every JitCode should estimate the code size itself
|
||||
template <KernelType KT, typename Attr>
|
||||
size_t CodeSize(Attr attr) {
|
||||
return 4096;
|
||||
}
|
||||
|
||||
// Every JitCode should have a condition when to use this JitCode
|
||||
template <KernelType KT, typename T, typename Attr>
|
||||
bool UseJitCode(Attr attr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Every JitCode should have a method to get the key from attribution
|
||||
template <typename Attr>
|
||||
size_t GetKey(Attr attr);
|
||||
|
||||
template <>
|
||||
size_t GetKey<int>(int d) {
|
||||
return d;
|
||||
}
|
||||
|
||||
class JitBase {
|
||||
public:
|
||||
JitBase() = default;
|
||||
virtual ~JitBase() = default;
|
||||
virtual const char* name() const = 0;
|
||||
virtual const unsigned char* getCodeInternal() = 0;
|
||||
|
||||
template <typename FUNC>
|
||||
const FUNC getCode() {
|
||||
const unsigned char* code = this->getCodeInternal();
|
||||
if (FLAGS_dump_jitcode) {
|
||||
this->dumpCode(code);
|
||||
}
|
||||
return reinterpret_cast<const FUNC>(code);
|
||||
}
|
||||
DISABLE_COPY_AND_ASSIGN(JitBase);
|
||||
|
||||
protected:
|
||||
void dumpCode(const unsigned char* code);
|
||||
};
|
||||
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,47 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
|
||||
typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType;
|
||||
|
||||
// Just for adding to kernel pool without template
|
||||
class Kernel {
|
||||
public:
|
||||
Kernel() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(Kernel);
|
||||
};
|
||||
|
||||
template <typename T, typename Func, typename Attr> // TODO(TJ): use tuple
|
||||
class KernelImpl : public Kernel {
|
||||
public:
|
||||
using ELEMENT_TYPE = T; // TODO(TJ): remove me?
|
||||
KernelImpl() = default;
|
||||
virtual ~KernelImpl() = default;
|
||||
|
||||
virtual Func GetFunc() { return func; }
|
||||
virtual bool UseMe(Attr attr) const = 0;
|
||||
|
||||
protected:
|
||||
Func func{nullptr};
|
||||
};
|
||||
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,49 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
|
||||
struct KernelKey {
|
||||
struct Hash {
|
||||
size_t operator()(const KernelKey& key) const {
|
||||
int place = key.place_.which(); // less than 2^8
|
||||
int type = static_cast<int>(key.type_) << 8; // less than 2^(32-8)
|
||||
std::hash<int> hasher;
|
||||
return hasher(place + type);
|
||||
}
|
||||
};
|
||||
|
||||
KernelType type_;
|
||||
platform::Place place_;
|
||||
|
||||
KernelKey(KernelType type, platform::Place place)
|
||||
: type_(type), place_(place) {}
|
||||
size_t hash_key() const { return Hash()(*this); }
|
||||
|
||||
bool operator==(const KernelKey& o) const {
|
||||
return platform::places_are_same_class(place_, o.place_) &&
|
||||
type_ == o.type_;
|
||||
}
|
||||
bool operator!=(const KernelKey& o) const { return !(*this == o); }
|
||||
};
|
||||
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,33 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jitkernels/kernels.h"
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
|
||||
// refer do not need useme, it would be the last one.
|
||||
|
||||
KernelPool& KernelPool::Instance() {
|
||||
static KernelPool g_kernel_pool;
|
||||
return g_kernel_pool;
|
||||
}
|
||||
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,142 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
|
||||
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
|
||||
#include "paddle/fluid/operators/jitkernels/kernel_key.h"
|
||||
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
|
||||
template <KernelType KT>
|
||||
class JitCodePool {
|
||||
public:
|
||||
static JitCodePool& Instance() {
|
||||
static thread_local JitCodePool<KT> g_jit_codes;
|
||||
return g_jit_codes;
|
||||
}
|
||||
|
||||
std::shared_ptr<const JitBase> Get(size_t key) const {
|
||||
if (codes_.find(key) == codes_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return codes_.at(key);
|
||||
}
|
||||
|
||||
void Insert(size_t key, const std::shared_ptr<const JitBase>& value) {
|
||||
codes_.insert({key, value});
|
||||
}
|
||||
|
||||
private:
|
||||
JitCodePool() = default;
|
||||
std::unordered_map<size_t, std::shared_ptr<const JitBase>> codes_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(JitCodePool);
|
||||
};
|
||||
|
||||
// std::tuple<T, Func, Attr>
|
||||
template <typename T, typename Func, typename Attr>
|
||||
struct KernelAttr {
|
||||
typedef T data_type;
|
||||
typedef Func return_type;
|
||||
typedef Attr attr_type;
|
||||
};
|
||||
|
||||
class KernelPool {
|
||||
public:
|
||||
static KernelPool& Instance();
|
||||
|
||||
typedef std::unique_ptr<const Kernel> KernelPtr;
|
||||
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
|
||||
KernelMap;
|
||||
KernelMap& AllKernels() { return pool_; }
|
||||
|
||||
void Insert(const KernelKey& key, KernelPtr value) {
|
||||
if (pool_.find(key) == pool_.end()) {
|
||||
pool_.emplace(key, std::vector<KernelPtr>());
|
||||
}
|
||||
pool_.at(key).emplace_back(std::move(value));
|
||||
}
|
||||
KernelPool() = default;
|
||||
|
||||
private:
|
||||
KernelMap pool_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(KernelPool);
|
||||
};
|
||||
|
||||
// TODO(TJ): create_jitcode;
|
||||
|
||||
// TODO(TJ): make tuple? named KernelAttr
|
||||
template <KernelType KT, typename T, typename Func, typename Attr,
|
||||
typename PlaceType = platform::CPUPlace>
|
||||
Func Get(Attr attr) {
|
||||
size_t key = GetKey<Attr>(attr);
|
||||
auto jitcode = JitCodePool<KT>().Instance().Get(key);
|
||||
if (jitcode) {
|
||||
return jitcode->template getCode<Func>();
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_XBYAK
|
||||
// // jitcode::JitCode is under protection of PADDLE_WITH_XBYAK
|
||||
// if (std::is_same<PlaceType, platform::CPUPlace>::value) {
|
||||
// if (UseJitCode<KT, T, Attr>(attr)) {
|
||||
// std::shared_ptr<JitBase> p(std::make_shared<jitcode::JitCode<KT, Attr>>(
|
||||
// attr, CodeSize<KT, Attr>(attr)));
|
||||
// JitCodePool<KT>().Instance().Insert(key, p);
|
||||
// return p->getCode<Func>();
|
||||
// }
|
||||
// }
|
||||
#endif
|
||||
|
||||
// (KernelKey(type, place), vector<Kernel>)
|
||||
auto& pool = KernelPool().Instance().AllKernels();
|
||||
KernelKey kkey(KT, PlaceType());
|
||||
auto iter = pool.find(kkey);
|
||||
if (iter != pool.end()) {
|
||||
auto impls = iter->second;
|
||||
for (auto impl : impls) {
|
||||
auto i = std::dynamic_pointer_cast<KernelImpl<T, Func, Attr>>(impl.get());
|
||||
if (i && i->UseMe(attr)) {
|
||||
return i->GetFunc();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The last implementation should be reference function on CPU
|
||||
// Every kernel should have refer code.
|
||||
|
||||
// because of test refer should have it's own pool
|
||||
// PADDLE_ENFORCE_GT(list.size(), 1) << "Should have refer implemtation";
|
||||
// const auto& refer = KernelRefer<KT, T>().AllKernels();
|
||||
// return refer.Get<Func>();
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,7 @@
|
||||
|
||||
|
||||
if(WITH_MKLML)
|
||||
add_subdirectory(mkl)
|
||||
endif()
|
||||
|
||||
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} PARENT_SCOPE)
|
@ -0,0 +1,3 @@
|
||||
|
||||
cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
|
||||
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
|
@ -0,0 +1,44 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jitkernels/more/mkl/mkl.h"
|
||||
#include "paddle/fluid/operators/jitkernels/registry.h"
|
||||
#include "paddle/fluid/platform/dynload/mklml.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
namespace more {
|
||||
namespace mkl {
|
||||
|
||||
template <>
|
||||
void VMul<float>(const float* x, const float* y, float* z, int n) {
|
||||
platform::dynload::vsMul(n, x, y, z);
|
||||
}
|
||||
|
||||
template <>
|
||||
void VMul<double>(const double* x, const double* y, double* z, int n) {
|
||||
platform::dynload::vdMul(n, x, y, z);
|
||||
}
|
||||
|
||||
} // namespace mkl
|
||||
} // namespace more
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace mkl = paddle::operators::jitkernels::more::mkl;
|
||||
|
||||
REGISTER_JITKERNEL_MORE(vmul, mkl, mkl::VMulKernel<float>,
|
||||
mkl::VMulKernel<double>);
|
@ -0,0 +1,55 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
namespace more {
|
||||
namespace mkl {
|
||||
|
||||
template <typename T>
|
||||
void VMul(const T* x, const T* y, T* z, int n);
|
||||
|
||||
// template <typename T>
|
||||
// struct VMulTypes{
|
||||
// typedef T date_type;
|
||||
// typedef void (*func)(const T*, const T*, T*, int) func_type;
|
||||
// typedef int attr_type;
|
||||
// };
|
||||
|
||||
template <typename T>
|
||||
class VMulKernel
|
||||
: public KernelImpl<T, void (*)(const T*, const T*, T*, int), int> {
|
||||
public:
|
||||
VMulKernel() { this->func = VMul<T>; }
|
||||
bool UseMe(int d) const override {
|
||||
if (std::is_same<T, float>::value) {
|
||||
return platform::jit::MayIUse(platform::jit::avx512f) && d > 512;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mkl
|
||||
} // namespace more
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,15 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
@ -0,0 +1,3 @@
|
||||
|
||||
cc_library(jit_kernel_refer SRCS refer.cc DEPS jit_kernel_base)
|
||||
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_refer PARENT_SCOPE)
|
@ -0,0 +1,20 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/jitkernels/refer/refer.h"
|
||||
#include "paddle/fluid/operators/jitkernels/registry.h"
|
||||
|
||||
namespace refer = paddle::operators::jitkernels::refer;
|
||||
|
||||
// REGISTER_JITKERNEL_REFER(vmul, refer::VMul<float>, refer::VMul<double>);
|
@ -0,0 +1,33 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
namespace refer {
|
||||
|
||||
template <typename T>
|
||||
void VMul(const T* x, const T* y, T* z, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
z[i] = x[i] * y[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace refer
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,134 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include "paddle/fluid/operators/jitkernels/kernel_base.h"
|
||||
#include "paddle/fluid/operators/jitkernels/kernels.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace jitkernels {
|
||||
|
||||
// make_unique is supported from c++14
|
||||
template <typename T, typename... Args>
|
||||
inline std::unique_ptr<T> make_unique(Args&&... args) {
|
||||
static_assert(!std::is_array<T>::value, "T must not be array");
|
||||
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
template <typename PlaceType, bool IsEnd, size_t I, typename... KernelImpls>
|
||||
struct JitKernelRegistrarFunctor;
|
||||
|
||||
template <typename PlaceType, size_t I, typename... KernelImpls>
|
||||
struct JitKernelRegistrarFunctor<PlaceType, true, I, KernelImpls...> {
|
||||
void operator()(KernelType kt) const {}
|
||||
};
|
||||
|
||||
template <typename PlaceType, size_t I, typename... KernelImpls>
|
||||
struct JitKernelRegistrarFunctor<PlaceType, false, I, KernelImpls...> {
|
||||
using KERNEL_IMPL_TYPE =
|
||||
typename std::tuple_element<I, std::tuple<KernelImpls...>>::type;
|
||||
|
||||
void operator()(KernelType kt) const {
|
||||
KernelKey kkey(kt, PlaceType());
|
||||
KernelPool().Instance().Insert(
|
||||
kkey, std::move(make_unique<const KERNEL_IMPL_TYPE>()));
|
||||
constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
|
||||
JitKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelImpls...>
|
||||
func;
|
||||
func(kt);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PlaceType, typename... KernelImpls>
|
||||
class JitKernelRegistrar {
|
||||
public:
|
||||
explicit JitKernelRegistrar(KernelType kt) {
|
||||
JitKernelRegistrarFunctor<PlaceType, false, 0, KernelImpls...> func;
|
||||
func(kt);
|
||||
}
|
||||
};
|
||||
|
||||
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
|
||||
struct __test_global_namespace_##uniq_name##__ {}; \
|
||||
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
|
||||
__test_global_namespace_##uniq_name##__>::value, \
|
||||
msg)
|
||||
|
||||
// kernel_type: should be in paddle::operators::jitkernels::KernelType
|
||||
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
|
||||
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
|
||||
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
|
||||
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
|
||||
"REGISTER_KERNEL_MORE must be called in global namespace"); \
|
||||
static ::paddle::operators::jitkernels::JitKernelRegistrar< \
|
||||
::paddle::platform::place_type, __VA_ARGS__> \
|
||||
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##__( \
|
||||
::paddle::operators::jitkernels::KernelType::kernel_type)
|
||||
// TODO(TJ): Add Touch and use me
|
||||
|
||||
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
|
||||
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
|
||||
|
||||
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
|
||||
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
|
||||
|
||||
/*
|
||||
REGISTER_JITKERNEL_JITCODE(vmul, JitKernelCode<vmul, int>);
|
||||
|
||||
// refer must be only one and at least one
|
||||
REGISTER_JITKERNEL_REFER(vmul, VMul); // Refer need support dtype
|
||||
|
||||
// you can register more implementations and the condition when use it
|
||||
REGISTER_JITKERNEL_MORE(vmul, mkl::VMUL<float>, UseMe<float>, mkl::VMUL<double>,
|
||||
UseMe<double>)
|
||||
|
||||
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
|
||||
struct __test_global_namespace_##uniq_name##__ {}; \
|
||||
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
|
||||
__test_global_namespace_##uniq_name##__>::value, \
|
||||
msg)
|
||||
|
||||
// Register a new pass that can be applied on the IR.
|
||||
#define REGISTER_PASS(pass_type, pass_class) \
|
||||
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
|
||||
__reg_pass__##pass_type, \
|
||||
"REGISTER_PASS must be called in global namespace"); \
|
||||
static ::paddle::framework::ir::PassRegistrar<pass_class> \
|
||||
__pass_registrar_##pass_type##__(#pass_type); \
|
||||
int TouchPassRegistrar_##pass_type() { \
|
||||
__pass_registrar_##pass_type##__.Touch(); \
|
||||
return 0; \
|
||||
} \
|
||||
static ::paddle::framework::ir::PassRegistrar<pass_class>& \
|
||||
__pass_tmp_registrar_##pass_type##__ UNUSED = \
|
||||
__pass_registrar_##pass_type##__
|
||||
|
||||
#define USE_PASS(pass_type) \
|
||||
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
|
||||
__use_pass_itself_##pass_type, \
|
||||
"USE_PASS must be called in global namespace"); \
|
||||
extern int TouchPassRegistrar_##pass_type(); \
|
||||
static int use_pass_itself_##pass_type##_ UNUSED = \
|
||||
TouchPassRegistrar_##pass_type()
|
||||
*/
|
||||
|
||||
} // namespace jitkernels
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,36 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include <cstring> // for memcpy
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "gflags/gflags.h"
|
||||
#include "glog/logging.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/operators/math/jit_kernel.h"
|
||||
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
|
||||
#include "paddle/fluid/platform/port.h"
|
||||
|
||||
constexpr int repeat = 20000;
|
||||
|
||||
inline double GetCurrentUS() {
|
||||
struct timeval time;
|
||||
gettimeofday(&time, NULL);
|
||||
return 1e+6 * time.tv_sec + time.tv_usec;
|
||||
}
|
||||
|
||||
TEST(JitKernel, vmul) {}
|
||||
|
||||
TEST(JitKernel, pool) {}
|
Loading…
Reference in new issue