|
|
|
@ -14,6 +14,11 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/gru_op.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
|
|
|
|
|
|
|
|
|
|
DECLARE_int32(paddle_num_threads);
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -211,6 +216,158 @@ class GRUGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class GRUCPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void BatchCompute(const framework::ExecutionContext& context) const {
|
|
|
|
|
using DeviceContext = paddle::platform::CPUDeviceContext;
|
|
|
|
|
auto* input = context.Input<LoDTensor>("Input");
|
|
|
|
|
auto* h0 = context.Input<Tensor>("H0");
|
|
|
|
|
auto* weight = context.Input<Tensor>("Weight");
|
|
|
|
|
const T* weight_data = weight->data<T>();
|
|
|
|
|
auto* bias = context.Input<Tensor>("Bias");
|
|
|
|
|
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
|
|
|
|
|
batch_gate->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto* batch_reset_hidden_prev =
|
|
|
|
|
context.Output<LoDTensor>("BatchResetHiddenPrev");
|
|
|
|
|
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
|
|
|
|
|
batch_hidden->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto* hidden = context.Output<LoDTensor>("Hidden");
|
|
|
|
|
hidden->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto hidden_dims = hidden->dims();
|
|
|
|
|
|
|
|
|
|
bool is_reverse = context.Attr<bool>("is_reverse");
|
|
|
|
|
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
|
|
|
|
|
|
|
|
|
|
if (bias) {
|
|
|
|
|
math::RowwiseAdd<DeviceContext, T> add_bias;
|
|
|
|
|
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int frame_size = hidden_dims[1];
|
|
|
|
|
math::GRUMetaValue<T> gru_value;
|
|
|
|
|
gru_value.gate_weight = const_cast<T*>(weight_data);
|
|
|
|
|
gru_value.state_weight =
|
|
|
|
|
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
|
|
|
|
|
Tensor ordered_h0;
|
|
|
|
|
|
|
|
|
|
framework::Vector<size_t> order(batch_gate->lod()[2]);
|
|
|
|
|
|
|
|
|
|
if (h0) {
|
|
|
|
|
// Since the batch computing for GRU reorders the input sequences
|
|
|
|
|
// according to their length. The initialized cell state also needs
|
|
|
|
|
// to reorder.
|
|
|
|
|
ReorderInitState<DeviceContext, T>(
|
|
|
|
|
context.template device_context<DeviceContext>(), *h0, order,
|
|
|
|
|
&ordered_h0, true);
|
|
|
|
|
gru_value.prev_out_value = ordered_h0.data<T>();
|
|
|
|
|
} else {
|
|
|
|
|
gru_value.prev_out_value = nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto batch_starts = batch_gate->lod()[0];
|
|
|
|
|
size_t seq_len = batch_starts.size() - 1;
|
|
|
|
|
auto active_node = math::detail::GetActivationType(
|
|
|
|
|
context.Attr<std::string>("activation"));
|
|
|
|
|
auto active_gate = math::detail::GetActivationType(
|
|
|
|
|
context.Attr<std::string>("gate_activation"));
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
// use MKL packed to speedup GEMM
|
|
|
|
|
if (FLAGS_paddle_num_threads >= 4) {
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
|
|
|
|
|
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
|
|
|
|
|
frame_size * 2 /*width of weight*/,
|
|
|
|
|
frame_size /*height of height*/);
|
|
|
|
|
PADDLE_ENFORCE(packed_gate);
|
|
|
|
|
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
|
|
|
|
|
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
|
|
|
|
|
packed_gate);
|
|
|
|
|
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
|
|
|
|
|
frame_size /*width of weight*/,
|
|
|
|
|
frame_size /*height of height*/);
|
|
|
|
|
PADDLE_ENFORCE(packed_state);
|
|
|
|
|
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
|
|
|
|
|
frame_size, T(1.0), gru_value.state_weight, frame_size,
|
|
|
|
|
packed_state);
|
|
|
|
|
for (size_t n = 0; n < seq_len; n++) {
|
|
|
|
|
int bstart = static_cast<int>(batch_starts[n]);
|
|
|
|
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
|
|
|
|
int cur_batch_size = bend - bstart;
|
|
|
|
|
|
|
|
|
|
Tensor gate_t = batch_gate->Slice(bstart, bend);
|
|
|
|
|
Tensor reset_hidden_prev_t =
|
|
|
|
|
batch_reset_hidden_prev->Slice(bstart, bend);
|
|
|
|
|
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
|
|
|
|
|
gru_value.output_value = hidden_t.data<T>();
|
|
|
|
|
gru_value.gate_value = gate_t.data<T>();
|
|
|
|
|
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
|
|
|
|
|
|
|
|
|
|
if (gru_value.prev_out_value) {
|
|
|
|
|
blas.GEMM_COMPUTE(
|
|
|
|
|
CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
|
|
|
|
|
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
|
|
|
|
|
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
math::detail::forward_reset_output(
|
|
|
|
|
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
|
|
|
|
|
cur_batch_size, active_gate);
|
|
|
|
|
|
|
|
|
|
if (gru_value.prev_out_value) {
|
|
|
|
|
blas.GEMM_COMPUTE(
|
|
|
|
|
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
|
|
|
|
|
gru_value.reset_output_value, frame_size, packed_state,
|
|
|
|
|
frame_size, T(1), gru_value.gate_value + frame_size * 2,
|
|
|
|
|
frame_size * 3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
math::detail::forward_final_output(
|
|
|
|
|
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
|
|
|
|
|
cur_batch_size, active_node);
|
|
|
|
|
|
|
|
|
|
gru_value.prev_out_value = gru_value.output_value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
blas.GEMM_FREE(packed_gate);
|
|
|
|
|
blas.GEMM_FREE(packed_state);
|
|
|
|
|
} else {
|
|
|
|
|
#endif
|
|
|
|
|
for (size_t n = 0; n < seq_len; n++) {
|
|
|
|
|
int bstart = static_cast<int>(batch_starts[n]);
|
|
|
|
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
|
|
|
|
int cur_batch_size = bend - bstart;
|
|
|
|
|
|
|
|
|
|
Tensor gate_t = batch_gate->Slice(bstart, bend);
|
|
|
|
|
Tensor reset_hidden_prev_t =
|
|
|
|
|
batch_reset_hidden_prev->Slice(bstart, bend);
|
|
|
|
|
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
|
|
|
|
|
gru_value.output_value = hidden_t.data<T>();
|
|
|
|
|
gru_value.gate_value = gate_t.data<T>();
|
|
|
|
|
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
|
|
|
|
|
|
|
|
|
|
math::GRUUnitFunctor<DeviceContext, T>::compute(
|
|
|
|
|
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
|
|
|
|
|
active_gate);
|
|
|
|
|
|
|
|
|
|
gru_value.prev_out_value = gru_value.output_value;
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
|
|
|
|
|
batch_hidden->set_lod(batch_gate->lod());
|
|
|
|
|
to_seq(dev_ctx, *batch_hidden, hidden);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
BatchCompute(context);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -218,9 +375,8 @@ namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
gru, ops::GRUKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::GRUKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
|
|
|
|
|
ops::GRUCPUKernel<double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
gru_grad, ops::GRUGradKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::GRUGradKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|