|
|
|
@ -36,6 +36,9 @@ class LSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* weight = ctx.Input<Tensor>("Weight");
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias");
|
|
|
|
|
|
|
|
|
|
auto* hidden_t0 = ctx.Input<Tensor>("H0");
|
|
|
|
|
auto* cell_t0 = ctx.Input<Tensor>("C0");
|
|
|
|
|
|
|
|
|
|
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
|
|
|
|
|
batch_gate->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
|
|
|
|
@ -43,12 +46,7 @@ class LSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell");
|
|
|
|
|
cell_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
// Now the function ShareLoD in InferShape is not implemented.
|
|
|
|
|
// So copy LoD here.
|
|
|
|
|
ctx.ShareLoD("Input", "Hidden");
|
|
|
|
|
ctx.ShareLoD("Input", "Cell");
|
|
|
|
|
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("isReverse");
|
|
|
|
|
bool is_reverse = ctx.Attr<bool>("is_reverse");
|
|
|
|
|
math::LoDTensor2BatchFunctor<Place, T> to_batch;
|
|
|
|
|
auto& device_ctx = ctx.device_context();
|
|
|
|
|
to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
|
|
|
|
@ -84,6 +82,13 @@ class LSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
lstm_value.checkOg = nullptr;
|
|
|
|
|
}
|
|
|
|
|
lstm_value.prevStateValue = nullptr;
|
|
|
|
|
Tensor ordered_c0;
|
|
|
|
|
if (cell_t0) {
|
|
|
|
|
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
|
|
|
|
|
const size_t* order = batch_gate->lod()[2].data();
|
|
|
|
|
row_shuffle(device_ctx, *cell_t0, order, ordered_c0, true);
|
|
|
|
|
lstm_value.prevStateValue = ordered_c0.data<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Use the local variable as here.
|
|
|
|
|
LoDTensor batch_hidden, batch_cell;
|
|
|
|
@ -94,9 +99,9 @@ class LSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto batch_starts = batch_gate->lod()[0];
|
|
|
|
|
size_t num_batch = batch_starts.size() - 1;
|
|
|
|
|
auto gate_act = ctx.Attr<std::string>("gateActivation");
|
|
|
|
|
auto cell_act = ctx.Attr<std::string>("cellActivation");
|
|
|
|
|
auto cand_act = ctx.Attr<std::string>("candidateActivation");
|
|
|
|
|
auto gate_act = ctx.Attr<std::string>("gate_activation");
|
|
|
|
|
auto cell_act = ctx.Attr<std::string>("cell_activation");
|
|
|
|
|
auto cand_act = ctx.Attr<std::string>("candidate_activation");
|
|
|
|
|
|
|
|
|
|
for (size_t n = 0; n < num_batch; n++) {
|
|
|
|
|
int bstart = static_cast<int>(batch_starts[n]);
|
|
|
|
@ -109,15 +114,22 @@ class LSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
int cur_batch_size = bend - bstart;
|
|
|
|
|
|
|
|
|
|
if (n != 0) {
|
|
|
|
|
if (n > 0) {
|
|
|
|
|
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
|
|
|
|
|
int pre_h_end = pre_h_start + cur_batch_size;
|
|
|
|
|
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
|
|
|
|
|
math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false,
|
|
|
|
|
static_cast<T>(1.0), &gate_t,
|
|
|
|
|
static_cast<T>(1.0));
|
|
|
|
|
} else if (hidden_t0) {
|
|
|
|
|
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
|
|
|
|
|
Tensor ordered_h0;
|
|
|
|
|
const size_t* order = batch_gate->lod()[2].data();
|
|
|
|
|
row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true);
|
|
|
|
|
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false,
|
|
|
|
|
static_cast<T>(1.0), &gate_t,
|
|
|
|
|
static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
// else if : FIXME support the initial hidden and cell
|
|
|
|
|
|
|
|
|
|
lstm_value.gateValue = gate_t.data<T>();
|
|
|
|
|
lstm_value.outputValue = out_t.data<T>();
|
|
|
|
@ -160,6 +172,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
|
|
|
|
|
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
|
|
|
|
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0");
|
|
|
|
|
auto* c0 = ctx.Input<Tensor>("C0");
|
|
|
|
|
|
|
|
|
|
auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
|
|
|
|
|
auto* c0_g = ctx.Output<Tensor>(framework::GradVarName("C0"));
|
|
|
|
|
|
|
|
|
|
auto& device_ctx = ctx.device_context();
|
|
|
|
|
math::SetConstant<Place, T> zero;
|
|
|
|
|
if (weight_g) {
|
|
|
|
@ -167,6 +185,14 @@ class LSTMGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
zero(device_ctx, weight_g, static_cast<T>(0.0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
|
|
|
|
|
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
|
|
|
|
|
const size_t* order = batch_gate->lod()[2].data();
|
|
|
|
|
if (c0) {
|
|
|
|
|
ordered_c0.mutable_data<T>(c0->dims(), ctx.GetPlace());
|
|
|
|
|
row_shuffle(device_ctx, *c0, order, ordered_c0, true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto in_dims = input->dims();
|
|
|
|
|
auto out_dims = hidden_g->dims();
|
|
|
|
|
int frame_size = static_cast<int>(in_dims[1] / 4);
|
|
|
|
@ -226,9 +252,9 @@ class LSTMGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
|
|
|
|
|
batch_gate_g.set_lod(batch_gate->lod());
|
|
|
|
|
|
|
|
|
|
auto gate_act = ctx.Attr<std::string>("gateActivation");
|
|
|
|
|
auto cell_act = ctx.Attr<std::string>("cellActivation");
|
|
|
|
|
auto cand_act = ctx.Attr<std::string>("candidateActivation");
|
|
|
|
|
auto gate_act = ctx.Attr<std::string>("gate_activation");
|
|
|
|
|
auto cell_act = ctx.Attr<std::string>("cell_activation");
|
|
|
|
|
auto cand_act = ctx.Attr<std::string>("candidate_activation");
|
|
|
|
|
|
|
|
|
|
auto batch_starts = batch_gate->lod()[0];
|
|
|
|
|
size_t num_batch = batch_starts.size() - 1;
|
|
|
|
@ -250,15 +276,24 @@ class LSTMGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
lstm_grad.gateGrad = gate_g.data<T>();
|
|
|
|
|
lstm_grad.outputGrad = out_g.data<T>();
|
|
|
|
|
|
|
|
|
|
if (n) {
|
|
|
|
|
if (n > 0) {
|
|
|
|
|
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
|
|
|
|
|
Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
|
|
|
|
|
Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
|
|
|
|
|
lstm_value.prevStateValue = cell_pre.data<T>();
|
|
|
|
|
lstm_grad.prevStateGrad = cell_pre_g.data<T>();
|
|
|
|
|
} else {
|
|
|
|
|
lstm_value.prevStateValue = nullptr;
|
|
|
|
|
lstm_grad.prevStateGrad = nullptr;
|
|
|
|
|
if (c0) {
|
|
|
|
|
lstm_value.prevStateValue = ordered_c0.data<T>();
|
|
|
|
|
} else {
|
|
|
|
|
lstm_value.prevStateValue = nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (c0 && c0_g) {
|
|
|
|
|
ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
|
|
|
|
|
lstm_grad.prevStateGrad = ordered_c0_g.data<T>();
|
|
|
|
|
} else {
|
|
|
|
|
lstm_grad.prevStateGrad = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int cur_batch_size = bend - bstart;
|
|
|
|
@ -266,7 +301,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
|
|
|
|
|
gate_act, cell_act, cand_act);
|
|
|
|
|
|
|
|
|
|
if (n != 0) {
|
|
|
|
|
if (n > 0) {
|
|
|
|
|
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
|
|
|
|
|
int pre_h_end = pre_h_start + cur_batch_size;
|
|
|
|
|
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
|
|
|
|
@ -280,6 +315,20 @@ class LSTMGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
static_cast<T>(1.0), weight_g,
|
|
|
|
|
static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (h0 && weight_g) {
|
|
|
|
|
ordered_h0.mutable_data<T>(h0->dims(), ctx.GetPlace());
|
|
|
|
|
row_shuffle(device_ctx, *h0, order, ordered_h0, true);
|
|
|
|
|
math::matmul<Place, T>(device_ctx, ordered_h0, true, gate_g, false,
|
|
|
|
|
static_cast<T>(1.0), weight_g,
|
|
|
|
|
static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
if (h0 && h0_g) {
|
|
|
|
|
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
|
|
|
|
|
math::matmul<Place, T>(device_ctx, gate_g, false, *weight, true,
|
|
|
|
|
static_cast<T>(1.0), &ordered_h0_g,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -302,6 +351,15 @@ class LSTMGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::gemv<Place, T>(device_ctx, true, m, n, 1., batch_gate_g.data<T>(),
|
|
|
|
|
ones.data<T>(), 0., bias_g->data<T>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (h0 && h0_g) {
|
|
|
|
|
h0_g->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
row_shuffle(device_ctx, ordered_h0_g, order, *h0_g, false);
|
|
|
|
|
}
|
|
|
|
|
if (c0 && c0_g) {
|
|
|
|
|
c0_g->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
row_shuffle(device_ctx, ordered_c0_g, order, *c0_g, false);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|