commit
41eeb771e8
File diff suppressed because it is too large
Load Diff
@ -1,43 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
#ifdef __AVX__
|
||||
template <>
|
||||
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
|
||||
float* ht) {
|
||||
namespace act = detail::forward::avx;
|
||||
// gates: W_ch, W_ih, W_fh, W_oh
|
||||
__m256 c, i, f, o;
|
||||
c = _mm256_loadu_ps(gates);
|
||||
i = _mm256_loadu_ps(gates + 8);
|
||||
f = _mm256_loadu_ps(gates + 16);
|
||||
o = _mm256_loadu_ps(gates + 24);
|
||||
|
||||
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
||||
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
|
||||
i = _mm256_loadu_ps(ct_1);
|
||||
f = _mm256_mul_ps(i, act::Sigmoid(f));
|
||||
f = _mm256_add_ps(c, f);
|
||||
_mm256_storeu_ps(ct, f);
|
||||
|
||||
/* H_t = act_cell(C_t) * ogated */
|
||||
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
|
||||
_mm256_storeu_ps(ht, o);
|
||||
}
|
||||
#endif
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,64 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/operators/math/cpu_vec.h"
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
#ifdef __AVX__
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
// TODO(TJ): ugly workaround, clean me
|
||||
template <typename T>
|
||||
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
|
||||
// gates: W_ch, W_ih, W_fh, W_oh
|
||||
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
|
||||
vec_tanh<T, platform::jit::avx>(8, gates, gates);
|
||||
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
|
||||
const T min = SIGMOID_THRESHOLD_MIN;
|
||||
const T max = SIGMOID_THRESHOLD_MAX;
|
||||
for (int d = 0; d < 8; ++d) {
|
||||
// C_t = C_t-1 * fgated + cand_gated * igated
|
||||
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
|
||||
// H_t = act_cell(C_t) * ogated
|
||||
T tmp = ct[d] * 2;
|
||||
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
|
||||
vec_exp<T>(1, &tmp, &tmp);
|
||||
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
|
||||
ht[d] = tmp * o[d];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __AVX__
|
||||
namespace detail {
|
||||
namespace forward {
|
||||
namespace avx {
|
||||
__m256 Sigmoid(const __m256 a);
|
||||
__m256 Tanh(const __m256 a);
|
||||
|
||||
} // namespace avx
|
||||
} // namespace forward
|
||||
} // namespace detail
|
||||
|
||||
template <>
|
||||
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
|
||||
float* ht);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,41 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/jit_kernel.h"
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
namespace jit = platform::jit;
|
||||
|
||||
KernelPool& KernelPool::Instance() {
|
||||
static thread_local KernelPool g_jit_kernels;
|
||||
return g_jit_kernels;
|
||||
}
|
||||
|
||||
std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const {
|
||||
if (kers_.find(key) == kers_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return kers_.at(key);
|
||||
}
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,142 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "paddle/fluid/platform/cpu_info.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
// Note: Only support on CPU yet.
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
#define SIGMOID_THRESHOLD_MIN -40.0
|
||||
#define SIGMOID_THRESHOLD_MAX 13.0
|
||||
#define EXP_MAX_INPUT 40.0
|
||||
#define AVX_FLOAT_BLOCK 8
|
||||
#define AVX2_FLOAT_BLOCK 8
|
||||
#define AVX512_FLOAT_BLOCK 16
|
||||
|
||||
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
|
||||
|
||||
class Kernel {
|
||||
public:
|
||||
Kernel() = default;
|
||||
virtual ~Kernel() = default;
|
||||
int num_{0};
|
||||
int end_{0};
|
||||
int rest_{0};
|
||||
DISABLE_COPY_AND_ASSIGN(Kernel);
|
||||
};
|
||||
|
||||
class KernelPool {
|
||||
public:
|
||||
static KernelPool &Instance();
|
||||
|
||||
template <typename Ker, typename... ARGS>
|
||||
std::shared_ptr<const Ker> Get(ARGS... args);
|
||||
|
||||
std::shared_ptr<const Kernel> Get(const std::string &key) const;
|
||||
|
||||
private:
|
||||
KernelPool() = default;
|
||||
std::unordered_map<std::string, std::shared_ptr<const Kernel>> kers_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(KernelPool);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VMulKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T *x, const T *y, T *z) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VAddKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T *x, const T *y, T *z) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VScalKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T a, const T *x, T *y) const = 0;
|
||||
virtual void Compute(const T a, T *x) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VAddBiasKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T a, const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VActKernel : public Kernel {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VReluKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VIdentityKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VExpKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VSigmoidKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VTanhKernel : public VActKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const T *x, T *y) const = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LSTMKernel : public Kernel {
|
||||
public:
|
||||
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht,
|
||||
/* below only used in peephole*/
|
||||
const T *wp_data = nullptr,
|
||||
T *checked = nullptr) const = 0;
|
||||
|
||||
// compute c1 and h1 without c0 or h0
|
||||
virtual void ComputeC1H1(T *gates, T *ct, T *ht,
|
||||
/* below only used in peephole*/
|
||||
const T *wp_data = nullptr) const = 0;
|
||||
};
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue