parent
fa2ab3346c
commit
2b5edfbc37
File diff suppressed because it is too large
Load Diff
@ -1,43 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
#ifdef __AVX__
|
||||
template <>
|
||||
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
|
||||
float* ht) {
|
||||
namespace act = detail::forward::avx;
|
||||
// gates: W_ch, W_ih, W_fh, W_oh
|
||||
__m256 c, i, f, o;
|
||||
c = _mm256_loadu_ps(gates);
|
||||
i = _mm256_loadu_ps(gates + 8);
|
||||
f = _mm256_loadu_ps(gates + 16);
|
||||
o = _mm256_loadu_ps(gates + 24);
|
||||
|
||||
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
||||
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
|
||||
i = _mm256_loadu_ps(ct_1);
|
||||
f = _mm256_mul_ps(i, act::Sigmoid(f));
|
||||
f = _mm256_add_ps(c, f);
|
||||
_mm256_storeu_ps(ct, f);
|
||||
|
||||
/* H_t = act_cell(C_t) * ogated */
|
||||
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
|
||||
_mm256_storeu_ps(ht, o);
|
||||
}
|
||||
#endif
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,64 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/operators/math/cpu_vec.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
#ifdef __AVX__
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
// TODO(TJ): ugly workaround, clean me
|
||||
template <typename T>
|
||||
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
|
||||
// gates: W_ch, W_ih, W_fh, W_oh
|
||||
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
|
||||
vec_tanh<T, platform::jit::avx>(8, gates, gates);
|
||||
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
|
||||
const T min = SIGMOID_THRESHOLD_MIN;
|
||||
const T max = SIGMOID_THRESHOLD_MAX;
|
||||
for (int d = 0; d < 8; ++d) {
|
||||
// C_t = C_t-1 * fgated + cand_gated * igated
|
||||
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
|
||||
// H_t = act_cell(C_t) * ogated
|
||||
T tmp = ct[d] * 2;
|
||||
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
|
||||
vec_exp<T>(1, &tmp, &tmp);
|
||||
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
|
||||
ht[d] = tmp * o[d];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __AVX__
|
||||
namespace detail {
|
||||
namespace forward {
|
||||
namespace avx {
|
||||
__m256 Sigmoid(const __m256 a);
|
||||
__m256 Tanh(const __m256 a);
|
||||
|
||||
} // namespace avx
|
||||
} // namespace forward
|
||||
} // namespace detail
|
||||
|
||||
template <>
|
||||
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
|
||||
float* ht);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,41 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/jit_kernel.h"
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
namespace jit = platform::jit;
|
||||
|
||||
KernelPool& KernelPool::Instance() {
|
||||
static thread_local KernelPool g_jit_kernels;
|
||||
return g_jit_kernels;
|
||||
}
|
||||
|
||||
std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const {
|
||||
if (kers_.find(key) == kers_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return kers_.at(key);
|
||||
}
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,142 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
// Note: Only support on CPU yet.
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
#define SIGMOID_THRESHOLD_MIN -40.0
|
||||
#define SIGMOID_THRESHOLD_MAX 13.0
|
||||
#define EXP_MAX_INPUT 40.0
|
||||
#define AVX_FLOAT_BLOCK 8
|
||||
#define AVX2_FLOAT_BLOCK 8
|
||||
#define AVX512_FLOAT_BLOCK 16
|
||||
|
||||
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
|
||||
|
||||
class Kernel {
|
||||
public:
|
||||
Kernel() = default;
|
||||
virtual ~Kernel() = default;
|
||||
int num_{0};
|
||||
int end_{0};
|
||||
int rest_{0};
|
||||
DISABLE_COPY_AND_ASSIGN(Kernel);
|
||||
};
|
||||
|
||||
class KernelPool {
|
||||
public:
|
||||
static KernelPool &Instance();
|
||||
|
||||
template <typename Ker, typename... ARGS>
|
||||
std::shared_ptr<const Ker> Get(ARGS... args);
|
||||
|
||||
std::shared_ptr<const Kernel> Get(const std::string &key) const;
|
||||
|
||||
private:
|
||||
KernelPool() = default;
|
||||
std::unordered_map<std::string, std::shared_ptr<const Kernel>> kers_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(KernelPool);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VMulKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T *x, const T *y, T *z) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VAddKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T *x, const T *y, T *z) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VScalKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T a, const T *x, T *y) const = 0;
|
||||
virtual void Compute(const T a, T *x) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VAddBiasKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T a, const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VActKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VReluKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VIdentityKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VExpKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VSigmoidKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VTanhKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LSTMKernel : public Kernel {
|
||||
public:
|
||||
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht,
|
||||
/* below only used in peephole*/
|
||||
const T *wp_data = nullptr,
|
||||
T *checked = nullptr) const = 0;
|
||||
|
||||
// compute c1 and h1 without c0 or h0
|
||||
virtual void ComputeC1H1(T *gates, T *ct, T *ht,
|
||||
/* below only used in peephole*/
|
||||
const T *wp_data = nullptr) const = 0;
|
||||
};
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,111 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
namespace jit = platform::jit;
|
||||
|
||||
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
|
||||
if (d < AVX_FLOAT_BLOCK) { \
|
||||
macro_(ker, dtype, isa, kLT8); \
|
||||
} else if (d == AVX_FLOAT_BLOCK) { \
|
||||
macro_(ker, dtype, isa, kEQ8); \
|
||||
} else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \
|
||||
macro_(ker, dtype, isa, kGT8LT16); \
|
||||
} else if (d == AVX512_FLOAT_BLOCK) { \
|
||||
macro_(ker, dtype, isa, kEQ16); \
|
||||
} else { \
|
||||
macro_(ker, dtype, isa, kGT16); \
|
||||
}
|
||||
|
||||
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
|
||||
if (jit::MayIUse(jit::avx512f)) { \
|
||||
SEARCH_BLOCK(macro_, ker, dtype, jit::avx512f); \
|
||||
} else if (jit::MayIUse(jit::avx2)) { \
|
||||
SEARCH_BLOCK(macro_, ker, dtype, jit::avx2); \
|
||||
} else if (jit::MayIUse(jit::avx)) { \
|
||||
SEARCH_BLOCK(macro_, ker, dtype, jit::avx); \
|
||||
} else { \
|
||||
SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \
|
||||
}
|
||||
|
||||
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
|
||||
template <> \
|
||||
std::shared_ptr<const ker_class<ker_dtype>> \
|
||||
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
|
||||
|
||||
#define JITKERNEL_KEY(ker_key, dtype_key) \
|
||||
#ker_key #dtype_key + std::to_string(d)
|
||||
|
||||
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
|
||||
p = std::dynamic_pointer_cast<ker<dtype>>( \
|
||||
std::make_shared<ker##Impl<dtype, isa, k>>(d))
|
||||
|
||||
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
|
||||
marco_declare, macro_key, macro_impl) \
|
||||
marco_declare(ker_class, ker_dtype) { \
|
||||
std::string key = macro_key(ker_key, dtype_key); \
|
||||
if (kers_.find(key) == kers_.end()) { \
|
||||
std::shared_ptr<ker_class<ker_dtype>> p; \
|
||||
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
|
||||
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
|
||||
return p; \
|
||||
} \
|
||||
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
|
||||
kers_.at(key)); \
|
||||
}
|
||||
|
||||
#define REGISTER_JITKERNEL(ker_key, ker_class) \
|
||||
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \
|
||||
JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \
|
||||
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \
|
||||
JITKERNEL_KEY, JITKERNEL_NEW_IMPL)
|
||||
|
||||
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \
|
||||
macro_impl) \
|
||||
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \
|
||||
macro_impl); \
|
||||
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \
|
||||
macro_key, macro_impl)
|
||||
|
||||
#define FOR_EACH_ISA(macro_, block) \
|
||||
macro_(jit::avx512f, block); \
|
||||
macro_(jit::avx2, block); \
|
||||
macro_(jit::avx, block); \
|
||||
macro_(jit::isa_any, block)
|
||||
|
||||
#define FOR_EACH_BLOCK(macro_, isa) \
|
||||
macro_(isa, kLT8); \
|
||||
macro_(isa, kEQ8); \
|
||||
macro_(isa, kGT8LT16); \
|
||||
macro_(isa, kEQ16); \
|
||||
macro_(isa, kGT16)
|
||||
|
||||
#define FOR_EACH_ISA_BLOCK(macro_) \
|
||||
FOR_EACH_BLOCK(macro_, jit::avx512f); \
|
||||
FOR_EACH_BLOCK(macro_, jit::avx2); \
|
||||
FOR_EACH_BLOCK(macro_, jit::avx); \
|
||||
FOR_EACH_BLOCK(macro_, jit::isa_any)
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue