You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/rnn_op.h

2017 lines
85 KiB

/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/unique_op.h"
#include "paddle/fluid/operators/utils.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
using TensorList = std::vector<framework::Tensor>;
#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \
inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \
const std::string& mode = ctx.Attr<std::string>("mode"); \
return mode == #MODE_STR; \
}
DEFINE_MODE_DETECTOR(lstm, LSTM);
DEFINE_MODE_DETECTOR(gru, GRU);
DEFINE_MODE_DETECTOR(rnn_relu, RNN_RELU);
DEFINE_MODE_DETECTOR(rnn_tanh, RNN_TANH);
void SwapPoniter(Tensor** a, Tensor** b) {
Tensor* c = *a;
*a = *b;
*b = c;
}
template <typename T>
void create_mask_matrix(const framework::ExecutionContext& context,
const Tensor* sequence_length, Tensor* mask_matrix,
const bool& is_reverse, int* min_seq_len) {
const auto& seq_len_vec = GetDataFromTensor<int>(sequence_length);
const int& table_width = mask_matrix->dims()[0];
Tensor temp;
temp.Resize(
framework::make_ddim({mask_matrix->dims()[1], mask_matrix->dims()[0]}));
T* data_temp = temp.mutable_data<T>(context.GetPlace());
std::fill(data_temp, data_temp + mask_matrix->numel(), static_cast<T>(1.0));
*min_seq_len = table_width;
for (unsigned int i = 0; i < seq_len_vec.size(); i++) {
// reset the mask matrix
*min_seq_len = std::min(seq_len_vec[i], *min_seq_len);
if (seq_len_vec[i] == table_width) {
continue;
}
if (is_reverse) {
std::fill(data_temp + i * table_width,
data_temp + (i + 1) * table_width - seq_len_vec[i],
static_cast<T>(0));
} else {
std::fill(data_temp + i * table_width + seq_len_vec[i],
data_temp + (i + 1) * table_width, static_cast<T>(0));
}
}
mask_matrix->mutable_data<T>(context.GetPlace());
std::vector<int> trans_vec;
trans_vec.emplace_back(1);
trans_vec.emplace_back(0);
auto& dev_ctx = context.template device_context<platform::CPUDeviceContext>();
TransCompute<platform::CPUDeviceContext, T>(2, dev_ctx, temp, mask_matrix,
trans_vec);
}
template <typename T>
struct Cell {
virtual ~Cell() {}
virtual void operator()(const platform::CPUDeviceContext* device_ctx,
Tensor* input, const Tensor* weight_hh,
const Tensor* init_h, const Tensor* init_c,
Tensor* last_h, Tensor* last_c, Tensor* last_c_act,
Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const {}
};
template <typename T, template <typename> class EigenActivationFunctor,
math::detail::ActivationType act_type>
struct SimpleRNNCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h,
const Tensor* init_c, Tensor* last_h, Tensor* last_c,
Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const override {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0),
input, static_cast<T>(1.0));
auto z = EigenVector<T>::Flatten(
GET_DATA_SAFELY(input, "Input", "z", "Activation"));
auto hidden = EigenVector<T>::Flatten(
GET_DATA_SAFELY(output, "Output", "hidden", "Activation"));
auto* place = device_ctx->eigen_device();
EigenActivationFunctor<T> functor;
functor(*place, z, hidden);
}
};
template <typename T>
struct GRUCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h,
const Tensor* init_c, Tensor* last_h, Tensor* last_c,
Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const override {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b =
math::CreateMatrixDescriptor(weight_hh_gru->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h, mat_dim_a, *weight_hh_gru, mat_dim_b,
static_cast<T>(1.0), input, static_cast<T>(1.0));
size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1];
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = weight_hh->data<T>();
gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_value.reset_bias = bias_hh->data<T>() + 2 * frame_size;
gru_value.gate_value = input->data<T>();
gru_value.reset_output_value = last_c->data<T>();
gru_value.output_value = output->data<T>();
gru_value.prev_out_value = init_h->data<T>();
auto gate_act = math::detail::GetActivationType("sigmoid_v2");
auto cand_act = math::detail::GetActivationType("tanh_v2");
math::GRUUnitFunctorV2<platform::CPUDeviceContext, T>::compute(
*device_ctx, gru_value, frame_size, batch_size, cand_act, gate_act);
}
};
template <typename T>
struct LSTMCell : Cell<T> {
void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
const Tensor* weight_hh, const Tensor* init_h,
const Tensor* init_c, Tensor* last_h, Tensor* last_c,
Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
Tensor* weight_hh_gru) const override {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
// convert the batch matmul to matmul, this operator could be speed faster
blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0),
input, static_cast<T>(1.0));
math::LstmMetaValue<T> lstm_value;
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
auto gate_act = math::detail::GetActivationType("sigmoid_v2");
auto cell_act = math::detail::GetActivationType("tanh_v2");
auto cand_act = math::detail::GetActivationType("tanh_v2");
size_t frame_size = init_h->dims()[2];
size_t batch_size = init_h->dims()[1];
Tensor cell_pre_act;
if (last_c_act == nullptr) { /* is test */
cell_pre_act.mutable_data<T>(init_h->dims(), device_ctx->GetPlace());
last_c_act = &cell_pre_act;
}
lstm_value.prev_state_value = init_c->data<T>();
lstm_value.gate_value = input->data<T>();
lstm_value.output_value = output->data<T>();
lstm_value.state_value = last_c->data<T>();
lstm_value.state_active_value = last_c_act->data<T>();
T cell_clip = 0.0;
math::LstmUnitFunctor<platform::CPUDeviceContext, T>::compute(
*device_ctx, lstm_value, frame_size, batch_size, cell_clip, gate_act,
cell_act, cand_act, false);
}
};
template <typename T>
void dropout_helper(const framework::ExecutionContext& context, Tensor* x,
Tensor* y, const Tensor* mask, const float& dropout_prob) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto dropout_mask = EigenVector<uint8_t>::Flatten(*mask);
auto in = EigenVector<T>::Flatten(*x);
auto out = EigenVector<T>::Flatten(*y);
if (dropout_prob == 1.0f) {
out.device(place) = static_cast<T>(0) * in;
} else {
out.device(place) =
in * dropout_mask.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
template <typename T>
void dropout_cpu_function_inplace(const framework::ExecutionContext& context,
Tensor* x, Tensor* y, Tensor* mask,
const float& dropout_prob,
const int& seed_number, const bool& is_test,
bool* is_has_reset) {
if (is_test) {
return;
}
size_t size = framework::product(x->dims());
auto* mask_data = mask->data<uint8_t>();
if (!(*is_has_reset)) {
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::fill(mask_data, mask_data + size, static_cast<uint8_t>(0));
} else {
auto engine = framework::GetCPURandomEngine(seed_number);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
} else {
mask_data[i] = 1;
}
}
}
*is_has_reset = true;
}
dropout_helper<T>(context, x, y, mask, dropout_prob);
}
template <typename T>
void dropout_cpu_grad_function_inplace(
const framework::ExecutionContext& context, Tensor* grad_x,
const Tensor* mask, const float& dropout_prob) {
dropout_helper<T>(context, grad_x, grad_x, mask, dropout_prob);
}
template <typename T, typename CellType>
struct Layer {
explicit Layer(const CellType& cell) : cell_(cell) {}
virtual ~Layer() {}
void preprocess(const framework::ExecutionContext& context,
const Tensor* input, const Tensor& weight,
const Tensor& bias_ih, const Tensor& bias_hh,
Tensor* cache_input, bool is_test) {
// crate the temp input for the X * W_ih^T + Bias_ih
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& hidden_size = weight.dims()[0];
cache_input->Resize(framework::make_ddim(
{input->dims()[0], input->dims()[1], hidden_size}));
if (is_test) {
cache_input->mutable_data<T>(context.GetPlace());
}
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
auto mat_dim_a = math::CreateMatrixDescriptor(input->dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(weight.dims(), 0, true);
// convert the batch matmul to matmul, this operator could be speed faster
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
blas.MatMul(*input, mat_dim_a, weight, mat_dim_b, static_cast<T>(1.0),
cache_input, static_cast<T>(0));
auto in = framework::EigenMatrix<T>::Reshape(
*cache_input, cache_input->dims().size() - 1);
auto bias_ih_tmp = framework::EigenMatrix<T>::From(
bias_ih, framework::make_ddim({1, bias_ih.dims()[0]}));
const int& row_num =
framework::product(cache_input->dims()) / cache_input->dims()[2];
in = in + bias_ih_tmp.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
if (is_gru(context)) {
// reset_gate update_gate cell_gate = [1, 1, 0]
Tensor bias_hh_tmp;
bias_hh_tmp.Resize({bias_hh.numel()});
bias_hh_tmp.mutable_data<T>(context.GetPlace());
framework::TensorCopy(bias_hh, context.GetPlace(), dev_ctx, &bias_hh_tmp);
bias_hh_tmp.Resize({3, bias_hh_tmp.numel() / 3});
auto bias_hh_tmp_unbind = Unbind(bias_hh_tmp);
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &bias_hh_tmp_unbind[2], static_cast<T>(0.0));
auto bias_hh_after_mask = framework::EigenMatrix<T>::From(
bias_hh_tmp, framework::make_ddim({1, bias_hh.dims()[0]}));
in = in + bias_hh_after_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
} else {
auto bias_hh_no_mask = framework::EigenMatrix<T>::From(
bias_hh, framework::make_ddim({1, bias_hh.dims()[0]}));
in = in + bias_hh_no_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
}
}
void postprocess(const framework::ExecutionContext& context, Tensor* output,
const Tensor* init_h, const Tensor* init_c, Tensor* last_h,
Tensor* last_c, const Tensor& mask_tensor) {
// in the output, if mask flag is 0, we will retun the zero data
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto out =
framework::EigenMatrix<T>::Reshape(*output, output->dims().size() - 1);
auto mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto pre_h =
framework::EigenMatrix<T>::Reshape(*init_h, init_h->dims().size() - 1);
auto curr_h =
framework::EigenMatrix<T>::Reshape(*last_h, last_h->dims().size() - 1);
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, output->dims()[2]));
curr_h.device(place) = out * mask_broadcast + pre_h * (1 - mask_broadcast);
out.device(place) = out * mask_broadcast;
if (is_lstm(context)) {
auto pre_c = framework::EigenMatrix<T>::Reshape(
*init_c, init_c->dims().size() - 1);
auto curr_c = framework::EigenMatrix<T>::Reshape(
*last_c, last_c->dims().size() - 1);
curr_c.device(place) =
curr_c * mask_broadcast + pre_c * (1 - mask_broadcast);
}
}
virtual void operator()(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList last_h,
TensorList last_c, Tensor* output,
const int& layer_idx, const int& gate_num,
Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_test) {}
void RunTestIter(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList* last_h_ptr,
TensorList* last_c_ptr, Tensor* output, int layer_idx,
Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_bidirect, int offset) {
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& time_step = input->dims()[0];
this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4],
vec[3 + offset * 4], gate_value, true);
auto input_tensors = Unbind(*gate_value);
auto output_tensors = Unbind(*output);
if (is_reverse) {
std::reverse(input_tensors.begin(), input_tensors.end());
std::reverse(output_tensors.begin(), output_tensors.end());
}
TensorList mask_tensor_list;
// construct the mask matrix for the mask
bool has_sequence_length = false;
if (sequence_length != nullptr) {
has_sequence_length = true;
}
Tensor mask_matrix;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
&mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
if (is_reverse) {
mask_min_length = mask_min_length - time_step + 1;
}
bool has_allocate_mem_c = false;
bool has_use_last_h_holder = false;
const int& reverse_flag = is_reverse ? -1 : 1;
// define the init_h holder for the swap
Tensor init_h_temp;
framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx,
&init_h_temp);
Tensor* init_h_holder = &init_h_temp;
Tensor* last_h_holder = nullptr;
if (0 < mask_min_length) {
last_h_holder = &(output_tensors[0]);
} else {
last_h_holder = &(*last_h_ptr)[layer_idx];
has_use_last_h_holder = true;
}
Tensor* init_c_holder = nullptr;
const Tensor* init_c_temp_holder = nullptr;
Tensor init_c_temp;
Tensor* last_c_holder = nullptr;
Tensor last_c_temp;
if (is_lstm(context)) {
last_c_holder = &(*last_c_ptr)[layer_idx];
init_c_temp_holder = &init_c[layer_idx];
} else if (is_gru(context)) {
// for reset output value
last_c_temp.Resize(init_h[layer_idx].dims());
last_c_temp.mutable_data<T>(context.GetPlace());
last_c_holder = &last_c_temp;
}
Tensor weight_hh_tmp; // for gru
if (is_gru(context)) {
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
weight_hh_tmp.mutable_data<T>(context.GetPlace());
framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx,
&weight_hh_tmp);
weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
}
for (int i = 0; i < time_step; i++) {
bool in_mask = (reverse_flag * i) >= mask_min_length;
if (i > 0) {
if (!has_allocate_mem_c) {
if (is_lstm(context) || is_gru(context)) {
init_c_temp.Resize(init_h[layer_idx].dims());
init_c_temp.mutable_data<T>(context.GetPlace());
init_c_holder = &init_c_temp;
}
has_allocate_mem_c = true;
}
SwapPoniter(&init_c_holder, &last_c_holder);
init_c_temp_holder = init_c_holder;
}
cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder,
init_c_temp_holder, last_h_holder, last_c_holder, nullptr,
&output_tensors[i], &vec[3 + offset * 4] /* bias_hh */,
&weight_hh_tmp);
if (in_mask) {
this->postprocess(context, &output_tensors[i], init_h_holder,
init_c_temp_holder, last_h_holder, last_c_holder,
mask_tensor_list[i]);
}
// prepare next step
if (i + 1 < time_step) {
bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
if (next_step_mask) {
if (!has_use_last_h_holder) {
init_h_holder = &(*last_h_ptr)[layer_idx];
}
} else {
init_h_holder = &(output_tensors[i + 1]);
}
SwapPoniter(&init_h_holder, &last_h_holder);
}
}
if (has_sequence_length) {
if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx,
&(*last_h_ptr)[layer_idx]);
}
} else {
framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(),
dev_ctx, &(*last_h_ptr)[layer_idx]);
}
if (time_step % 2 == 0) {
if (is_lstm(context)) {
framework::TensorCopy(*last_c_holder, context.GetPlace(), dev_ctx,
&(*last_c_ptr)[layer_idx]);
}
}
}
void RunIter(const framework::ExecutionContext& context, const Tensor* input,
const TensorList& vec, const TensorList& init_h,
const TensorList& init_c, const Tensor* sequence_length,
TensorList* last_h_ptr, TensorList* last_c_ptr, Tensor* output,
int layer_idx, Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_bidirect, int offset,
bool is_test) {
if (is_test) {
RunTestIter(context, input, vec, init_h, init_c, sequence_length,
last_h_ptr, last_c_ptr, output, layer_idx, gate_value,
cell_value, cell_act_value, is_bidirect, offset);
return;
}
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& time_step = input->dims()[0];
this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4],
vec[3 + offset * 4], gate_value, is_test);
auto input_tensors = Unbind(*gate_value);
auto output_tensors = Unbind(*output);
if (is_reverse) {
std::reverse(input_tensors.begin(), input_tensors.end());
std::reverse(output_tensors.begin(), output_tensors.end());
}
TensorList mask_tensor_list;
// construct the mask matrix for the mask
bool has_sequence_length = false;
if (sequence_length != nullptr) {
has_sequence_length = true;
}
Tensor mask_matrix;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
&mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
if (is_reverse) {
mask_min_length = mask_min_length - time_step + 1;
}
// define the init_h holder for the swap
bool has_use_last_h_holder = false;
const int& reverse_flag = is_reverse ? -1 : 1;
TensorList cell_value_tensors;
TensorList cell_act_value_tensors;
Tensor init_h_temp;
framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx,
&init_h_temp);
Tensor* init_h_holder = &init_h_temp;
Tensor* last_h_holder = nullptr;
if (0 < mask_min_length) {
last_h_holder = &(output_tensors[0]);
} else {
last_h_holder = &(*last_h_ptr)[layer_idx];
has_use_last_h_holder = true;
}
const Tensor* init_c_holder = nullptr;
Tensor* last_c_holder = nullptr;
Tensor* last_c_act_holder = nullptr;
if (is_lstm(context) || is_gru(context)) {
cell_value->Resize({time_step, cell_value->numel() / time_step});
cell_value_tensors = Unbind(*cell_value);
if (is_lstm(context)) {
cell_act_value->Resize(
{time_step, cell_act_value->numel() / time_step});
cell_act_value_tensors = Unbind(*cell_act_value);
}
}
Tensor weight_hh_tmp; // for gru
if (is_gru(context)) {
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
weight_hh_tmp.mutable_data<T>(context.GetPlace());
framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx,
&weight_hh_tmp);
weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
}
for (int i = 0; i < time_step; i++) {
bool in_mask = (reverse_flag * i) >= mask_min_length;
if (is_lstm(context)) {
if (i == 0) {
init_c_holder = &init_c[layer_idx];
} else {
init_c_holder = &cell_value_tensors[i - 1];
}
cell_value_tensors[i].Resize(init_c[layer_idx].dims());
cell_act_value_tensors[i].Resize(init_c[layer_idx].dims());
last_c_holder = &cell_value_tensors[i];
last_c_act_holder = &cell_act_value_tensors[i];
} else if (is_gru(context)) {
cell_value_tensors[i].Resize(init_h[layer_idx].dims());
last_c_holder = &cell_value_tensors[i];
}
cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder,
init_c_holder, last_h_holder, last_c_holder, last_c_act_holder,
&output_tensors[i], &vec[3 + offset * 4] /* bias_hh */,
&weight_hh_tmp);
if (in_mask) {
this->postprocess(context, &output_tensors[i], init_h_holder,
init_c_holder, last_h_holder, last_c_holder,
mask_tensor_list[i]);
}
// prepare next step
if (i + 1 < time_step) {
bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
if (next_step_mask) {
if (!has_use_last_h_holder) {
init_h_holder = &(*last_h_ptr)[layer_idx];
}
} else {
init_h_holder = &(output_tensors[i + 1]);
}
SwapPoniter(&init_h_holder, &last_h_holder);
}
}
if (has_sequence_length) {
if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx,
&(*last_h_ptr)[layer_idx]);
}
} else {
framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(),
dev_ctx, &(*last_h_ptr)[layer_idx]);
}
if (is_lstm(context)) {
framework::TensorCopy(cell_value_tensors[time_step - 1],
context.GetPlace(), dev_ctx,
&(*last_c_ptr)[layer_idx]);
}
}
// Cell for the rnn module
CellType cell_;
};
template <typename T, typename CellType>
struct SingleLayer : public Layer<T, CellType> {
explicit SingleLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
void operator()(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList last_h,
TensorList last_c, Tensor* output, const int& layer_idx,
const int& gate_num, Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_test) {
this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
&last_c, output, layer_idx, gate_value, cell_value,
cell_act_value, false, 0, is_test);
}
};
template <typename T, typename CellType>
struct BidirLayer : public Layer<T, CellType> {
explicit BidirLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
void operator()(const framework::ExecutionContext& context,
const Tensor* input, const TensorList& vec,
const TensorList& init_h, const TensorList& init_c,
const Tensor* sequence_length, TensorList last_h,
TensorList last_c, Tensor* output, const int& layer_idx,
const int& gate_num, Tensor* gate_value, Tensor* cell_value,
Tensor* cell_act_value, bool is_test) {
TensorList output_vec(2);
Tensor forward_input_w, forward_cell_value, forward_cell_act_value;
Tensor backward_input_w, backward_cell_value, backward_cell_act_value;
int time_step = input->dims()[0];
int batch_size = input->dims()[1];
int hidden_size = output->dims()[2];
for (int i = 0; i < 2; ++i) {
output_vec[i].Resize({time_step, batch_size, hidden_size / 2});
output_vec[i].mutable_data<T>(context.GetPlace());
}
if (!is_test) {
gate_value->Resize({2, gate_value->numel() / 2});
forward_input_w = gate_value->Slice(0, 1);
backward_input_w = gate_value->Slice(1, 2);
if (is_lstm(context) || is_gru(context)) /* for lstm and gru */ {
cell_value->Resize({2, cell_value->numel() / 2});
cell_act_value->Resize({2, cell_act_value->numel() / 2});
forward_cell_value = cell_value->Slice(0, 1);
backward_cell_value = cell_value->Slice(1, 2);
if (is_lstm(context)) {
forward_cell_act_value = cell_act_value->Slice(0, 1);
backward_cell_act_value = cell_act_value->Slice(1, 2);
}
}
}
this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
&last_c, &output_vec[0], layer_idx, &forward_input_w,
&forward_cell_value, &forward_cell_act_value, true, 0,
is_test);
this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
&last_c, &output_vec[1], layer_idx, &backward_input_w,
&backward_cell_value, &backward_cell_act_value, true, 1,
is_test);
// concat the the output result
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
paddle::operators::math::ConcatFunctor<platform::CPUDeviceContext, T>
concat_functor;
concat_functor(dev_ctx, output_vec, static_cast<int>(2), output);
}
};
template <typename TensorType>
void SplitReserveData(const framework::ExecutionContext& ctx,
TensorType* reserve_data, Tensor* gate_data,
Tensor* cell_data, Tensor* cell_act_data,
Tensor* hidden_data, int direction_num,
const int& time_step, const int& batch_size,
const int& hidden_size, const int& gate_num,
const int& num_layers) {
const int& gate_data_idx = gate_num * num_layers;
const int& cell_data_idx = (gate_num + 1) * num_layers;
const int& cell_act_data_idx = (gate_num + 2) * num_layers;
// simple rnn
int hidden_data_start_idx = gate_data_idx;
*gate_data = reserve_data->Slice(0, gate_data_idx);
if (is_lstm(ctx)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
*cell_act_data = reserve_data->Slice(cell_data_idx, cell_act_data_idx);
hidden_data_start_idx = cell_act_data_idx;
} else if (is_gru(ctx)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
hidden_data_start_idx = cell_data_idx;
}
int hidden_data_idx = hidden_data_start_idx + (num_layers - 1);
if (num_layers > 1) {
*hidden_data = reserve_data->Slice(hidden_data_start_idx, hidden_data_idx);
}
}
template <typename TensorType>
void reset_parameter_vector(const std::vector<TensorType>& raw_params_vec,
const int& num_layers, const int& gate_num,
const bool& is_bidirec,
std::vector<TensorList>* params_vec) {
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
const int& direction_num = is_bidirec ? 2 : 1;
const int& layer_weight_size = 4 * direction_num;
const int& all_weight_size = num_layers * layer_weight_size;
const int& bias_start_idx = all_weight_size / 2;
for (int i = 0; i < num_layers; i++) {
TensorList tensor_list;
tensor_list.reserve(layer_weight_size);
for (int j = 0; j < layer_weight_size; j++) {
Tensor tensor_holder;
tensor_list.emplace_back(tensor_holder);
}
for (int j = 0; j < layer_weight_size; j++) {
int k = j % 4;
const int& section = j / 4;
int tensor_idx = i * 2 * direction_num + section * 2 + k % 2;
if (k >= 2) {
tensor_idx += bias_start_idx;
}
tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]);
}
params_vec->emplace_back(tensor_list);
}
}
template <typename CellType, typename T>
void AllocateReserveData(const framework::ExecutionContext& ctx,
Tensor* reserve_data, Tensor* gate_data,
Tensor* cell_data, Tensor* cell_act_data,
Tensor* hidden_data, const Tensor* input,
bool is_bidirec, int num_layers, int gate_num,
int hidden_size) {
const int& direction_num = is_bidirec ? 2 : 1;
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& block_size = direction_num * time_step * batch_size * hidden_size;
int hidden_data_idx = (num_layers - 1);
if (is_lstm(ctx)) {
hidden_data_idx += (gate_num + 2) * num_layers;
} else if (is_gru(ctx)) {
hidden_data_idx += (gate_num + 1) * num_layers;
} else {
hidden_data_idx += gate_num * num_layers;
}
reserve_data->Resize({hidden_data_idx, block_size});
reserve_data->mutable_data<T>(ctx.GetPlace());
SplitReserveData(ctx, reserve_data, gate_data, cell_data, cell_act_data,
hidden_data, direction_num, time_step, batch_size,
hidden_size, gate_num, num_layers);
}
template <typename CellType, template <typename, typename> class LayerT,
template <typename, typename> class SingleLayerT,
template <typename, typename> class BidirLayerT, typename T>
void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input,
const std::vector<const Tensor*> weight_list, const Tensor* init_h,
const Tensor* init_c, const Tensor* sequence_length,
Tensor* last_h, Tensor* last_c, Tensor* output,
Tensor* dropout_mask, const int& num_layers, const int& gate_num,
const int& input_size, const int& hidden_size,
const bool& is_bidirec, const std::string& cell_type,
const float& dropout_prob, const bool& is_test, const int& seed,
Tensor* reserve_data) {
const int& direction_num = is_bidirec ? 2 : 1;
const auto& init_h_dims = init_h->dims();
PADDLE_ENFORCE_EQ(init_h_dims[0], num_layers * direction_num,
platform::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of init hidden, but received"
" num_layers:%d, dim:%d",
num_layers, init_h_dims[0]));
if (is_lstm(ctx)) {
const auto& init_c_dims = init_c->dims();
PADDLE_ENFORCE_EQ(init_c_dims[0], num_layers * direction_num,
platform::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of cell state hidden, but received"
" num_layers:%d, dim:%d",
num_layers, init_h_dims[0]));
}
CellType cell;
std::vector<TensorList> parameter_lists;
parameter_lists.reserve(num_layers);
reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec,
&parameter_lists);
Tensor gate_data, cell_data, cell_act_data, hidden_data;
if (!is_test) {
AllocateReserveData<CellType, T>(
ctx, reserve_data, &gate_data, &cell_data, &cell_act_data, &hidden_data,
input, is_bidirec, num_layers, gate_num, hidden_size);
gate_data.Resize({num_layers, gate_data.numel() / num_layers});
cell_data.Resize({num_layers, cell_data.numel() / num_layers});
cell_act_data.Resize({num_layers, cell_act_data.numel() / num_layers});
if (num_layers > 1) {
hidden_data.Resize(
{num_layers - 1, hidden_data.numel() / (num_layers - 1)});
}
}
Tensor* input_holder;
Tensor* output_holder = output;
Tensor temp;
bool has_allocate_mem = false;
auto init_h_unbind = Unbind(*init_h);
auto last_h_unbind = Unbind(*last_h);
TensorList init_c_unbind, last_c_unbind;
if (is_lstm(ctx)) {
init_c_unbind = Unbind(*init_c);
last_c_unbind = Unbind(*last_c);
}
Tensor curr_gate_data, curr_cell_data, curr_cell_act_data;
Tensor curr_hidden_data, prev_hidden_data;
bool has_dropout_reset = false;
for (int i = 0; i < num_layers; i++) {
if (!is_test) {
if (cell_data.numel() > 0) /** for lstm, gru **/ {
curr_cell_data = cell_data.Slice(i, i + 1);
}
if (cell_act_data.numel() > 0) /*for lstm*/ {
curr_cell_act_data = cell_act_data.Slice(i, i + 1);
}
curr_gate_data = gate_data.Slice(i, i + 1);
output_holder = output;
if (i < num_layers - 1 && num_layers > 1) {
curr_hidden_data = hidden_data.Slice(i, i + 1);
curr_hidden_data.Resize(output->dims());
output_holder = &curr_hidden_data;
}
}
if (i > 0) {
if (!has_allocate_mem) {
temp.Resize(output->dims());
temp.mutable_data<T>(ctx.GetPlace());
input_holder = &temp;
has_allocate_mem = true;
}
if (!is_test) {
prev_hidden_data = hidden_data.Slice(i - 1, i);
input_holder->Resize(output->dims());
if (dropout_prob != 0) {
dropout_cpu_function_inplace<T>(ctx, &prev_hidden_data, input_holder,
dropout_mask, dropout_prob, seed,
is_test, &has_dropout_reset);
} else {
input_holder = &prev_hidden_data;
input_holder->Resize(output->dims());
}
} else {
SwapPoniter(&output_holder, &input_holder);
}
}
const Tensor* input_temp_holder = input;
if (i > 0) {
input_temp_holder = input_holder;
}
LayerT<T, CellType>* layer;
SingleLayerT<T, CellType> slayer(cell);
BidirLayerT<T, CellType> blayer(cell);
if (is_bidirec) {
layer = &blayer;
} else {
layer = &slayer;
}
(*layer)(ctx, input_temp_holder, parameter_lists[i], init_h_unbind,
init_c_unbind, sequence_length, last_h_unbind, last_c_unbind,
output_holder, i, gate_num, &curr_gate_data, &curr_cell_data,
&curr_cell_act_data, is_test);
}
if (num_layers % 2 == 0) {
framework::TensorCopy(
*output_holder, ctx.GetPlace(),
ctx.template device_context<platform::CPUDeviceContext>(), output);
}
}
template <typename DeviceContext, typename T>
class RNNCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto pre_state = ctx.MultiInput<Tensor>("PreState");
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
auto state = ctx.MultiOutput<Tensor>("State");
auto* output = ctx.Output<Tensor>("Out");
auto* dropout_mask = ctx.Output<Tensor>("DropoutState");
auto* reserve_data = ctx.Output<Tensor>("Reserve");
const int& num_layers = ctx.Attr<int>("num_layers");
const bool& is_bidirec = ctx.Attr<bool>("is_bidirec");
const int& input_size = ctx.Attr<int>("input_size");
const int& hidden_size = ctx.Attr<int>("hidden_size");
const float& dropout_prob = ctx.Attr<float>("dropout_prob");
const std::string& mode = ctx.Attr<std::string>("mode");
const bool& is_test = ctx.Attr<bool>("is_test");
const int& seed = ctx.Attr<int>("seed");
bool has_seq_length = ctx.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr;
if (has_seq_length) {
sequence_length = ctx.Input<Tensor>("SequenceLength");
}
if (!dropout_mask->IsInitialized()) {
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
}
// init the output and allocate the memory
output->mutable_data<T>(ctx.GetPlace());
int gate_num = 4;
state[0]->mutable_data<T>(ctx.GetPlace());
if (is_lstm(ctx)) {
state[1]->mutable_data<T>(ctx.GetPlace());
RnnFunc<LSTMCell<T>, Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], pre_state[1], sequence_length,
state[0], state[1], output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
} else if (is_rnn_relu(ctx)) {
gate_num = 1;
RnnFunc<
SimpleRNNCell<T, ReluFunctor, math::detail::ActivationType::kReLU>,
Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
} else if (is_rnn_tanh(ctx)) {
gate_num = 1;
RnnFunc<
SimpleRNNCell<T, TanhFunctor, math::detail::ActivationType::kTanhV2>,
Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
} else if (is_gru(ctx)) {
gate_num = 3;
RnnFunc<GRUCell<T>, Layer, SingleLayer, BidirLayer, T>(
ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
state[0], nullptr, output, dropout_mask, num_layers, gate_num,
input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
seed, reserve_data);
}
}
};
template <typename T>
void create_lstm_value(math::LstmMetaValue<T>* lstm_value) {
lstm_value->check_ig = nullptr;
lstm_value->check_fg = nullptr;
lstm_value->check_og = nullptr;
}
template <typename T>
void create_lstm_grad(math::LstmMetaGrad<T>* lstm_grad) {
lstm_grad->check_ig_grad = nullptr;
lstm_grad->check_fg_grad = nullptr;
lstm_grad->check_og_grad = nullptr;
}
template <typename T>
void create_tensor_by_list(const framework::ExecutionContext& context,
Tensor* dst, const std::vector<T>& v) {
int tensor_size = v.size();
dst->Resize({tensor_size});
dst->mutable_data<T>(context.GetPlace());
int size = v.size();
for (int i = 0; i < size; ++i) {
dst->data<T>()[i] = v[i];
}
}
template <typename T, typename GradCellType>
struct GradLayer {
explicit GradLayer(const GradCellType& cell) : cell_(cell) {}
virtual ~GradLayer() {}
void run_rnn_grad_function(
const framework::ExecutionContext& context,
const platform::CPUDeviceContext& device_ctx, const Tensor* input,
Tensor* input_grad, const Tensor* sequence_length,
std::vector<Tensor>* init_h_unbind, std::vector<Tensor>* init_c_unbind,
std::vector<Tensor>* init_h_grad_unbind,
std::vector<Tensor>* init_c_grad_unbind, Tensor* layer_grad_gate_tensor,
std::vector<Tensor>* layer_gate_tensor_unbind,
std::vector<Tensor>* layer_grad_gate_tensor_unbind,
std::vector<Tensor>* layer_state_tensor_unbind,
std::vector<Tensor>* layer_act_state_tensor_unbind,
std::vector<Tensor>* output_tensor_unbind,
std::vector<Tensor>* output_grad_tensor_unbind,
const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const std::vector<TensorList>& parameter_lists,
std::vector<TensorList>* weight_list_grad, const int& layer_idx,
const int& time_step, const bool& has_sequence_length,
const bool& is_bidirec, const bool& is_reverse) {
const int& direction_num = is_bidirec ? 2 : 1;
const int& current_reverse_idx = is_reverse ? 1 : 0;
const int& current_layer_idx =
direction_num * layer_idx + current_reverse_idx;
int begin_idx = 0;
if (is_reverse) {
begin_idx = time_step;
}
Tensor mask_matrix;
TensorList mask_tensor_list;
int mask_min_length = time_step;
if (has_sequence_length) {
mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
&mask_min_length);
mask_tensor_list = Unbind(mask_matrix);
}
// copy the last_h, last_c for swaping pointer
Tensor a, b;
Tensor* dynamic_grad_last_h = &a;
Tensor* dynamic_grad_last_c = &b;
dynamic_grad_last_h->Resize(last_h_grad_unbind[current_layer_idx].dims());
dynamic_grad_last_h->mutable_data<T>(context.GetPlace());
framework::TensorCopy(last_h_grad_unbind[current_layer_idx],
context.GetPlace(), dynamic_grad_last_h);
if (last_c_grad_unbind.size() > 0) {
dynamic_grad_last_c->Resize(last_c_grad_unbind[current_layer_idx].dims());
dynamic_grad_last_c->mutable_data<T>(context.GetPlace());
framework::TensorCopy(last_c_grad_unbind[current_layer_idx],
context.GetPlace(), dynamic_grad_last_c);
} else {
dynamic_grad_last_c = nullptr;
}
Tensor c, d;
Tensor* dynamic_grad_pre_h = &c;
Tensor* dynamic_grad_pre_c = &d;
math::SetConstant<platform::CPUDeviceContext, T> zero;
if (init_h_grad_unbind->size() > 0) {
dynamic_grad_pre_h->ShareDataWith(
(*init_h_grad_unbind)[current_layer_idx]);
} else {
dynamic_grad_pre_h->Resize(dynamic_grad_last_h->dims());
dynamic_grad_pre_h->mutable_data<T>(context.GetPlace());
zero(device_ctx, dynamic_grad_pre_h, static_cast<T>(0.0));
}
if (init_c_grad_unbind->size() > 0) {
dynamic_grad_pre_c->ShareDataWith(
(*init_c_grad_unbind)[current_layer_idx]);
} else {
if (is_lstm(context) || is_gru(context)) {
dynamic_grad_pre_c->Resize(dynamic_grad_last_h->dims());
dynamic_grad_pre_c->mutable_data<T>(context.GetPlace());
if (is_gru(context)) {
dynamic_grad_last_c = dynamic_grad_pre_c;
}
} else {
dynamic_grad_pre_c = nullptr;
}
}
if (is_reverse) {
// must be reverse the input, output, input_grad, output_grad
// the gate and grad_gate must be reverse
std::reverse(layer_gate_tensor_unbind->begin(),
layer_gate_tensor_unbind->end());
std::reverse(layer_grad_gate_tensor_unbind->begin(),
layer_grad_gate_tensor_unbind->end());
/*
if (has_sequence_length) {
std::reverse(mask_tensor_list.begin(), mask_tensor_list.end());
}*/
std::reverse(output_tensor_unbind->begin(), output_tensor_unbind->end());
std::reverse(output_grad_tensor_unbind->begin(),
output_grad_tensor_unbind->end());
}
Tensor* weight_grad =
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 1]);
weight_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, weight_grad, static_cast<T>(0.0));
Tensor* pre_hidden = nullptr;
Tensor* pre_state = nullptr;
Tensor* hidden = nullptr;
if (is_gru(context)) {
zero(device_ctx,
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
static_cast<T>(0.0));
}
for (int i = time_step - 1; i >= 0; --i) {
if (has_sequence_length) {
this->mask_preprocess(context, &(*output_grad_tensor_unbind)[i],
dynamic_grad_last_h, dynamic_grad_last_c,
dynamic_grad_pre_h, dynamic_grad_pre_c,
mask_tensor_list[i]);
} else {
this->preprocess(context, &(*output_grad_tensor_unbind)[i],
dynamic_grad_last_h);
}
hidden = &(*output_tensor_unbind)[i];
if (i == 0) {
pre_hidden = &(*init_h_unbind)[current_layer_idx];
if (init_c_unbind->size() > 0) {
pre_state = &(*init_c_unbind)[current_layer_idx];
}
} else {
pre_hidden = &(*output_tensor_unbind)[i - 1];
if (layer_state_tensor_unbind->size() > 0) {
pre_state = &(*layer_state_tensor_unbind)[begin_idx + i - 1];
}
}
this->cell_(
context, &(*layer_gate_tensor_unbind)[i],
&(*layer_state_tensor_unbind)[begin_idx + i],
&(*layer_act_state_tensor_unbind)[begin_idx + i], hidden,
&(parameter_lists[layer_idx][current_reverse_idx * 4 + 1]),
pre_hidden, pre_state, dynamic_grad_last_h, dynamic_grad_last_c,
&(*layer_grad_gate_tensor_unbind)[i], weight_grad, dynamic_grad_pre_h,
dynamic_grad_pre_c,
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
mask_tensor_list[i], has_sequence_length);
SwapPoniter(&dynamic_grad_last_h, &dynamic_grad_pre_h);
SwapPoniter(&dynamic_grad_last_c, &dynamic_grad_pre_c);
}
// postproces for gradient for w_hi, X, bias_hi, bias_hh
this->postprocess(context, *layer_grad_gate_tensor, *input, input_grad,
parameter_lists[layer_idx],
&((*weight_list_grad)[layer_idx]), is_reverse);
// copy the gradient to init_c init_h
if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) {
framework::TensorCopy(*dynamic_grad_last_h, context.GetPlace(),
&((*init_h_grad_unbind)[current_layer_idx]));
}
if ((*init_c_grad_unbind).size() > 0 && time_step % 2 == 0) {
framework::TensorCopy(*dynamic_grad_last_c, context.GetPlace(),
&((*init_c_grad_unbind)[current_layer_idx]));
}
}
virtual void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, const TensorList& init_h_unbind,
const TensorList& init_c_unbind, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
const std::vector<TensorList>& weight_list_grad, const int& layer_idx,
const int& gate_num) {}
void preprocess(const framework::ExecutionContext& context,
const Tensor* grad_output, Tensor* grad_last_h) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto output_grad = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1);
auto last_h_grad = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1);
// the output gradient contribute the gradient to last_h
last_h_grad.device(place) = last_h_grad + output_grad;
}
void mask_preprocess(const framework::ExecutionContext& context,
const Tensor* grad_output, Tensor* grad_last_h,
Tensor* grad_last_c, Tensor* grad_pre_h,
Tensor* grad_pre_c, const Tensor& mask_tensor) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, grad_output->dims()[2]));
auto last_h_grad = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1);
auto pre_h_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_h, grad_pre_h->dims().size() - 1);
auto output_grad = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1);
last_h_grad.device(place) = last_h_grad + output_grad * mask_broadcast;
pre_h_grad.device(place) = (1 - mask_broadcast) * last_h_grad;
last_h_grad.device(place) = mask_broadcast * last_h_grad;
if (grad_last_c && grad_pre_c && is_lstm(context)) {
auto last_c_grad = framework::EigenMatrix<T>::Reshape(
*grad_last_c, grad_last_c->dims().size() - 1);
auto pre_c_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_c, grad_pre_c->dims().size() - 1);
pre_c_grad.device(place) = (1 - mask_broadcast) * last_c_grad;
last_c_grad.device(place) = mask_broadcast * last_c_grad;
}
}
void postprocess(const framework::ExecutionContext& context,
const Tensor& grad_gate, const Tensor& input,
Tensor* input_grad, const TensorList& parameters,
TensorList* grad_parameters, const int& is_reverse) {
// we get the grad_gate step by step, and need to bradocast the grad to the
// grad_w_hi, grad_bias_hi, grad_bias_hh
int begin_idx = 0;
if (is_reverse) {
begin_idx = 4;
}
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
// calc the gradient for the w_hi
auto mat_dim_out_grad =
math::CreateMatrixDescriptor(grad_gate.dims(), 0, true);
auto mat_dim_input = math::CreateMatrixDescriptor(input.dims(), 0, false);
mat_dim_out_grad.width_ *= mat_dim_out_grad.batch_size_;
mat_dim_out_grad.batch_size_ = 0;
mat_dim_input.height_ *= mat_dim_input.batch_size_;
mat_dim_input.batch_size_ = 0;
blas.MatMul(grad_gate, mat_dim_out_grad, input, mat_dim_input,
static_cast<T>(1.0), &((*grad_parameters)[begin_idx + 0]),
T(0));
// calc the gradient for the X
auto mat_dim_out_grad_new =
math::CreateMatrixDescriptor(grad_gate.dims(), 0, false);
mat_dim_out_grad_new.height_ *= mat_dim_out_grad_new.batch_size_;
mat_dim_out_grad_new.batch_size_ = 0;
auto mat_dim_parameter =
math::CreateMatrixDescriptor(parameters[0].dims(), 0, false);
blas.MatMul(grad_gate, mat_dim_out_grad_new, parameters[begin_idx + 0],
mat_dim_parameter, static_cast<T>(1.0), input_grad, T(1));
// calc the gradient of Bias_hi, Bias_hh
math::ColwiseSum<platform::CPUDeviceContext, T> col_sum;
Tensor tmp_grad_gate;
tmp_grad_gate.ShareDataWith(grad_gate);
tmp_grad_gate.Resize(
{grad_gate.dims()[0] * grad_gate.dims()[1], grad_gate.dims()[2]});
col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 2]));
// Bias_hh
if (!is_gru(context)) {
col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 3]));
}
}
GradCellType cell_;
};
template <typename T, typename GradCellType>
struct SingleGradLayer : GradLayer<T, GradCellType> {
// explicit SingleGradLayer(GradCellType& cell) : cell_(cell) {}
explicit SingleGradLayer(const GradCellType& cell)
: GradLayer<T, GradCellType>(cell) {}
virtual ~SingleGradLayer() {}
void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, std::vector<Tensor>* init_h_unbind,
std::vector<Tensor>* init_c_unbind, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
std::vector<TensorList>* weight_list_grad, const int& layer_idx,
const int& gate_num) {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& direction_num = is_bidirec ? 2 : 1;
const int& hidden_size = context.Attr<int>("hidden_size");
// in this section, create the gate_state_grad for the postprocess calculate
// ubind the output, the output from [time_step, batch_size, hidden_size]
auto output_tensor_unbind = Unbind(*output);
auto output_grad_tensor_unbind = Unbind(*output_grad);
auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
layer_gate_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size * gate_num});
auto layer_gate_tensor_unbind = Unbind(layer_gate_tensor);
// the gate_tensor and the grad_gate_tensor must be unbind
Tensor layer_grad_gate_tensor;
layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
layer_grad_gate_tensor.mutable_data<T>(context.GetPlace());
auto layer_grad_gate_tensor_unbind = Unbind(layer_grad_gate_tensor);
Tensor layer_state_tensor;
TensorList layer_state_tensor_unbind;
if (state_tensor_unbind.size() > 0) {
layer_state_tensor = state_tensor_unbind[layer_idx];
layer_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_state_tensor_unbind = Unbind(layer_state_tensor);
}
Tensor layer_act_state_tensor;
TensorList layer_act_state_tensor_unbind;
if (act_state_tensor_unbind.size() > 0) {
layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
layer_act_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
}
const bool& has_sequence_length = sequence_length == nullptr ? false : true;
this->run_rnn_grad_function(
context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
&layer_grad_gate_tensor, &layer_gate_tensor_unbind,
&layer_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
&layer_act_state_tensor_unbind, &output_tensor_unbind,
&output_grad_tensor_unbind, last_h_grad_unbind, last_c_grad_unbind,
parameter_lists, weight_list_grad, layer_idx, time_step,
has_sequence_length, is_bidirec, false);
}
};
template <typename T>
void split_tensor_at_last_dim(const framework::ExecutionContext& context,
const platform::CPUDeviceContext& dev_ctx,
const Tensor* output,
std::vector<Tensor*>* output_vec,
const int& axis) {
std::vector<const framework::Tensor*> shape_refer;
(*output_vec)[0]->Resize(
{output->dims()[0], output->dims()[1], output->dims()[2] / 2});
(*output_vec)[0]->mutable_data<T>(context.GetPlace());
(*output_vec)[1]->Resize(
{output->dims()[0], output->dims()[1], output->dims()[2] / 2});
(*output_vec)[1]->mutable_data<T>(context.GetPlace());
shape_refer.emplace_back((*output_vec)[0]);
shape_refer.emplace_back((*output_vec)[1]);
math::SplitFunctor<platform::CPUDeviceContext, T> functor;
functor(dev_ctx, *output, shape_refer, axis, output_vec);
}
template <typename T, typename GradCellType>
struct BidirGradLayer : GradLayer<T, GradCellType> {
explicit BidirGradLayer(const GradCellType& cell)
: GradLayer<T, GradCellType>(cell) {}
virtual ~BidirGradLayer() {}
void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, std::vector<Tensor>* init_h_unbind,
std::vector<Tensor>* init_c_unbind, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
std::vector<TensorList>* weight_list_grad, const int& layer_idx,
const int& gate_num) {
const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& direction_num = is_bidirec ? 2 : 1;
const int& hidden_size = context.Attr<int>("hidden_size");
// split the output two tensor to output_forward, output_backward
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
std::vector<Tensor*> output_vec;
Tensor forward_output;
Tensor backward_output;
std::vector<Tensor> forward_output_tensor_unbind;
std::vector<Tensor> backward_output_tensor_unbind;
// in the last layer, we will use the output as the last hidden
// the output just the concat the forward hidden, backward hidden, so just
// split it
// in other layer, we just split the hidden in the rows
output_vec.emplace_back(&forward_output);
output_vec.emplace_back(&backward_output);
split_tensor_at_last_dim<T>(context, device_ctx, output, &output_vec, 2);
forward_output_tensor_unbind = Unbind(*(output_vec[0]));
backward_output_tensor_unbind = Unbind(*(output_vec[1]));
std::vector<Tensor*> output_grad_vec;
Tensor grad_forward_output;
Tensor grad_backward_output;
output_grad_vec.emplace_back(&grad_forward_output);
output_grad_vec.emplace_back(&grad_backward_output);
split_tensor_at_last_dim<T>(context, device_ctx, output_grad,
&output_grad_vec, 2);
auto forward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[0]));
auto backward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[1]));
// the gate_tensor and the grad_gate_tensor must be unbind
auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
layer_gate_tensor.Resize(
{time_step * 2, batch_size, hidden_size * gate_num});
auto layer_forward_gate_tensor = layer_gate_tensor.Slice(0, time_step);
auto layer_backward_gate_tensor =
layer_gate_tensor.Slice(time_step, 2 * time_step);
auto layer_forward_gate_tensor_unbind = Unbind(layer_forward_gate_tensor);
auto layer_backward_gate_tensor_unbind = Unbind(layer_backward_gate_tensor);
Tensor layer_grad_gate_tensor;
layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
layer_grad_gate_tensor.mutable_data<T>(context.GetPlace());
zero(device_ctx, &layer_grad_gate_tensor, static_cast<T>(0.0));
auto layer_forward_grad_gate_tensor =
layer_grad_gate_tensor.Slice(0, time_step);
auto layer_backward_grad_gate_tensor =
layer_grad_gate_tensor.Slice(time_step, 2 * time_step);
auto layer_forward_grad_gate_tensor_unbind =
Unbind(layer_forward_grad_gate_tensor);
auto layer_backward_grad_gate_tensor_unbind =
Unbind(layer_backward_grad_gate_tensor);
Tensor layer_state_tensor;
TensorList layer_state_tensor_unbind;
if (state_tensor_unbind.size() > 0) {
layer_state_tensor = state_tensor_unbind[layer_idx];
layer_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_state_tensor_unbind = Unbind(layer_state_tensor);
}
Tensor layer_act_state_tensor;
TensorList layer_act_state_tensor_unbind;
if (act_state_tensor_unbind.size() > 0) {
layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
layer_act_state_tensor.Resize(
{time_step * direction_num, batch_size, hidden_size});
layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
}
const bool& has_sequence_length = sequence_length == nullptr ? false : true;
this->run_rnn_grad_function(
context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
&layer_forward_grad_gate_tensor, &layer_forward_gate_tensor_unbind,
&layer_forward_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
&layer_act_state_tensor_unbind, &forward_output_tensor_unbind,
&forward_output_grad_tensor_unbind, last_h_grad_unbind,
last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx,
time_step, has_sequence_length, is_bidirec, false);
this->run_rnn_grad_function(
context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
&layer_backward_grad_gate_tensor, &layer_backward_gate_tensor_unbind,
&layer_backward_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
&layer_act_state_tensor_unbind, &backward_output_tensor_unbind,
&backward_output_grad_tensor_unbind, last_h_grad_unbind,
last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx,
time_step, has_sequence_length, is_bidirec, true);
}
};
template <typename T>
void backup_tensor(const framework::ExecutionContext& context, Tensor* dst,
Tensor* src) {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
dst->Resize(src->dims());
dst->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*src, device_ctx.GetPlace(), device_ctx, dst);
}
template <typename T>
struct GradCell {
virtual ~GradCell() {}
virtual void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden,
Tensor* grad_state, Tensor* grad_gate,
Tensor* grad_weight_hh, Tensor* grad_pre_hidden,
Tensor* grad_pre_state, Tensor* grad_bias_hh,
const Tensor& mask_tensor,
bool has_sequence_length) const {}
void postprocess_pre_hidden_grad(const framework::ExecutionContext& context,
Tensor* grad_pre_hidden,
Tensor* grad_pre_hidden_bak,
Tensor* grad_pre_state,
Tensor* grad_pre_state_bak,
const Tensor& mask_tensor,
bool has_sequence_length) const {
if (has_sequence_length) {
auto& place =
*context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, grad_pre_hidden->dims()[2]));
auto pre_hidden_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden, grad_pre_hidden->dims().size() - 1);
auto pre_hidden_bak_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden_bak, grad_pre_hidden_bak->dims().size() - 1);
pre_hidden_grad.device(place) =
(1 - mask_broadcast) * pre_hidden_bak_grad +
pre_hidden_grad * mask_broadcast;
if (grad_pre_state) {
auto pre_state_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_state, grad_pre_state->dims().size() - 1);
auto pre_state_bak_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_state_bak, grad_pre_state_bak->dims().size() - 1);
pre_state_grad.device(place) =
(1 - mask_broadcast) * pre_state_bak_grad +
pre_state_grad * mask_broadcast;
}
}
}
virtual void update_pre_hidden_grad(
const framework::ExecutionContext& context, Tensor* grad_gate,
const Tensor* weight_hh, Tensor* grad_pre_hidden,
Tensor* grad_pre_hidden_bak, Tensor* grad_pre_state,
Tensor* grad_pre_state_bak, const Tensor& mask_tensor,
bool has_sequence_length) const {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
Tensor* grad_gate_tmp = grad_gate;
auto mat_dim_a =
math::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false);
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, false);
blas.MatMul(*grad_gate_tmp, mat_dim_a, *weight_hh, mat_dim_b,
static_cast<T>(1.0), grad_pre_hidden, 0);
postprocess_pre_hidden_grad(context, grad_pre_hidden, grad_pre_hidden_bak,
grad_pre_state, grad_pre_state_bak, mask_tensor,
has_sequence_length);
}
virtual void update_weight_hh_grad(const framework::ExecutionContext& context,
Tensor* grad_gate, Tensor* pre_hidden,
Tensor* grad_weight_hh) const {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
auto mat_dim_c = math::CreateMatrixDescriptor(grad_gate->dims(), 0, true);
mat_dim_c.height_ *= mat_dim_c.batch_size_;
mat_dim_c.batch_size_ = 0;
auto mat_dim_d = math::CreateMatrixDescriptor(pre_hidden->dims(), 0, false);
mat_dim_d.height_ *= mat_dim_d.batch_size_;
mat_dim_d.batch_size_ = 0;
blas.MatMul(*grad_gate, mat_dim_c, *pre_hidden, mat_dim_d,
static_cast<T>(1.0), grad_weight_hh, static_cast<T>(1.0));
}
};
template <typename T, template <typename> class EigenActivationBackwardFunctor>
struct SimpleRNNGradCell : GradCell<T> {
void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_bias_hh, const Tensor& mask_tensor,
bool has_sequence_length) const override {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
Tensor grad_pre_hidden_bak;
if (has_sequence_length) {
backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
}
// h = act(z)
// update dz
auto dz = EigenVector<T>::Flatten(
GET_DATA_SAFELY(grad_gate, "Output", "dz", "Grad"));
auto dh = EigenVector<T>::Flatten(
GET_DATA_SAFELY(grad_hidden, "Input", "dh", "Grad"));
auto h = EigenVector<T>::Flatten(
GET_DATA_SAFELY(hidden_tensor, "Input", "h", "Value"));
// useless, but need this argument to execute functor
auto z = EigenVector<T>::Flatten(
GET_DATA_SAFELY(gate_tensor, "Input", "z", "Value"));
auto* place = device_ctx.eigen_device();
EigenActivationBackwardFunctor<T> functor;
functor(*place, z, h, dh, dz);
// update grad_weight_hh, grad_pre_hidden
this->update_pre_hidden_grad(context, grad_gate, weight_hh, grad_pre_hidden,
&grad_pre_hidden_bak, nullptr, nullptr,
mask_tensor, has_sequence_length);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh);
}
};
template <typename T>
struct GRUGradCell : GradCell<T> {
void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_bias_hh, const Tensor& mask_tensor,
bool has_sequence_length) const override {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
size_t frame_size = pre_hidden->dims()[2];
size_t batch_size = pre_hidden->dims()[1];
Tensor grad_pre_hidden_bak;
if (has_sequence_length) {
backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
}
// zero pre_hidden
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, grad_pre_hidden, static_cast<T>(0.0));
math::GRUMetaValue<T> gru_value;
math::GRUMetaGrad<T> gru_grad;
gru_value.gate_value = gate_tensor->data<T>();
gru_value.prev_out_value = pre_hidden->data<T>();
gru_value.reset_output_value = state_tensor->data<T>();
gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_value.gate_weight = weight_hh->data<T>();
gru_grad.gate_grad = grad_gate->data<T>();
gru_grad.reset_output_grad = grad_state->data<T>();
gru_grad.prev_out_grad = grad_pre_hidden->data<T>();
gru_grad.output_grad = grad_hidden->data<T>();
gru_grad.gate_weight_grad = grad_weight_hh->data<T>();
gru_grad.state_weight_grad =
grad_weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_grad.bias_hh_grad = grad_bias_hh->data<T>();
auto act_gate = math::detail::GetActivationType("sigmoid_v2");
auto act_node = math::detail::GetActivationType("tanh_v2");
math::GRUUnitGradFunctorV2<platform::CPUDeviceContext, T>::compute(
device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node,
act_gate);
this->postprocess_pre_hidden_grad(context, grad_pre_hidden,
&grad_pre_hidden_bak, nullptr, nullptr,
mask_tensor, has_sequence_length);
}
};
template <typename T>
struct LSTMGradCell : GradCell<T> {
void operator()(const framework::ExecutionContext& context,
Tensor* gate_tensor, Tensor* state_tensor,
Tensor* act_state_tensor, Tensor* hidden_tensor,
const Tensor* weight_hh, Tensor* pre_hidden,
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_bias_hh, const Tensor& mask_tensor,
bool has_sequence_length) const override {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
size_t frame_size = state_tensor->dims()[2];
size_t batch_size = state_tensor->dims()[1];
Tensor grad_pre_hidden_bak;
Tensor grad_pre_state_bak;
if (has_sequence_length) {
backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
backup_tensor<T>(context, &grad_pre_state_bak, grad_pre_state);
}
math::LstmMetaValue<T> lstm_value;
math::LstmMetaGrad<T> lstm_grad;
create_lstm_value(&lstm_value);
create_lstm_grad(&lstm_grad);
lstm_value.gate_value = gate_tensor->data<T>();
lstm_value.state_value = state_tensor->data<T>();
lstm_value.state_active_value = act_state_tensor->data<T>();
lstm_value.prev_state_value = pre_state->data<T>();
lstm_grad.state_grad = grad_state->data<T>();
lstm_grad.gate_grad = grad_gate->data<T>();
lstm_grad.output_grad = grad_hidden->data<T>();
lstm_grad.prev_state_grad = grad_pre_state->data<T>();
lstm_value.output_value = nullptr;
lstm_grad.state_active_grad = nullptr;
auto gate_act = math::detail::GetActivationType("sigmoid_v2");
auto state_act = math::detail::GetActivationType("tanh_v2");
auto cand_act = math::detail::GetActivationType("tanh_v2");
T cell_clip = 0.0;
math::LstmUnitGradFunctor<platform::CPUDeviceContext, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip,
gate_act, state_act, cand_act, false);
this->update_pre_hidden_grad(
context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak,
grad_pre_state, &grad_pre_state_bak, mask_tensor, has_sequence_length);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh);
}
};
template <typename GradCellType,
template <typename, typename> class SingleGradLayerT,
template <typename, typename> class BidirGradLayerT, typename T>
void RnnGradFunc(const framework::ExecutionContext& context,
const int& gate_num) {
// get the tensor pointer for the input
auto* input = context.Input<Tensor>("Input");
auto weight_list = context.MultiInput<Tensor>("WeightList");
auto pre_state = context.MultiInput<Tensor>("PreState");
const Tensor* init_h = pre_state[0];
const Tensor* init_c = nullptr;
if (is_lstm(context)) {
init_c = pre_state[1];
}
auto* reserve_state = context.Input<Tensor>("Reserve");
auto* dropout_state = context.Input<Tensor>("DropoutState");
auto* output = context.Input<Tensor>("Out");
auto* output_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto state_grad = context.MultiInput<Tensor>(framework::GradVarName("State"));
const Tensor* last_h_grad = state_grad[0];
const Tensor* last_c_grad = nullptr;
if (is_lstm(context)) {
last_c_grad = state_grad[1];
}
bool has_seq_length = context.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr;
if (has_seq_length) {
sequence_length = context.Input<Tensor>("SequenceLength");
}
// get the tensor pointer for the output
auto* input_grad = context.Output<Tensor>(framework::GradVarName("Input"));
auto weight_grad_list = context.MultiOutput<framework::Tensor>(
framework::GradVarName("WeightList"));
auto pre_state_grad =
context.MultiOutput<Tensor>(framework::GradVarName("PreState"));
Tensor* init_h_grad = nullptr;
Tensor* init_c_grad = nullptr;
if (pre_state_grad.size() > 0) { // has gradient
init_h_grad = pre_state_grad[0];
if (is_lstm(context)) {
init_c_grad = pre_state_grad[1];
}
}
// get the attributes for the calcluate
const int& num_layers = context.Attr<int>("num_layers");
const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const float& dropout_prob = context.Attr<float>("dropout_prob");
const bool& is_test = context.Attr<bool>("is_test");
// get the input_size, batch_size, time_step, hidden_size
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& hidden_size = context.Attr<int>("hidden_size");
const int& direction_num = is_bidirec ? 2 : 1;
// allocate the memory and initization the input_grad
Tensor input_grad_value;
if (!input_grad) {
input_grad = &input_grad_value;
}
input_grad->mutable_data<T>(input->dims(), context.GetPlace());
if (init_h_grad) {
init_h_grad->mutable_data<T>(init_h->dims(), context.GetPlace());
}
if (init_c_grad) {
init_c_grad->mutable_data<T>(init_c->dims(), context.GetPlace());
}
// reset the parameter to sorted order and allocate the memory
std::vector<TensorList> parameter_lists;
parameter_lists.reserve(num_layers);
reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec,
&parameter_lists);
for (unsigned int i = 0; i < weight_grad_list.size(); ++i) {
weight_grad_list[i]->mutable_data<T>(context.GetPlace());
}
std::vector<TensorList> parameter_lists_grad;
parameter_lists_grad.reserve(num_layers);
reset_parameter_vector(weight_grad_list, num_layers, gate_num, is_bidirec,
&parameter_lists_grad);
// resolve the state of reverse_state
Tensor gate_tensor;
Tensor state_tensor;
Tensor act_state_tensor;
Tensor hidden_tensor;
SplitReserveData(context, reserve_state, &gate_tensor, &state_tensor,
&act_state_tensor, &hidden_tensor, direction_num, time_step,
batch_size, hidden_size, gate_num, num_layers);
int gate_num_tmp = gate_num;
if (gate_num == 0) {
gate_num_tmp = 1;
}
gate_tensor.Resize({num_layers, time_step * direction_num, batch_size,
hidden_size * gate_num_tmp});
if (state_tensor.numel() > 0) {
state_tensor.Resize(
{num_layers, time_step * direction_num, batch_size, hidden_size});
}
if (act_state_tensor.numel() > 0) {
act_state_tensor.Resize(
{num_layers, time_step * direction_num, batch_size, hidden_size});
}
if (num_layers > 1) {
hidden_tensor.Resize(
{num_layers - 1, time_step, batch_size, hidden_size * direction_num});
}
// unbind
auto last_h_grad_unbind = Unbind(*last_h_grad);
auto gate_tensor_unbind = Unbind(gate_tensor);
TensorList last_c_grad_unbind;
if (last_c_grad) {
last_c_grad_unbind = Unbind(*last_c_grad);
}
TensorList init_h_unbind, init_c_unbind;
TensorList init_h_grad_unbind, init_c_grad_unbind;
TensorList state_tensor_unbind, act_state_tensor_unbind;
TensorList hidden_tensor_unbind;
init_h_unbind = Unbind(*init_h);
if (init_c) {
init_c_unbind = Unbind(*init_c);
}
if (init_h_grad != nullptr) {
init_h_grad_unbind = Unbind(*init_h_grad);
}
if (init_c_grad != nullptr) {
init_c_grad_unbind = Unbind(*init_c_grad);
}
if (state_tensor.numel() > 0) {
state_tensor_unbind = Unbind(state_tensor);
}
if (act_state_tensor.numel() > 0) {
act_state_tensor_unbind = Unbind(act_state_tensor);
}
if (num_layers > 1) {
hidden_tensor_unbind = Unbind(hidden_tensor);
}
// squeeze the hidden first dim
for (unsigned int i = 0; i < hidden_tensor_unbind.size(); i++) {
hidden_tensor_unbind[i].Resize(
framework::slice_ddim(hidden_tensor_unbind[i].dims(), 1,
hidden_tensor_unbind[i].dims().size()));
}
// add the output tensor to the hidden vector
Tensor tmp;
hidden_tensor_unbind.emplace_back(tmp);
hidden_tensor_unbind[num_layers - 1].ShareDataWith(*output);
GradCellType cell;
Tensor layer_input;
Tensor layer_output;
Tensor* layer_input_grad_holder = nullptr;
Tensor tmp_out;
tmp_out.ShareDataWith(*output_grad);
Tensor* layer_output_grad_holder = &tmp_out;
Tensor input_grad_temp;
Tensor output_grad_temp;
bool has_allocate_mem = false;
for (int i = num_layers - 1; i >= 0; --i) {
// the layer input output had saved, just use the data
if (i > 0) {
if (layer_input.numel() == 0) {
layer_input.Resize(hidden_tensor_unbind[i - 1].dims());
layer_input.mutable_data<T>(context.GetPlace());
}
dropout_helper<T>(context, &hidden_tensor_unbind[i - 1], &layer_input,
dropout_state, dropout_prob);
} else {
layer_input.ShareDataWith(*input);
}
layer_output.ShareDataWith(hidden_tensor_unbind[i]);
if (num_layers == 1) {
layer_input_grad_holder = input_grad;
} else {
if (i == num_layers - 1) {
input_grad_temp.Resize(layer_input.dims());
input_grad_temp.mutable_data<T>(context.GetPlace());
layer_input_grad_holder = &input_grad_temp;
}
}
if (is_bidirec) {
BidirGradLayerT<T, GradCellType> layer(cell);
layer(context, &layer_input, &layer_output, &init_h_unbind,
&init_c_unbind, last_h_grad_unbind, last_c_grad_unbind,
gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind,
layer_output_grad_holder, parameter_lists, sequence_length,
layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind,
&parameter_lists_grad, i, gate_num_tmp);
} else {
SingleGradLayerT<T, GradCellType> layer(cell);
layer(context, &layer_input, &layer_output, &init_h_unbind,
&init_c_unbind, last_h_grad_unbind, last_c_grad_unbind,
gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind,
layer_output_grad_holder, parameter_lists, sequence_length,
layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind,
&parameter_lists_grad, i, gate_num_tmp);
}
// calcluate the dropout gradient for the layer_input_grad_holder
// dropout_state save in the forward process
if (i > 0) {
if ((!is_test) && (dropout_prob != 0)) {
dropout_cpu_grad_function_inplace<T>(context, layer_input_grad_holder,
dropout_state, dropout_prob);
}
}
if (i - 1 == 0) {
layer_output_grad_holder = input_grad;
} else {
if (!has_allocate_mem) {
output_grad_temp.Resize(layer_input_grad_holder->dims());
output_grad_temp.mutable_data<T>(context.GetPlace());
layer_output_grad_holder = &output_grad_temp;
has_allocate_mem = true;
}
}
SwapPoniter(&layer_input_grad_holder, &layer_output_grad_holder);
}
}
template <typename DeviceContext, typename T>
class RNNCPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int gate_num = 4;
if (is_lstm(ctx)) {
RnnGradFunc<LSTMGradCell<T>, SingleGradLayer, BidirGradLayer, T>(
ctx, gate_num);
} else if (is_gru(ctx)) {
gate_num = 3;
RnnGradFunc<GRUGradCell<T>, SingleGradLayer, BidirGradLayer, T>(ctx,
gate_num);
// run gru
} else if (is_rnn_relu(ctx)) {
gate_num = 1;
RnnGradFunc<SimpleRNNGradCell<T, ReluGradFunctor>, SingleGradLayer,
BidirGradLayer, T>(ctx, gate_num);
// run rnn
} else if (is_rnn_tanh(ctx)) {
gate_num = 1;
RnnGradFunc<SimpleRNNGradCell<T, TanhGradFunctor>, SingleGradLayer,
BidirGradLayer, T>(ctx, gate_num);
}
}
};
} // namespace operators
} // namespace paddle