refine refer code and add lstm refer code

test=develop
local_add_cudnn_lstm
tensor-tang 6 years ago
parent c2cfb03a72
commit ce31deb7e9

@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK
@ -31,49 +32,6 @@ namespace math {
namespace jitkernel {
namespace jit = platform::jit;
template <typename T>
void VMulRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename T>
void VAddRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddReluRefer(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VScalRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
template <typename T>
void VReluRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
VScalRefer<float>(a, x, y, n);
refer::VScal<float>(a, x, y, n);
}
}
@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
VScalRefer<double>(a, x, y, n);
refer::VScal<double>(a, x, y, n);
}
}
@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> {
return;
}
#endif
this->Compute = VMulRefer<T>;
this->Compute = refer::VMul<T>;
}
#ifdef PADDLE_WITH_XBYAK
@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> {
return;
}
#endif
this->Compute = VAddRefer<T>;
this->Compute = refer::VAdd<T>;
}
#ifdef PADDLE_WITH_XBYAK
@ -242,7 +200,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
return;
}
#endif
this->Compute = VAddReluRefer<T>;
this->Compute = refer::VAddRelu<T>;
}
#ifdef PADDLE_WITH_XBYAK
@ -280,7 +238,7 @@ class VScalKernelImpl : public VScalKernel<T> {
return;
}
#endif
this->Compute = VScalRefer<T>;
this->Compute = refer::VScal<T>;
}
#ifdef PADDLE_WITH_XBYAK
@ -324,7 +282,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
}
#endif
this->Compute = VAddBiasRefer<T>;
this->Compute = refer::VAddBias<T>;
}
#ifdef PADDLE_WITH_XBYAK
@ -358,7 +316,7 @@ class VReluKernelImpl : public VReluKernel<T> {
}
#endif
this->Compute = VReluRefer<T>;
this->Compute = refer::VRelu<T>;
}
#ifdef PADDLE_WITH_XBYAK
@ -374,16 +332,13 @@ bool VReluKernelImpl<float>::useJIT(int d) {
}
#endif
template <typename T>
inline void VIdentityRefer(const T* x, T* y, int n) {}
/* An empty JitKernel */
template <typename T>
class VIdentityKernelImpl : public VIdentityKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
this->Compute = VIdentityRefer<T>;
this->Compute = refer::VIdentity<T>;
}
};

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <cmath> // for exp
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
@ -35,38 +35,6 @@ namespace math {
namespace jitkernel {
namespace jit = platform::jit;
// TODO(TJ): move refer codes to one file
// Refer code only focus on correctness
template <typename T>
void VExpRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
template <typename T>
void VSigmoidRefer(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
template <typename T>
void VTanhRefer(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoidRefer(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup
template <typename T>
@ -129,7 +97,7 @@ class VExpKernelImpl : public VExpKernel<T> {
return;
}
#endif
this->Compute = VExpRefer<T>;
this->Compute = refer::VExp<T>;
}
#ifdef PADDLE_WITH_XBYAK
@ -182,7 +150,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
return;
}
#endif
this->Compute = VSigmoidRefer<T>;
this->Compute = refer::VSigmoid<T>;
}
#ifdef PADDLE_WITH_XBYAK
@ -234,7 +202,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
return;
}
#endif
this->Compute = VTanhRefer<T>;
this->Compute = refer::VTanh<T>;
}
#ifdef PADDLE_WITH_XBYAK

@ -38,9 +38,13 @@ typedef struct {
void* checked{nullptr};
} lstm_t;
typedef struct {
typedef struct lstm_attr_s {
int d;
std::string act_gate, act_cand, act_cell;
lstm_attr_s() = default;
lstm_attr_s(int _d, const std::string& _act_gate,
const std::string& _act_cand, const std::string& _act_cell)
: d(_d), act_gate(_act_gate), act_cand(_act_cand), act_cell(_act_cell) {}
} lstm_attr_t;
} // namespace jitkernel

@ -0,0 +1,171 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace refer {
/* Refer code only focus on correctness */
template <typename T>
void VMul(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
}
template <typename T>
void VAdd(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
template <typename T>
void VAddRelu(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
template <typename T>
void VScal(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBias(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
template <typename T>
void VRelu(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
template <typename T>
inline void VIdentity(const T* x, T* y, int n) {}
template <typename T>
void VExp(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = std::exp(x[i]);
}
}
template <typename T>
void VSigmoid(const T* x, T* y, int n) {
// y = 1 / (1 + e^-x)
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
}
template <typename T>
void VTanh(const T* x, T* y, int n) {
// y = 2 * sigmoid(2x) - 1
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoid(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
template <typename T>
void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
if (type == "sigmoid") {
return VSigmoid<T>;
} else if (type == "relu") {
return VRelu<T>;
} else if (type == "tanh") {
return VTanh<T>;
} else if (type == "identity" || type == "") {
return VIdentity<T>;
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
template <typename T>
void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
// gates: W_ch, W_ih, W_fh, W_oh
act_gate(gates + d, gates + d, d3);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand(gates, gates, d);
VMul(gates, gates + d, gates + d, d);
VMul(ct_1, gates + d2, gates + d2, d);
VAdd(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */
act_cell(ct, gates + d2, d);
VMul(gates + d2, gates + d3, ht, d);
}
template <typename T>
void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
auto act_gate = getActFunc<T>(attr->act_gate);
auto act_cand = getActFunc<T>(attr->act_cand);
auto act_cell = getActFunc<T>(attr->act_cell);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
/* C_t = igated * cgated*/
act_gate(gates + d, gates + d, d);
act_cand(gates, gates, d);
VMul(gates, gates + d, ct, d);
/* H_t = act_cell(C_t) * ogated */
act_gate(gates + d3, gates + d3, d);
act_cell(ct, gates + d2, d);
Vmul(gates + d2, gates + d3, ht, d);
}
} // namespace refer
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save