commit
f698a49ce3
@ -0,0 +1,103 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/lstm_unit_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LstmUnitOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
|
||||
"Input(X) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"),
|
||||
"Input(C_prev) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"),
|
||||
"Output(C) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"),
|
||||
"Output(H) of LSTM should not be null.");
|
||||
|
||||
auto *x = ctx.Input<framework::Tensor>("X");
|
||||
auto *c_prev = ctx.Input<framework::Tensor>("C_prev");
|
||||
|
||||
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
|
||||
PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0],
|
||||
"Batch size of inputs and states must be equal");
|
||||
PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4,
|
||||
"Dimension of FC should equal to prev state * 4");
|
||||
|
||||
int b_size = c_prev->dims()[0]; // batch size
|
||||
int s_dim = c_prev->dims()[1]; // state dim
|
||||
ctx.Output<framework::LoDTensor>("C")->Resize({b_size, s_dim});
|
||||
ctx.Output<framework::LoDTensor>("H")->Resize({b_size, s_dim});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AttrType>
|
||||
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
LstmUnitOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "FC input before the non-linear activation.");
|
||||
AddInput(
|
||||
"C_prev",
|
||||
"The cell state tensor of last time-step in the Lstm Unit operator.");
|
||||
AddOutput("C", "The cell tensor of Lstm Unit operator.");
|
||||
AddOutput("H", "The hidden state tensor of Lstm Unit operator.");
|
||||
|
||||
AddComment(R"DOC(Lstm-Unit Operator
|
||||
|
||||
Equation:
|
||||
i, f, o, j = split(X)
|
||||
C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j)
|
||||
H = C * sigm(o)
|
||||
|
||||
)DOC");
|
||||
AddAttr<AttrType>("forget_bias", "The forget bias of Lstm Unit.")
|
||||
.SetDefault(0.0);
|
||||
}
|
||||
};
|
||||
|
||||
class LstmUnitGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")),
|
||||
"Input(C@GRAD) should not be null");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")),
|
||||
"Input(H@GRAD) should not be null");
|
||||
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
|
||||
->Resize(ctx.Input<Tensor>("X")->dims());
|
||||
ctx.Output<framework::LoDTensor>(framework::GradVarName("C_prev"))
|
||||
->Resize(ctx.Input<Tensor>("C_prev")->dims());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker<float>,
|
||||
lstm_unit_grad, ops::LstmUnitGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(lstm_unit,
|
||||
ops::LstmUnitKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
lstm_unit_grad, ops::LstmUnitGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,173 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/cross_entropy_op.h"
|
||||
#include "paddle/platform/assert.h"
|
||||
#include "paddle/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename Dtype>
|
||||
__device__ Dtype cuda_sigmoid(const Dtype x) {
|
||||
return Dtype(1) / (Dtype(1) + exp(-x));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__device__ Dtype cuda_tanh(const Dtype x) {
|
||||
return Dtype(1 - exp(-2. * x)) / (Dtype(1) + exp(-2. * x));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LSTMUnitKernel(const int nthreads, const int dim,
|
||||
const T* C_prev, const T* X, T* C, T* H,
|
||||
const T forget_bias) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
const int n = index / dim;
|
||||
const int d = index % dim;
|
||||
|
||||
const T* X_offset = X + 4 * dim * n;
|
||||
const T i = cuda_sigmoid(X_offset[d]);
|
||||
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
|
||||
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
|
||||
const T g = cuda_tanh(X_offset[3 * dim + d]);
|
||||
const T c_prev = C_prev[index];
|
||||
const T c = f * c_prev + i * g;
|
||||
C[index] = c;
|
||||
const T tanh_c = cuda_tanh(c);
|
||||
H[index] = o * tanh_c;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
|
||||
const T* C_prev, const T* X, const T* C,
|
||||
const T* H, const T* C_diff,
|
||||
const T* H_diff, T* C_prev_diff,
|
||||
T* X_diff, const T forget_bias) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
const int n = index / dim;
|
||||
const int d = index % dim;
|
||||
const T* X_offset = X + 4 * dim * n;
|
||||
T* c_prev_diff = C_prev_diff + index;
|
||||
T* X_diff_offset = X_diff + 4 * dim * n;
|
||||
T* i_diff = X_diff_offset + d;
|
||||
T* f_diff = X_diff_offset + 1 * dim + d;
|
||||
T* o_diff = X_diff_offset + 2 * dim + d;
|
||||
T* g_diff = X_diff_offset + 3 * dim + d;
|
||||
|
||||
const T i = cuda_sigmoid(X_offset[d]);
|
||||
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
|
||||
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
|
||||
const T g = cuda_tanh(X_offset[3 * dim + d]);
|
||||
const T c_prev = C_prev[index];
|
||||
const T c = C[index];
|
||||
const T tanh_c = cuda_tanh(c);
|
||||
const T c_term_diff =
|
||||
C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c);
|
||||
*c_prev_diff = c_term_diff * f;
|
||||
*i_diff = c_term_diff * g * i * (1 - i);
|
||||
*f_diff = c_term_diff * c_prev * f * (1 - f);
|
||||
*o_diff = H_diff[index] * tanh_c * o * (1 - o);
|
||||
*g_diff = c_term_diff * i * (1 - g * g);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AttrType = T>
|
||||
class LstmUnitOpCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"It must use GPUPlace.");
|
||||
|
||||
auto* x_tensor = ctx.Input<framework::Tensor>("X");
|
||||
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
|
||||
auto* c_tensor = ctx.Output<framework::Tensor>("C");
|
||||
auto* h_tensor = ctx.Output<framework::Tensor>("H");
|
||||
|
||||
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||
|
||||
int b_size = c_tensor->dims()[0];
|
||||
int D = c_tensor->dims()[1];
|
||||
|
||||
const T* X = x_tensor->data<T>();
|
||||
const T* C_prev = c_prev_tensor->data<T>();
|
||||
|
||||
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int block = 512;
|
||||
int n = b_size * D;
|
||||
int grid = (n + block - 1) / block;
|
||||
|
||||
LSTMUnitKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, forget_bias);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename AttrType = T>
|
||||
class LstmUnitGradOpCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"It must use GPUPlace.");
|
||||
|
||||
auto x_tensor = ctx.Input<Tensor>("X");
|
||||
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
|
||||
auto c_tensor = ctx.Input<Tensor>("C");
|
||||
auto h_tensor = ctx.Input<Tensor>("H");
|
||||
|
||||
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
|
||||
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
|
||||
|
||||
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto c_prev_diff_tensor =
|
||||
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
|
||||
|
||||
auto* X = x_tensor->data<T>();
|
||||
auto* C_prev = c_prev_tensor->data<T>();
|
||||
auto* C = c_tensor->data<T>();
|
||||
auto* H = h_tensor->data<T>();
|
||||
|
||||
auto* H_diff = hdiff_tensor->data<T>();
|
||||
auto* C_diff = cdiff_tensor->data<T>();
|
||||
|
||||
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int N = c_tensor->dims()[0];
|
||||
int D = c_tensor->dims()[1];
|
||||
|
||||
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||
|
||||
int block = 512;
|
||||
int n = N * D;
|
||||
int grid = (n + block - 1) / block;
|
||||
|
||||
LSTMUnitGradientKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, C_diff,
|
||||
H_diff, C_prev_diff, X_diff,
|
||||
forget_bias);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel<float>);
|
||||
REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel<float>);
|
@ -0,0 +1,148 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::LoDTensor;
|
||||
using framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
inline T sigmoid(T x) {
|
||||
return 1. / (1. + exp(-x));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T tanh(T x) {
|
||||
return 2. * sigmoid(2. * x) - 1.;
|
||||
}
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class LstmUnitKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
||||
"It must use CPUPlace.");
|
||||
|
||||
auto* x_tensor = ctx.Input<framework::Tensor>("X");
|
||||
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
|
||||
auto* c_tensor = ctx.Output<framework::Tensor>("C");
|
||||
auto* h_tensor = ctx.Output<framework::Tensor>("H");
|
||||
|
||||
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||
|
||||
int b_size = c_tensor->dims()[0];
|
||||
int D = c_tensor->dims()[1];
|
||||
|
||||
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
const T* X = x_tensor->data<T>();
|
||||
const T* C_prev = c_prev_tensor->data<T>();
|
||||
|
||||
for (int n = 0; n < b_size; ++n) {
|
||||
for (int d = 0; d < D; ++d) {
|
||||
const T i = sigmoid(X[d]);
|
||||
const T f = sigmoid(X[1 * D + d] + forget_bias);
|
||||
const T o = sigmoid(X[2 * D + d]);
|
||||
const T g = tanh(X[3 * D + d]);
|
||||
const T c_prev = C_prev[d];
|
||||
const T c = f * c_prev + i * g;
|
||||
C[d] = c;
|
||||
const T tanh_c = tanh(c);
|
||||
H[d] = o * tanh_c;
|
||||
}
|
||||
C_prev += D;
|
||||
X += 4 * D;
|
||||
C += D;
|
||||
H += D;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class LstmUnitGradKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
||||
"It must use CPUPlace.");
|
||||
|
||||
auto x_tensor = ctx.Input<Tensor>("X");
|
||||
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
|
||||
auto c_tensor = ctx.Input<Tensor>("C");
|
||||
auto h_tensor = ctx.Input<Tensor>("H");
|
||||
|
||||
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
|
||||
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
|
||||
|
||||
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto c_prev_diff_tensor =
|
||||
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
|
||||
|
||||
auto* X = x_tensor->data<T>();
|
||||
auto* C_prev = c_prev_tensor->data<T>();
|
||||
auto* C = c_tensor->data<T>();
|
||||
auto* H = h_tensor->data<T>();
|
||||
|
||||
auto* H_diff = hdiff_tensor->data<T>();
|
||||
auto* C_diff = cdiff_tensor->data<T>();
|
||||
|
||||
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int N = c_tensor->dims()[0];
|
||||
int D = c_tensor->dims()[1];
|
||||
|
||||
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int d = 0; d < D; ++d) {
|
||||
T* c_prev_diff = C_prev_diff + d;
|
||||
T* i_diff = X_diff + d;
|
||||
T* f_diff = X_diff + 1 * D + d;
|
||||
T* o_diff = X_diff + 2 * D + d;
|
||||
T* g_diff = X_diff + 3 * D + d;
|
||||
|
||||
const T i = sigmoid(X[d]);
|
||||
const T f = sigmoid(X[1 * D + d] + forget_bias);
|
||||
const T o = sigmoid(X[2 * D + d]);
|
||||
const T g = tanh(X[3 * D + d]);
|
||||
const T c_prev = C_prev[d];
|
||||
const T c = C[d];
|
||||
const T tanh_c = tanh(c);
|
||||
const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c);
|
||||
*c_prev_diff = c_term_diff * f;
|
||||
*i_diff = c_term_diff * g * i * (1 - i);
|
||||
*f_diff = c_term_diff * c_prev * f * (1 - f);
|
||||
*o_diff = H_diff[d] * tanh_c * o * (1 - o);
|
||||
*g_diff = c_term_diff * i * (1 - g * g);
|
||||
}
|
||||
C_prev += D;
|
||||
X += 4 * D;
|
||||
C += D;
|
||||
H += D;
|
||||
C_diff += D;
|
||||
H_diff += D;
|
||||
X_diff += 4 * D;
|
||||
C_prev_diff += D;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,38 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def sigmoid_np(x):
|
||||
return 1. / (1. + np.exp(-x))
|
||||
|
||||
|
||||
def tanh_np(x):
|
||||
return 2 * sigmoid_np(2. * x) - 1.
|
||||
|
||||
|
||||
class LstmUnitTest(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "lstm_unit"
|
||||
x_np = np.random.normal(size=(5, 16)).astype("float32")
|
||||
c_np = np.random.normal(size=(5, 4)).astype("float32")
|
||||
i_np, f_np, o_np, j_np = np.split(x_np, 4, axis=1)
|
||||
forget_bias_np = 0.
|
||||
self.attrs = {'forget_bias': 0.}
|
||||
|
||||
new_c = c_np * sigmoid_np(f_np + forget_bias_np) + sigmoid_np(
|
||||
i_np) * tanh_np(j_np)
|
||||
new_h = tanh_np(new_c) * sigmoid_np(o_np)
|
||||
|
||||
self.inputs = {'X': x_np, 'C_prev': c_np}
|
||||
self.outputs = {'C': new_c, 'H': new_h}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X', 'C_prev'], ['C', 'H'], max_relative_error=0.01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue