Merge pull request #4929 from qingqing01/lstm
Forward implementation for LSTM operator.revert-4814-Add_sequence_project_op
commit
3f1062d711
@ -0,0 +1,226 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/lstm_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LSTMOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
||||
"Output(Hidden) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
||||
"Output(Cell) of LSTM should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("Input");
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
||||
|
||||
if (ctx->HasInput("H0")) {
|
||||
PADDLE_ENFORCE(ctx->HasInput("C0"),
|
||||
"Input(Cell) and Input(Hidden) of LSTM should not "
|
||||
"be null at the same time.");
|
||||
auto h_dims = ctx->GetInputDim("H0");
|
||||
auto c_dims = ctx->GetInputDim("C0");
|
||||
PADDLE_ENFORCE(h_dims == c_dims,
|
||||
"The dimension of Input(H0) and Input(C0) "
|
||||
"should be the same.");
|
||||
}
|
||||
|
||||
int frame_size = x_dims[1] / 4;
|
||||
auto w_dims = ctx->GetInputDim("Weight");
|
||||
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
|
||||
"The rank of Input(Weight) should be 2.");
|
||||
PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
|
||||
"The first dimension of Input(Weight) "
|
||||
"should be %d.",
|
||||
frame_size);
|
||||
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
|
||||
"The second dimension of Input(Weight) "
|
||||
"should be 4 * %d.",
|
||||
frame_size);
|
||||
auto b_dims = ctx->GetInputDim("Bias");
|
||||
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
|
||||
PADDLE_ENFORCE_EQ(b_dims[0], 1,
|
||||
"The first dimension of Input(Bias) should be 1.");
|
||||
if (ctx->Attrs().Get<bool>("usePeepholes")) {
|
||||
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
|
||||
"The second dimension of Input(Bias) should be "
|
||||
"7 * %d if enable peepholes connection",
|
||||
frame_size);
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
|
||||
"The second dimension of Input(Bias) should be "
|
||||
"4 * %d if disable peepholes connection",
|
||||
frame_size);
|
||||
}
|
||||
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
|
||||
ctx->SetOutputDim("Cell", {x_dims[0], frame_size});
|
||||
ctx->SetOutputDim("BatchGate", x_dims);
|
||||
ctx->ShareLoD("Input", "Hidden");
|
||||
ctx->ShareLoD("Input", "Cell");
|
||||
}
|
||||
};
|
||||
|
||||
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Input",
|
||||
"(LoDTensor) the first input is a LodTensor, which support "
|
||||
"variable-time length input sequence. The underlying tensor in "
|
||||
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
|
||||
"total time steps in this mini-batch, D is the hidden size.");
|
||||
AddInput("H0",
|
||||
"(Tensor, optional) the initial hidden state is an optional "
|
||||
"input. This is a tensor with shape (N x D), where N is the "
|
||||
"batch size, D is the hidden size.");
|
||||
AddInput("C0",
|
||||
"(Tensor, optional) the initial cell state is an optional "
|
||||
"input. This is a tensor with shape (N x D), where N is the "
|
||||
"batch size. `H0` and `C0` can be NULL but only at the same time");
|
||||
AddInput("Weight",
|
||||
"(Tensor) the learnable hidden-hidden weights."
|
||||
" - The shape is (D x 4D), where D is the hidden size. "
|
||||
" - Weight = {W_ch, W_ih, W_fh, W_oh}");
|
||||
AddInput("Bias",
|
||||
"(Tensor) the learnable weights, which contains two parts: "
|
||||
"input-hidden bias weight and peephole connections weight if "
|
||||
"setting `usePeepholes` True. "
|
||||
"1. `usePeepholes = False` "
|
||||
" - The shape is (1 x 4D). "
|
||||
" - Bias = {b_c, b_i, b_f, b_o}."
|
||||
"2. `usePeepholes = True` "
|
||||
" - The shape is (1 x 7D). "
|
||||
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
|
||||
AddOutput("BatchGate",
|
||||
"(LoDTensor) This LoDTensor contains input gate, forget gate "
|
||||
"and output gate after the nonlinear computation. This "
|
||||
"LoDTensor has the same shape with the reorganized input, which "
|
||||
"was also be called batch input. The LoD size is 2. The first "
|
||||
"LoD is the batch offsets and the second LoD contains the "
|
||||
"indexes, which denote the position of reorganized sequence "
|
||||
"in the raw input.")
|
||||
.AsIntermediate();
|
||||
AddOutput("Hidden",
|
||||
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
|
||||
"The shape and lod is the same with the `Input`.");
|
||||
AddOutput("Cell",
|
||||
"(LoDTensor) the cell state lod tensor of LSTM operator. "
|
||||
"The shape and lod is the same with the `Input`.");
|
||||
AddAttr<bool>("usePeepholes",
|
||||
"(bool, defalut: True) "
|
||||
"whether to enable diagonal/peephole connections.")
|
||||
.SetDefault(true);
|
||||
AddAttr<bool>("isReverse",
|
||||
"(bool, defalut: False) "
|
||||
"whether to compute reversed LSTM.")
|
||||
.SetDefault(false);
|
||||
AddAttr<std::string>(
|
||||
"gateActivation",
|
||||
"(string, default: sigmoid)"
|
||||
"The activation for input gate, forget gate and output "
|
||||
"gate, `sigmoid` by default.")
|
||||
.SetDefault("sigmoid");
|
||||
AddAttr<std::string>("cellActivation",
|
||||
"(string, default: tanh)"
|
||||
"The activation for cell output, `tanh` by defalut.")
|
||||
.SetDefault("tanh");
|
||||
AddAttr<std::string>("candidateActivation",
|
||||
"(string, default: tanh)"
|
||||
"The activation for candidate hidden state, "
|
||||
"`tanh` by default.")
|
||||
.SetDefault("tanh");
|
||||
AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator
|
||||
|
||||
The defalut implementation is diagonal/peephole connection [1], the formula is
|
||||
as follows
|
||||
|
||||
i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i)
|
||||
|
||||
f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f)
|
||||
|
||||
\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
|
||||
|
||||
o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o)
|
||||
|
||||
c_t = f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c_t}
|
||||
|
||||
h_t = o_t ⊙ act_h(c_t)
|
||||
|
||||
where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix
|
||||
of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$
|
||||
are diagonal weight matrices for peephole connections. In our implenmention,
|
||||
We use vectors to reprenset these diagonal weight matrices. The b terms
|
||||
denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$
|
||||
is the non-line actications, such as logistic sigmoid function, and
|
||||
\f$i, f, o\f$ and \f$c\f$ are respectively the input gate, forget gate,
|
||||
output gate and cell activation vectors, all of which are the same size as
|
||||
the cell output activation vector \f$h\f$.
|
||||
|
||||
The ⊙ is the element-wise product of the vectors, \f$act_g\f$ and \f$act_h\f$
|
||||
are the cell input and cell output activation functions, `tanh` is usually
|
||||
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
|
||||
which is computed based on the current input and the previous hidden state.
|
||||
|
||||
Set `usePeepholes` False to disable peephole connection [2]. The formula
|
||||
is omitted here.
|
||||
|
||||
@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
|
||||
operations on the input x_{t} were NOT included in this operator.
|
||||
Users can choose to use fully-connect operator before LSTM operator.
|
||||
|
||||
[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory
|
||||
recurrent neural network architectures for large scale acoustic modeling.
|
||||
INTERSPEECH, 2014.
|
||||
|
||||
[2] S. Hochreiter and J. Schmidhuber. Long Short-Term Memory.
|
||||
Neural Computation, 9(8):1735-1780, 1997.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class LSTMGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
|
||||
"Input(Hidden@GRAD) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
|
||||
"Input(Cell@GRAD) should not be null");
|
||||
ctx->SetOutputDim(framework::GradVarName("Weight"),
|
||||
ctx->GetInputDim("Weight"));
|
||||
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(lstm, ops::LSTMOp, ops::LSTMOpMaker, lstm_grad, ops::LSTMGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(lstm, ops::LSTMKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::LSTMKernel<paddle::platform::CPUPlace, double>);
|
||||
REGISTER_OP_CPU_KERNEL(lstm_grad,
|
||||
ops::LSTMGradKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::LSTMGradKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,23 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "paddle/operators/lstm_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(lstm, ops::LSTMKernel<paddle::platform::GPUPlace, float>,
|
||||
ops::LSTMKernel<paddle::platform::GPUPlace, double>);
|
||||
REGISTER_OP_GPU_KERNEL(lstm_grad,
|
||||
ops::LSTMGradKernel<paddle::platform::GPUPlace, float>,
|
||||
ops::LSTMGradKernel<paddle::platform::GPUPlace, double>);
|
@ -0,0 +1,139 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/lstm_compute.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
#include "paddle/operators/math/sequence2batch.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::LoDTensor;
|
||||
using framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LSTMKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<framework::LoDTensor>("Input");
|
||||
auto* weight = ctx.Input<framework::Tensor>("Weight");
|
||||
auto* bias = ctx.Input<framework::Tensor>("Bias");
|
||||
|
||||
auto* batch_gate = ctx.Output<framework::LoDTensor>("BatchGate");
|
||||
batch_gate->mutable_data<T>(ctx.GetPlace());
|
||||
auto* hidden_out = ctx.Output<framework::LoDTensor>("Hidden");
|
||||
hidden_out->mutable_data<T>(ctx.GetPlace());
|
||||
auto* cell_out = ctx.Output<framework::LoDTensor>("Cell");
|
||||
cell_out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
// Now the function ShareLoD in InferShape is not implemented.
|
||||
// So copy LoD here.
|
||||
ctx.ShareLoD("Input", "Hidden");
|
||||
ctx.ShareLoD("Input", "Cell");
|
||||
|
||||
bool is_reverse = ctx.Attr<bool>("isReverse");
|
||||
math::LoDTensor2BatchFunctor<Place, T> to_batch;
|
||||
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse);
|
||||
|
||||
auto in_dims = input->dims();
|
||||
int frame_size = static_cast<int>(in_dims[1] / 4);
|
||||
framework::DDim dims({in_dims[0], frame_size});
|
||||
|
||||
if (bias) {
|
||||
Eigen::array<int, 2> extents({{1, 4 * frame_size}});
|
||||
Eigen::array<int, 2> offsets({{0, 0}});
|
||||
auto b = EigenMatrix<T>::From(*bias);
|
||||
auto gate = EigenMatrix<T>::From(*batch_gate);
|
||||
gate.device(ctx.GetEigenDevice<Place>()) =
|
||||
gate +
|
||||
b.slice(offsets, extents)
|
||||
.reshape(Eigen::array<int, 2>({{1, frame_size * 4}}))
|
||||
.broadcast(
|
||||
Eigen::array<int, 2>({{static_cast<int>(in_dims[0]), 1}}));
|
||||
}
|
||||
|
||||
math::LstmMetaValue<T> lstm_value;
|
||||
T* bias_data = const_cast<T*>(bias->data<T>());
|
||||
// the code style in LstmMetaValue will be updated later.
|
||||
lstm_value.checkIg = bias_data + 4 * frame_size;
|
||||
lstm_value.checkFg = lstm_value.checkIg + frame_size;
|
||||
lstm_value.checkOg = lstm_value.checkFg + frame_size;
|
||||
lstm_value.prevStateValue = nullptr;
|
||||
|
||||
framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act;
|
||||
batch_out.mutable_data<T>(dims, ctx.GetPlace());
|
||||
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
|
||||
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
|
||||
|
||||
auto batch_starts = batch_gate->lod()[0];
|
||||
size_t num_batch = batch_starts.size() - 1;
|
||||
auto gate_act = ctx.Attr<std::string>("gateActivation");
|
||||
auto cell_act = ctx.Attr<std::string>("cellActivation");
|
||||
auto cand_act = ctx.Attr<std::string>("candidateActivation");
|
||||
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
int bstart = static_cast<int>(batch_starts[n]);
|
||||
int bend = static_cast<int>(batch_starts[n + 1]);
|
||||
|
||||
Tensor gate_t = batch_gate->Slice(bstart, bend);
|
||||
Tensor out_t = batch_out.Slice(bstart, bend);
|
||||
Tensor cell_t = batch_cell.Slice(bstart, bend);
|
||||
Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend);
|
||||
|
||||
int cur_batch_size = bend - bstart;
|
||||
|
||||
if (n != 0) {
|
||||
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
|
||||
int pre_h_end = pre_h_start + cur_batch_size;
|
||||
auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end);
|
||||
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
|
||||
*weight, false, static_cast<T>(1.0), &gate_t,
|
||||
static_cast<T>(1.0));
|
||||
}
|
||||
// else if : FIXME support the initial hidden and cell
|
||||
|
||||
lstm_value.gateValue = gate_t.data<T>();
|
||||
lstm_value.outputValue = out_t.data<T>();
|
||||
lstm_value.stateValue = cell_t.data<T>();
|
||||
lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
|
||||
math::LstmUnitFunctor<Place, T>::compute(ctx.device_context(), lstm_value,
|
||||
frame_size, cur_batch_size,
|
||||
gate_act, cell_act, cand_act);
|
||||
lstm_value.prevStateValue = lstm_value.stateValue;
|
||||
}
|
||||
|
||||
math::Batch2LoDTensorFunctor<Place, T> to_seq;
|
||||
batch_out.set_lod(batch_gate->lod());
|
||||
// restore the output hidden in LoDTensor from the batch hidden
|
||||
to_seq(ctx.device_context(), batch_out, *hidden_out);
|
||||
|
||||
batch_cell.set_lod(batch_gate->lod());
|
||||
// restore the output cell state in LoDTensor from the batch cell
|
||||
to_seq(ctx.device_context(), batch_cell, *cell_out);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LSTMGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,5 @@
|
||||
if(WITH_AVX)
|
||||
cc_library(activation_functions SRCS hl_cpu_functions.cc hl_avx_functions.cc)
|
||||
else()
|
||||
cc_library(activation_functions SRCS hl_cpu_functions.cc)
|
||||
endif()
|
@ -0,0 +1,188 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifndef HL_ACTIVATION_FUNCTIONS_H_
|
||||
#define HL_ACTIVATION_FUNCTIONS_H_
|
||||
|
||||
#include "hl_functions.h"
|
||||
#include "paddle/operators/math/lstm_compute.h"
|
||||
|
||||
/**
|
||||
* Active functions: sigmoid, relu, tanh and linear.
|
||||
*/
|
||||
#define FLOAT_ACTIVE_FUNCTION \
|
||||
{ \
|
||||
hppl::typef::sigmoid, hppl::typef::relu, hppl::typef::tanh, \
|
||||
hppl::typef::linear \
|
||||
}
|
||||
|
||||
#define DOUBLE_ACTIVE_FUNCTION \
|
||||
{ \
|
||||
hppl::typed::sigmoid, hppl::typed::relu, hppl::typed::tanh, \
|
||||
hppl::typed::linear \
|
||||
}
|
||||
|
||||
#define AVX_ACTIVE_FUNCTION \
|
||||
{ hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear }
|
||||
|
||||
namespace hppl {
|
||||
|
||||
using activation_mode_t = paddle::operators::math::activation_mode_t;
|
||||
|
||||
/**
|
||||
* Hppl supports sigmoid, relu, tanh, linear active functions
|
||||
* for neural networks' forward and backward activation.
|
||||
*/
|
||||
template <class T>
|
||||
class Active {
|
||||
public:
|
||||
typedef T (*forward)(T);
|
||||
typedef T (*backward)(T, T);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ForwardActType;
|
||||
|
||||
template <>
|
||||
struct ForwardActType<float> {
|
||||
using type = Active<float>::forward;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ForwardActType<double> {
|
||||
using type = Active<double>::forward;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BackwardActType;
|
||||
|
||||
template <>
|
||||
struct BackwardActType<float> {
|
||||
using type = Active<float>::backward;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BackwardActType<double> {
|
||||
using type = Active<double>::backward;
|
||||
};
|
||||
|
||||
#ifdef __NVCC__
|
||||
namespace gpu {
|
||||
static __device__ Active<float>::forward forward[] = FLOAT_ACTIVE_FUNCTION;
|
||||
static __device__ Active<float>::backward backward[] = FLOAT_ACTIVE_FUNCTION;
|
||||
|
||||
static __device__ Active<double>::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION;
|
||||
static __device__ Active<double>::backward backward_d[] =
|
||||
DOUBLE_ACTIVE_FUNCTION;
|
||||
|
||||
template <typename T>
|
||||
struct ForwardAct {
|
||||
__device__ typename ForwardActType<T>::type operator()(
|
||||
activation_mode_t type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ForwardAct<float> {
|
||||
__device__ ForwardActType<float>::type operator()(activation_mode_t type) {
|
||||
return forward[type];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ForwardAct<double> {
|
||||
__device__ ForwardActType<double>::type operator()(activation_mode_t type) {
|
||||
return forward_d[type];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BackwardAct {
|
||||
__device__ typename BackwardActType<T>::type operator()(
|
||||
activation_mode_t type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BackwardAct<float> {
|
||||
__device__ BackwardActType<float>::type operator()(activation_mode_t type) {
|
||||
return backward[type];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BackwardAct<double> {
|
||||
__device__ BackwardActType<double>::type operator()(activation_mode_t type) {
|
||||
return backward_d[type];
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
#else
|
||||
namespace cpu {
|
||||
static Active<float>::forward forward[] = FLOAT_ACTIVE_FUNCTION;
|
||||
static Active<float>::backward backward[] = FLOAT_ACTIVE_FUNCTION;
|
||||
|
||||
static Active<double>::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION;
|
||||
static Active<double>::backward backward_d[] = DOUBLE_ACTIVE_FUNCTION;
|
||||
|
||||
template <typename T>
|
||||
struct ForwardAct {
|
||||
typename ForwardActType<T>::type operator()(activation_mode_t type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ForwardAct<float> {
|
||||
ForwardActType<float>::type operator()(activation_mode_t type) {
|
||||
return forward[type];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ForwardAct<double> {
|
||||
ForwardActType<double>::type operator()(activation_mode_t type) {
|
||||
return forward_d[type];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BackwardAct {
|
||||
typename BackwardActType<T>::type operator()(activation_mode_t type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BackwardAct<float> {
|
||||
BackwardActType<float>::type operator()(activation_mode_t type) {
|
||||
return backward[type];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BackwardAct<double> {
|
||||
BackwardActType<double>::type operator()(activation_mode_t type) {
|
||||
return backward_d[type];
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cpu
|
||||
|
||||
#ifdef __AVX__
|
||||
namespace avx {
|
||||
static Active<__m256>::forward forward[] = AVX_ACTIVE_FUNCTION;
|
||||
static Active<__m256>::backward backward[] = AVX_ACTIVE_FUNCTION;
|
||||
} // namespace avx
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace hppl
|
||||
|
||||
#endif // HL_ACTIVATION_FUNCTIONS_H_
|
@ -0,0 +1,70 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "hl_functions.h"
|
||||
// TODO(qingqing) refine this dependence
|
||||
#include "paddle/cuda/src/avx_mathfun.h"
|
||||
|
||||
namespace hppl {
|
||||
|
||||
__m256 exp(__m256 a) { return exp256_ps(a); }
|
||||
|
||||
__m256 relu(const __m256 a) {
|
||||
__m256 tmp = _mm256_set1_ps(0.0f);
|
||||
return _mm256_max_ps(a, tmp);
|
||||
}
|
||||
|
||||
__m256 sigmoid(const __m256 a) {
|
||||
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
|
||||
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);
|
||||
__m256 tmp = _mm256_max_ps(a, min);
|
||||
tmp = _mm256_min_ps(tmp, max);
|
||||
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
|
||||
tmp = exp(tmp);
|
||||
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
|
||||
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
__m256 tanh(const __m256 a) {
|
||||
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
|
||||
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
|
||||
tmp = _mm256_min_ps(tmp, max);
|
||||
tmp = exp(tmp);
|
||||
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f),
|
||||
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)),
|
||||
_mm256_set1_ps(1.0f));
|
||||
}
|
||||
|
||||
__m256 linear(const __m256 a) { return a; }
|
||||
|
||||
__m256 relu(const __m256 a, const __m256 b) {
|
||||
return _mm256_mul_ps(
|
||||
a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS),
|
||||
_mm256_set1_ps(1.0f)));
|
||||
}
|
||||
|
||||
__m256 sigmoid(const __m256 a, const __m256 b) {
|
||||
return _mm256_mul_ps(_mm256_mul_ps(a, b),
|
||||
_mm256_sub_ps(_mm256_set1_ps(1.0f), b));
|
||||
}
|
||||
|
||||
__m256 tanh(const __m256 a, const __m256 b) {
|
||||
return _mm256_mul_ps(
|
||||
a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b)));
|
||||
}
|
||||
|
||||
__m256 linear(const __m256 a, const __m256 b) { return a; }
|
||||
} // namespace hppl
|
@ -0,0 +1,32 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifndef HL_AVX_FUNCTIONS_H_
|
||||
#define HL_AVX_FUNCTIONS_H_
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
namespace hppl {
|
||||
__m256 relu(const __m256 a);
|
||||
__m256 sigmoid(const __m256 a);
|
||||
__m256 tanh(const __m256 a);
|
||||
__m256 linear(const __m256 a);
|
||||
|
||||
__m256 relu(const __m256 a, const __m256 b);
|
||||
__m256 sigmoid(const __m256 a, const __m256 b);
|
||||
__m256 tanh(const __m256 a, const __m256 b);
|
||||
__m256 linear(const __m256 a, const __m256 b);
|
||||
} // namespace hppl
|
||||
|
||||
#endif // HL_AVX_FUNCTIONS_H_
|
@ -0,0 +1,89 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <math.h>
|
||||
#include "hl_functions.h"
|
||||
|
||||
namespace hppl {
|
||||
namespace typef {
|
||||
|
||||
float relu(const float a) {
|
||||
return a > static_cast<float>(0.0) ? a : static_cast<float>(0.0);
|
||||
}
|
||||
|
||||
float sigmoid(const float a) {
|
||||
const float min = SIGMOID_THRESHOLD_MIN;
|
||||
const float max = SIGMOID_THRESHOLD_MAX;
|
||||
float tmp = (a < min) ? min : ((a > max) ? max : a);
|
||||
return static_cast<float>(1.0) / (static_cast<float>(1.0) + exp(-tmp));
|
||||
}
|
||||
|
||||
float tanh(const float a) {
|
||||
float tmp = -2.0 * a;
|
||||
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
|
||||
return (2.0 / (1.0 + exp(tmp))) - 1.0;
|
||||
}
|
||||
|
||||
float linear(const float a) { return a; }
|
||||
|
||||
float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); }
|
||||
|
||||
float sigmoid(const float a, const float b) {
|
||||
return a * b * (static_cast<float>(1) - b);
|
||||
}
|
||||
|
||||
float tanh(const float a, const float b) {
|
||||
return a * (static_cast<float>(1) - b * b);
|
||||
}
|
||||
|
||||
float linear(const float a, const float b) { return a; }
|
||||
|
||||
} // namespace typef
|
||||
|
||||
namespace typed {
|
||||
double relu(const double a) {
|
||||
return a > static_cast<double>(0.0) ? a : static_cast<double>(0.0);
|
||||
}
|
||||
|
||||
double sigmoid(const double a) {
|
||||
const double min = SIGMOID_THRESHOLD_MIN;
|
||||
const double max = SIGMOID_THRESHOLD_MAX;
|
||||
double tmp = (a < min) ? min : ((a > max) ? max : a);
|
||||
return static_cast<double>(1.0) / (static_cast<double>(1.0) + exp(-tmp));
|
||||
}
|
||||
|
||||
double tanh(const double a) {
|
||||
double tmp = -2.0 * a;
|
||||
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
|
||||
return (2.0 / (1.0 + exp(tmp))) - 1.0;
|
||||
}
|
||||
|
||||
double linear(const double a) { return a; }
|
||||
|
||||
double relu(const double a, const double b) {
|
||||
return a * (b > 0.0 ? 1.0 : 0.0);
|
||||
}
|
||||
|
||||
double sigmoid(const double a, const double b) {
|
||||
return a * b * (static_cast<double>(1) - b);
|
||||
}
|
||||
|
||||
double tanh(const double a, const double b) {
|
||||
return a * (static_cast<double>(1) - b * b);
|
||||
}
|
||||
|
||||
double linear(const double a, const double b) { return a; }
|
||||
|
||||
} // namespace typed
|
||||
} // namespace hppl
|
@ -0,0 +1,71 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifndef HL_FUNCTIONS_H_
|
||||
#define HL_FUNCTIONS_H_
|
||||
|
||||
/**
|
||||
* sigmoid threshold maximum
|
||||
*/
|
||||
#define SIGMOID_THRESHOLD_MIN -40.0
|
||||
|
||||
/**
|
||||
* sigmoid threshold minimum
|
||||
*/
|
||||
#define SIGMOID_THRESHOLD_MAX 13.0
|
||||
|
||||
/**
|
||||
* The maximum input value for exp, used to avoid overflow problem.
|
||||
* currently only used for tanh function.
|
||||
*/
|
||||
#define EXP_MAX_INPUT 40.0
|
||||
|
||||
#ifndef __NVCC__
|
||||
namespace hppl {
|
||||
namespace typef {
|
||||
float relu(const float a);
|
||||
float sigmoid(const float a);
|
||||
float tanh(const float a);
|
||||
float linear(const float a);
|
||||
|
||||
float relu(const float a, const float b);
|
||||
float sigmoid(const float a, const float b);
|
||||
float tanh(const float a, const float b);
|
||||
float linear(const float a, const float b);
|
||||
|
||||
} // namespace typef
|
||||
|
||||
namespace typed {
|
||||
double relu(const double a);
|
||||
double sigmoid(const double a);
|
||||
double tanh(const double a);
|
||||
double linear(const double a);
|
||||
|
||||
double relu(const double a, const double b);
|
||||
double sigmoid(const double a, const double b);
|
||||
double tanh(const double a, const double b);
|
||||
double linear(const double a, const double b);
|
||||
} // namespace typed
|
||||
|
||||
} // namespace hppl
|
||||
|
||||
#ifdef __AVX__
|
||||
#include "hl_avx_functions.h"
|
||||
#endif
|
||||
|
||||
#else
|
||||
#include "hl_gpu_functions.h"
|
||||
#endif
|
||||
|
||||
#endif // HL_FUNCTIONS_H_
|
@ -0,0 +1,93 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifndef HL_GPU_FUNCTIONS_CUH_
|
||||
#define HL_GPU_FUNCTIONS_CUH_
|
||||
|
||||
#include "hl_base.h"
|
||||
|
||||
namespace hppl {
|
||||
namespace typef {
|
||||
|
||||
__device__ static float relu(const float a) { return a > 0.0f ? a : 0.0f; }
|
||||
|
||||
__device__ static float sigmoid(const float a) {
|
||||
const float min = SIGMOID_THRESHOLD_MIN;
|
||||
const float max = SIGMOID_THRESHOLD_MAX;
|
||||
float tmp = (a < min) ? min : ((a > max) ? max : a);
|
||||
return __fdividef(1.0f, 1.0f + __expf(-tmp));
|
||||
}
|
||||
|
||||
__device__ static float tanh(const float a) {
|
||||
float tmp = -2.0 * a;
|
||||
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
|
||||
return __fdividef(2.0f, (1.0f + __expf(-2.0f * tmp))) - 1.0f;
|
||||
}
|
||||
|
||||
__device__ static float linear(const float a) { return a; }
|
||||
|
||||
__device__ static float relu(const float a, const float b) {
|
||||
return a * (b > 0.0f ? 1.0f : 0.0f);
|
||||
}
|
||||
|
||||
__device__ static float sigmoid(const float a, const float b) {
|
||||
return a * b * (1.0f - b);
|
||||
}
|
||||
|
||||
__device__ static float tanh(const float a, const float b) {
|
||||
return a * (1.0f - b * b);
|
||||
}
|
||||
|
||||
__device__ static float linear(const float a, const float b) { return a; }
|
||||
|
||||
} // namespace typef
|
||||
|
||||
namespace typed {
|
||||
|
||||
__device__ static double relu(const double a) { return a > 0.0 ? a : 0.0; }
|
||||
|
||||
__device__ static double sigmoid(const double a) {
|
||||
const double min = SIGMOID_THRESHOLD_MIN;
|
||||
const double max = SIGMOID_THRESHOLD_MAX;
|
||||
double tmp = (a < min) ? min : ((a > max) ? max : a);
|
||||
return 1.0 / (1.0 + exp(-tmp));
|
||||
}
|
||||
|
||||
__device__ static double tanh(const double a) {
|
||||
double tmp = -2.0 * a;
|
||||
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
|
||||
return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0;
|
||||
}
|
||||
|
||||
__device__ static double linear(const double a) { return a; }
|
||||
|
||||
__device__ static double relu(const double a, const double b) {
|
||||
return a * (b > 0.0 ? 1.0 : 0.0);
|
||||
}
|
||||
|
||||
__device__ static double sigmoid(const double a, const double b) {
|
||||
return a * b * (1 - b);
|
||||
}
|
||||
|
||||
__device__ static double tanh(const double a, const double b) {
|
||||
return a * (1.0 - b * b);
|
||||
}
|
||||
|
||||
__device__ static double linear(const double a, const double b) { return a; }
|
||||
|
||||
} // namespace typef
|
||||
|
||||
} // namespace hppl
|
||||
|
||||
#endif // HL_GPU_FUNCTIONS_CUH_
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,138 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/detail/hl_activation_functions.h"
|
||||
#include "paddle/platform/hostdevice.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace detail {
|
||||
|
||||
namespace forward {
|
||||
|
||||
template <class T>
|
||||
class lstm {
|
||||
public:
|
||||
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
||||
T &prevState, T &state, T &stateAtv, T &output,
|
||||
T &checkI, T &checkF, T &checkO,
|
||||
typename hppl::ForwardActType<T>::type actInput,
|
||||
typename hppl::ForwardActType<T>::type actGate,
|
||||
typename hppl::ForwardActType<T>::type actState) {
|
||||
valueIn = actInput(valueIn);
|
||||
valueIg = actGate(valueIg + prevState * checkI);
|
||||
valueFg = actGate(valueFg + prevState * checkF);
|
||||
state = valueIn * valueIg + prevState * valueFg;
|
||||
valueOg = actGate(valueOg + state * checkO);
|
||||
stateAtv = actState(state);
|
||||
output = valueOg * stateAtv;
|
||||
}
|
||||
#ifndef __NVCC__
|
||||
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
|
||||
static const bool avx = false;
|
||||
#else
|
||||
// Only float support AVX optimization
|
||||
static const bool avx = std::is_same<T, float>::value;
|
||||
|
||||
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
|
||||
__m256 &valueOg, __m256 &prevState, __m256 &state,
|
||||
__m256 &stateAtv, __m256 &output, __m256 &checkI,
|
||||
__m256 &checkF, __m256 &checkO,
|
||||
hppl::Active<__m256>::forward actInput,
|
||||
hppl::Active<__m256>::forward actGate,
|
||||
hppl::Active<__m256>::forward actState) {
|
||||
valueIn = actInput(valueIn);
|
||||
valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)));
|
||||
valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)));
|
||||
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg),
|
||||
_mm256_mul_ps(prevState, valueFg));
|
||||
valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)));
|
||||
stateAtv = actState(state);
|
||||
output = _mm256_mul_ps(valueOg, stateAtv);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace forward
|
||||
|
||||
namespace backward {
|
||||
|
||||
template <class T>
|
||||
class lstm {
|
||||
public:
|
||||
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
||||
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
|
||||
T &prevState, T &prevStateGrad, T &state,
|
||||
T &stateGrad, T &stateAtv, T &outputGrad,
|
||||
T &checkI, T &checkF, T &checkO, T &checkIGrad,
|
||||
T &checkFGrad, T &checkOGrad,
|
||||
typename hppl::BackwardActType<T>::type actInput,
|
||||
typename hppl::BackwardActType<T>::type actGate,
|
||||
typename hppl::BackwardActType<T>::type actState) {
|
||||
gradOg = actGate(outputGrad * stateAtv, valueOg);
|
||||
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
|
||||
gradIn = actInput(stateGrad * valueIg, valueIn);
|
||||
gradIg = actGate(stateGrad * valueIn, valueIg);
|
||||
gradFg = actGate(stateGrad * prevState, valueFg);
|
||||
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
|
||||
checkIGrad = gradIg * prevState;
|
||||
checkFGrad = gradFg * prevState;
|
||||
checkOGrad = gradOg * state;
|
||||
}
|
||||
#ifndef __NVCC__
|
||||
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
|
||||
static const bool avx = false;
|
||||
#else
|
||||
// Only float support AVX optimization
|
||||
static const bool avx = std::is_same<T, float>::value;
|
||||
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
|
||||
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
|
||||
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
|
||||
__m256 &prevStateGrad, __m256 &state,
|
||||
__m256 &stateGrad, __m256 &stateAtv,
|
||||
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
|
||||
__m256 &checkO, __m256 &checkIGrad,
|
||||
__m256 &checkFGrad, __m256 &checkOGrad,
|
||||
hppl::Active<__m256>::backward actInput,
|
||||
hppl::Active<__m256>::backward actGate,
|
||||
hppl::Active<__m256>::backward actState) {
|
||||
gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg);
|
||||
stateGrad = _mm256_add_ps(
|
||||
actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad);
|
||||
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad);
|
||||
gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn);
|
||||
gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg);
|
||||
gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg);
|
||||
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI),
|
||||
_mm256_mul_ps(gradFg, checkF));
|
||||
prevStateGrad =
|
||||
_mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad);
|
||||
checkIGrad = _mm256_mul_ps(gradIg, prevState);
|
||||
checkFGrad = _mm256_mul_ps(gradFg, prevState);
|
||||
checkOGrad = _mm256_mul_ps(gradOg, state);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace backward
|
||||
|
||||
} // namespace detail
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,82 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/lstm_compute.h"
|
||||
#include "paddle/operators/math/detail/lstm_cpu_kernel.h"
|
||||
#include "paddle/operators/math/detail/lstm_kernel.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <class T>
|
||||
struct LstmUnitFunctor<platform::CPUPlace, T> {
|
||||
static void compute(const platform::DeviceContext& context,
|
||||
LstmMetaValue<T> value, int frame_size, int batch_size,
|
||||
const std::string& gate_act, const std::string& cell_act,
|
||||
const std::string& cand_act) {
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
|
||||
ActiveType(cand_act), ActiveType(gate_act),
|
||||
ActiveType(cell_act));
|
||||
value.gateValue += frame_size * 4;
|
||||
value.stateValue += frame_size;
|
||||
value.stateActiveValue += frame_size;
|
||||
value.outputValue += frame_size;
|
||||
if (value.prevStateValue) {
|
||||
value.prevStateValue += frame_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct LstmUnitGradFunctor<platform::CPUPlace, T> {
|
||||
static void compute(const platform::DeviceContext& context,
|
||||
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
|
||||
int frame_size, int batch_size,
|
||||
const std::string& gate_act, const std::string& cell_act,
|
||||
const std::string& cand_act) {
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
|
||||
frame_size, ActiveType(cand_act),
|
||||
ActiveType(gate_act), ActiveType(cell_act));
|
||||
|
||||
value.gateValue += frame_size * 4;
|
||||
value.stateValue += frame_size;
|
||||
value.stateActiveValue += frame_size;
|
||||
value.outputValue += frame_size;
|
||||
if (value.prevStateValue) {
|
||||
value.prevStateValue += frame_size;
|
||||
}
|
||||
|
||||
grad.gateGrad += frame_size * 4;
|
||||
grad.stateGrad += frame_size;
|
||||
grad.stateActiveGrad += frame_size;
|
||||
grad.outputGrad += frame_size;
|
||||
if (grad.prevStateGrad) {
|
||||
grad.prevStateGrad += frame_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class LstmUnitFunctor<platform::CPUPlace, float>;
|
||||
template class LstmUnitFunctor<platform::CPUPlace, double>;
|
||||
template class LstmUnitGradFunctor<platform::CPUPlace, float>;
|
||||
template class LstmUnitGradFunctor<platform::CPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,55 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/detail/lstm_gpu_kernel.h"
|
||||
#include "paddle/operators/math/detail/lstm_kernel.h"
|
||||
#include "paddle/operators/math/lstm_compute.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <class T>
|
||||
struct LstmUnitFunctor<platform::GPUPlace, T> {
|
||||
static void compute(const platform::DeviceContext& context,
|
||||
LstmMetaValue<T> value, int frame_size, int batch_size,
|
||||
const std::string& gate_act, const std::string& cell_act,
|
||||
const std::string& cand_act) {
|
||||
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
|
||||
frame_size, batch_size, ActiveType(cand_act),
|
||||
ActiveType(gate_act), ActiveType(cell_act));
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct LstmUnitGradFunctor<platform::GPUPlace, T> {
|
||||
static void compute(const platform::DeviceContext& context,
|
||||
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
|
||||
int frame_size, int batch_size,
|
||||
const std::string& gate_act, const std::string& cell_act,
|
||||
const std::string& cand_act) {
|
||||
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
|
||||
frame_size, batch_size, ActiveType(cand_act),
|
||||
ActiveType(gate_act), ActiveType(cell_act));
|
||||
}
|
||||
};
|
||||
|
||||
template class LstmUnitFunctor<platform::GPUPlace, float>;
|
||||
template class LstmUnitFunctor<platform::GPUPlace, double>;
|
||||
template class LstmUnitGradFunctor<platform::GPUPlace, float>;
|
||||
template class LstmUnitGradFunctor<platform::GPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/platform/device_context.h"
|
||||
#include "paddle/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
typedef enum {
|
||||
HL_ACTIVATION_SIGMOID = 0,
|
||||
HL_ACTIVATION_RELU = 1,
|
||||
HL_ACTIVATION_TANH = 2,
|
||||
HL_ACTIVATION_LINEAR = 3,
|
||||
HL_ACTIVATION_END
|
||||
} activation_mode_t;
|
||||
|
||||
template <class T>
|
||||
struct LstmMetaValue {
|
||||
T *gateValue;
|
||||
T *prevStateValue;
|
||||
T *stateValue;
|
||||
T *stateActiveValue;
|
||||
T *outputValue;
|
||||
T *checkIg;
|
||||
T *checkFg;
|
||||
T *checkOg;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct LstmMetaGrad {
|
||||
T *gateGrad;
|
||||
T *prevStateGrad;
|
||||
T *stateGrad;
|
||||
T *stateActiveGrad;
|
||||
T *outputGrad;
|
||||
T *checkIgGrad;
|
||||
T *checkFgGrad;
|
||||
T *checkOgGrad;
|
||||
};
|
||||
|
||||
inline activation_mode_t ActiveType(const std::string &type) {
|
||||
if (type == "sigmoid") {
|
||||
return HL_ACTIVATION_SIGMOID;
|
||||
} else if (type == "relu") {
|
||||
return HL_ACTIVATION_RELU;
|
||||
} else if (type == "tanh") {
|
||||
return HL_ACTIVATION_TANH;
|
||||
} else if (type == "linear" || type == "identity" || type == "") {
|
||||
return HL_ACTIVATION_LINEAR;
|
||||
} else {
|
||||
PADDLE_THROW("Do not support activation type.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LstmUnitFunctor {
|
||||
public:
|
||||
static void compute(const platform::DeviceContext &context,
|
||||
LstmMetaValue<T> value, int frame_size, int batch_size,
|
||||
const std::string &gate_act, const std::string &cell_act,
|
||||
const std::string &cand_act);
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LstmUnitGradFunctor {
|
||||
public:
|
||||
static void compute(const platform::DeviceContext &context,
|
||||
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
|
||||
int frame_size, int batch_size,
|
||||
const std::string &gate_act, const std::string &cell_act,
|
||||
const std::string &cand_act);
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,61 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/sequence2batch.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::LoDTensor& src, const size_t* index,
|
||||
framework::LoDTensor& dst, bool is_src_index) {
|
||||
auto src_dims = src.dims();
|
||||
auto dst_dims = dst.dims();
|
||||
PADDLE_ENFORCE_EQ(src_dims.size(), 2UL,
|
||||
"The src must be matrix with rank 2.");
|
||||
PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL,
|
||||
"The dst must be matrix with rank 2.");
|
||||
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1],
|
||||
"The width of src and dst must be same.");
|
||||
auto height = dst_dims[0];
|
||||
auto width = dst_dims[1];
|
||||
auto* src_data = src.data<T>();
|
||||
auto* dst_data = dst.data<T>();
|
||||
for (int i = 0; i < height; ++i) {
|
||||
if (is_src_index) {
|
||||
memcpy(dst_data + i * width, src_data + index[i] * width,
|
||||
width * sizeof(T));
|
||||
} else {
|
||||
memcpy(dst_data + index[i] * width, src_data + i * width,
|
||||
width * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class CopyMatrixRowsFunctor<platform::CPUPlace, float>;
|
||||
template class CopyMatrixRowsFunctor<platform::CPUPlace, double>;
|
||||
|
||||
template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
|
||||
template class LoDTensor2BatchFunctor<platform::CPUPlace, double>;
|
||||
template class Batch2LoDTensorFunctor<platform::CPUPlace, float>;
|
||||
template class Batch2LoDTensorFunctor<platform::CPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,78 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/sequence2batch.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
|
||||
__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index,
|
||||
int64_t height, int64_t width,
|
||||
bool is_src_index) {
|
||||
int idx = threadIdx.x;
|
||||
int idy = threadIdx.y;
|
||||
int id = blockIdx.x + idy * GridDimX;
|
||||
while (id < height) {
|
||||
int src_idx = is_src_index ? index[id] : id;
|
||||
int dst_idx = is_src_index ? id : index[id];
|
||||
const T* src_data = src + src_idx * width;
|
||||
T* dst_data = dst + dst_idx * width;
|
||||
for (int i = idx; i < width; i += BlockDimX) {
|
||||
dst_data[i] = src_data[i];
|
||||
}
|
||||
id += BlockDimY * GridDimX;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class CopyMatrixRowsFunctor<platform::GPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::LoDTensor& src, const size_t* index,
|
||||
framework::LoDTensor& dst, bool is_src_index) {
|
||||
auto src_dims = src.dims();
|
||||
auto dst_dims = dst.dims();
|
||||
PADDLE_ENFORCE_EQ(src_dims.size(), 2,
|
||||
"The src must be matrix with rank 2.");
|
||||
PADDLE_ENFORCE_EQ(dst_dims.size(), 2,
|
||||
"The dst must be matrix with rank 2.");
|
||||
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1],
|
||||
"The width of src and dst must be same.");
|
||||
auto height = dst_dims[0];
|
||||
auto width = dst_dims[1];
|
||||
auto* src_data = src.data<T>();
|
||||
auto* dst_data = dst.data<T>();
|
||||
|
||||
dim3 threads(128, 8);
|
||||
dim3 grid(8, 1);
|
||||
auto stream =
|
||||
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
|
||||
CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>(
|
||||
src_data, dst_data, index, height, width, is_src_index);
|
||||
}
|
||||
};
|
||||
|
||||
template class CopyMatrixRowsFunctor<platform::GPUPlace, float>;
|
||||
template class CopyMatrixRowsFunctor<platform::GPUPlace, double>;
|
||||
|
||||
template class LoDTensor2BatchFunctor<platform::GPUPlace, float>;
|
||||
template class LoDTensor2BatchFunctor<platform::GPUPlace, double>;
|
||||
template class Batch2LoDTensorFunctor<platform::GPUPlace, float>;
|
||||
template class Batch2LoDTensorFunctor<platform::GPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,148 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename Place, typename T>
|
||||
class CopyMatrixRowsFunctor {
|
||||
public:
|
||||
// If is_src_index is true,
|
||||
// copy the indexed rows of input src to the output dst.
|
||||
// If is_src_index is false,
|
||||
// copy the input src to the indexed rows of output dst.
|
||||
// The indexed rows are based on the input index.
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::LoDTensor& src, const size_t* index,
|
||||
framework::LoDTensor& dst, bool is_src_index);
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LoDTensor2BatchFunctor {
|
||||
// Calculate the length of each sequence and
|
||||
// sort sequence index by the length.
|
||||
// example: sequences = {s0, s1, s2}
|
||||
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
|
||||
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
|
||||
//
|
||||
struct SeqInfo {
|
||||
SeqInfo(int start, int length, int seq_idx)
|
||||
: start(start), length(length), seq_idx(seq_idx) {}
|
||||
int start;
|
||||
int length;
|
||||
int seq_idx;
|
||||
};
|
||||
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::LoDTensor& lod_tensor,
|
||||
framework::LoDTensor& batch, bool is_reverse) const {
|
||||
auto lods = lod_tensor.lod();
|
||||
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
|
||||
auto lod = lods[0];
|
||||
|
||||
std::vector<SeqInfo> seq_info;
|
||||
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
|
||||
int length = lod[seq_id + 1] - lod[seq_id];
|
||||
seq_info.emplace_back(lod[seq_id], length, seq_id);
|
||||
}
|
||||
|
||||
std::sort(seq_info.begin(), seq_info.end(),
|
||||
[](SeqInfo a, SeqInfo b) { return a.length > b.length; });
|
||||
|
||||
// calculate the start position of each batch
|
||||
// (numBatch equal the maxLength of sequences)
|
||||
// example: sequences = {s0, s1, s2}
|
||||
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
|
||||
// num_batch = 5,
|
||||
// batchIndex = {b0, b1, b2, b3, b4}
|
||||
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
|
||||
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
|
||||
// batch_start_positions[0] = len(b0)
|
||||
// batch_start_positions[1] = len(b0) + len(b1)
|
||||
// batch_start_positions[2] = len(b0) + len(b1) + len(b2)
|
||||
// ...
|
||||
// seq2batch_idx[12] = {4, 0, 9,
|
||||
// 5, 1, 10,
|
||||
// 6, 2, 11,
|
||||
// 7, 3,
|
||||
// 8}
|
||||
// The batch number represents batch size after rearranging the
|
||||
// input LodTensor. It is also the maximum length of input sequence.
|
||||
|
||||
paddle::framework::LoD batch_lods;
|
||||
batch_lods.emplace_back(std::vector<size_t>{0});
|
||||
batch_lods.emplace_back(std::vector<size_t>{0});
|
||||
|
||||
// batch_lods[0] is the start positions for batch LoDTensor
|
||||
int num_batch = seq_info[0].length;
|
||||
batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
|
||||
// batch_lods[1] is the raw index in the input LoDTensor
|
||||
auto dims = lod_tensor.dims();
|
||||
batch_lods[1].resize(static_cast<size_t>(dims[0]));
|
||||
|
||||
size_t* batch_starts = batch_lods[0].data();
|
||||
size_t* seq2batch_idx = batch_lods[1].data();
|
||||
batch_starts[0] = 0;
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
auto batch_id = static_cast<int>(batch_starts[n]);
|
||||
for (size_t i = 0; i < seq_info.size(); ++i) {
|
||||
size_t seq_len = seq_info[i].length;
|
||||
int start = seq_info[i].start;
|
||||
if (n < seq_len) {
|
||||
seq2batch_idx[batch_id] =
|
||||
is_reverse ? start + seq_len - 1 - n : start + n;
|
||||
batch_id++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
batch_starts[n + 1] = static_cast<size_t>(batch_id);
|
||||
}
|
||||
batch.set_lod(batch_lods);
|
||||
|
||||
CopyMatrixRowsFunctor<Place, T> to_batch;
|
||||
to_batch(context, lod_tensor, seq2batch_idx, batch, true);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class Batch2LoDTensorFunctor {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::LoDTensor& batch,
|
||||
framework::LoDTensor& lod_tensor) const {
|
||||
auto in_lod = batch.lod();
|
||||
PADDLE_ENFORCE_EQ(in_lod.size(), 2UL,
|
||||
"The LoD size of input `batch` should be 2.");
|
||||
auto out_lod = lod_tensor.lod()[0];
|
||||
auto num = out_lod[out_lod.size() - 1];
|
||||
PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]);
|
||||
PADDLE_ENFORCE_EQ(num, in_lod[1].size());
|
||||
PADDLE_ENFORCE_EQ(num, batch.dims()[0]);
|
||||
CopyMatrixRowsFunctor<Place, T> to_seq;
|
||||
size_t* index = in_lod[1].data();
|
||||
to_seq(context, batch, index, lod_tensor, false);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,185 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
SIGMOID_THRESHOLD_MIN = -40.0
|
||||
SIGMOID_THRESHOLD_MAX = 13.0
|
||||
EXP_MAX_INPUT = 40.0
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
y = np.copy(x)
|
||||
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
|
||||
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
|
||||
return 1. / (1. + np.exp(-y))
|
||||
|
||||
|
||||
def tanh(x):
|
||||
y = -2. * x
|
||||
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
|
||||
return (2. / (1. + np.exp(y))) - 1.
|
||||
|
||||
|
||||
def relu(x):
|
||||
return np.maximum(x, 0)
|
||||
|
||||
|
||||
ACTVATION = {
|
||||
'identity': identity,
|
||||
'sigmoid': sigmoid,
|
||||
'tanh': tanh,
|
||||
'relu': relu
|
||||
}
|
||||
|
||||
|
||||
def lstm(
|
||||
input, # T x 4D
|
||||
lod, # 1 x N
|
||||
h0=None, # N x D
|
||||
c0=None, # N x D
|
||||
w_h=None, # D x 4D
|
||||
w_b=None, # 1 x 4D
|
||||
w_c=None, # 1 x 3D
|
||||
is_reverse=False,
|
||||
act_gate=None,
|
||||
act_cell=None,
|
||||
act_cand=None):
|
||||
def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
|
||||
g = np.dot(h_pre, w_h) # 1 x 4D
|
||||
g = g + x
|
||||
g = np.reshape(g, (1, g.size))
|
||||
c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1)
|
||||
if w_c is None:
|
||||
g_i = act_gate(g_i) # 1 x D
|
||||
g_f = act_gate(g_f) # 1 x D
|
||||
else:
|
||||
w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
|
||||
g_i = act_gate(g_i + w_ic * c_pre) # 1 x D
|
||||
g_f = act_gate(g_f + w_fc * c_pre) # 1 x D
|
||||
c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D
|
||||
|
||||
if w_c is None:
|
||||
g_o = act_gate(g_o) # 1 x D
|
||||
else:
|
||||
_, _, w_oc = np.split(w_c, 3, axis=1)
|
||||
g_o = act_gate(g_o + w_oc * c) # 1 x D
|
||||
h = g_o * act_cell(c)
|
||||
bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1)
|
||||
return h, c, bg
|
||||
|
||||
def _reverse(x, lod):
|
||||
y = np.zeros_like(x)
|
||||
for i in range(len(lod) - 1):
|
||||
b, e = lod[i], lod[i + 1]
|
||||
y[b:e, :] = np.flip(x[b:e, :], 0)
|
||||
return y
|
||||
|
||||
offset = lod[0]
|
||||
batch_size = len(offset) - 1
|
||||
hidden = []
|
||||
cell = []
|
||||
gate = []
|
||||
input = _reverse(input, offset) if is_reverse else input
|
||||
if w_b is not None:
|
||||
input = input + np.tile(w_b, (offset[-1], 1))
|
||||
for i in range(batch_size):
|
||||
# compute one sequence
|
||||
seq_len = offset[i + 1] - offset[i]
|
||||
x = input[offset[i]:offset[i + 1], :]
|
||||
h_pre = h0[i] # 1 x D
|
||||
c_pre = c0[i] # 1 x D
|
||||
for j in range(seq_len):
|
||||
# compute one step
|
||||
h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
|
||||
act_cell, act_cand)
|
||||
hidden.append(h_pre.flatten())
|
||||
cell.append(c_pre.flatten())
|
||||
gate.append(g_pre.flatten())
|
||||
|
||||
hidden = np.array(hidden).astype("float64")
|
||||
cell = np.array(cell).astype("float64")
|
||||
gate = np.array(gate).astype("float64")
|
||||
|
||||
hidden = _reverse(hidden, offset) if is_reverse else hidden
|
||||
cell = _reverse(cell, offset) if is_reverse else cell
|
||||
|
||||
assert gate.shape == input.shape
|
||||
assert hidden.shape == (input.shape[0], input.shape[1] / 4)
|
||||
assert cell.shape == (input.shape[0], input.shape[1] / 4)
|
||||
return hidden, cell, gate
|
||||
|
||||
|
||||
class TestLstmOp(OpTest):
|
||||
def set_data(self):
|
||||
self.lod = [[0, 2, 6, 9]]
|
||||
self.D = 64
|
||||
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
|
||||
|
||||
self.act_gate = "sigmoid"
|
||||
self.act_cell = "tanh"
|
||||
self.act_cand = "tanh"
|
||||
|
||||
self.is_reverse = False
|
||||
|
||||
def setUp(self):
|
||||
self.set_data()
|
||||
self.op_type = "lstm"
|
||||
|
||||
T = self.lod[0][-1]
|
||||
N = len(self.lod[0]) - 1
|
||||
|
||||
x = np.random.normal(size=(T, 4 * self.D)).astype("float64")
|
||||
h0 = np.zeros((N, self.D)).astype("float64")
|
||||
c0 = np.zeros((N, self.D)).astype("float64")
|
||||
w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64")
|
||||
b = np.random.normal(size=(1, 7 * self.D)).astype("float64")
|
||||
|
||||
w_b = b[:, 0:4 * self.D]
|
||||
w_c = b[:, 4 * self.D:]
|
||||
h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
|
||||
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
|
||||
ACTVATION[self.act_cand])
|
||||
|
||||
g_sort = np.zeros_like(x)
|
||||
for i, j in enumerate(self.sort_idx):
|
||||
g_sort[i, :] = g[j, :]
|
||||
|
||||
self.inputs = {
|
||||
'Input': (x, self.lod),
|
||||
'H0': h0,
|
||||
'C0': c0,
|
||||
'Weight': w,
|
||||
'Bias': b
|
||||
}
|
||||
self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort}
|
||||
self.attrs = {
|
||||
'usePeepholes': True,
|
||||
'isReverse': self.is_reverse,
|
||||
'gateActivation': 'sigmoid',
|
||||
'cellActivation': 'tanh',
|
||||
'candidateActivation': 'tanh'
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestLstmOpRerverse(TestLstmOp):
|
||||
def set_data(self):
|
||||
self.lod = [[0, 2, 6, 9]]
|
||||
self.D = 64
|
||||
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
|
||||
|
||||
self.act_gate = "sigmoid"
|
||||
self.act_cell = "tanh"
|
||||
self.act_cand = "tanh"
|
||||
|
||||
self.is_reverse = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue