You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
142 lines
5.6 KiB
142 lines
5.6 KiB
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#pragma once
|
|
#include "paddle/framework/op_registry.h"
|
|
#include "paddle/operators/math/lstm_compute.h"
|
|
#include "paddle/operators/math/math_function.h"
|
|
#include "paddle/operators/math/sequence2batch.h"
|
|
|
|
namespace paddle {
|
|
namespace operators {
|
|
|
|
using framework::LoDTensor;
|
|
using framework::Tensor;
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
typename IndexType = Eigen::DenseIndex>
|
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
|
|
|
template <typename Place, typename T>
|
|
class LSTMKernel : public framework::OpKernel<T> {
|
|
public:
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
auto* input = ctx.Input<framework::LoDTensor>("Input");
|
|
auto* weight = ctx.Input<framework::Tensor>("Weight");
|
|
auto* bias = ctx.Input<framework::Tensor>("Bias");
|
|
|
|
auto* batch_gate = ctx.Output<framework::LoDTensor>("BatchGate");
|
|
batch_gate->mutable_data<T>(ctx.GetPlace());
|
|
auto* hidden_out = ctx.Output<framework::LoDTensor>("Hidden");
|
|
hidden_out->mutable_data<T>(ctx.GetPlace());
|
|
auto* cell_out = ctx.Output<framework::LoDTensor>("Cell");
|
|
cell_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
// Now the function ShareLoD in InferShape is not implemented.
|
|
// So copy LoD here.
|
|
ctx.ShareLoD("Input", "Hidden");
|
|
ctx.ShareLoD("Input", "Cell");
|
|
|
|
bool is_reverse = ctx.Attr<bool>("isReverse");
|
|
math::LoDTensor2BatchFunctor<Place, T> to_batch;
|
|
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse);
|
|
|
|
auto in_dims = input->dims();
|
|
int frame_size = static_cast<int>(in_dims[1] / 4);
|
|
framework::DDim dims({in_dims[0], frame_size});
|
|
|
|
if (bias) {
|
|
Eigen::array<int, 2> extents({{1, 4 * frame_size}});
|
|
Eigen::array<int, 2> offsets({{0, 0}});
|
|
auto b = EigenMatrix<T>::From(*bias);
|
|
auto gate = EigenMatrix<T>::From(*batch_gate);
|
|
gate.device(ctx.GetEigenDevice<Place>()) =
|
|
gate +
|
|
b.slice(offsets, extents)
|
|
.reshape(Eigen::array<int, 2>({{1, frame_size * 4}}))
|
|
.broadcast(
|
|
Eigen::array<int, 2>({{static_cast<int>(in_dims[0]), 1}}));
|
|
}
|
|
|
|
math::LstmMetaValue<T> lstm_value;
|
|
T* bias_data = const_cast<T*>(bias->data<T>());
|
|
// the code style in LstmMetaValue will be updated later.
|
|
lstm_value.checkIg = bias_data + 4 * frame_size;
|
|
lstm_value.checkFg = lstm_value.checkIg + frame_size;
|
|
lstm_value.checkOg = lstm_value.checkFg + frame_size;
|
|
lstm_value.prevStateValue = nullptr;
|
|
|
|
framework::LoDTensor batch_out;
|
|
batch_out.mutable_data<T>(dims, ctx.GetPlace());
|
|
framework::LoDTensor batch_cell;
|
|
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
|
|
framework::LoDTensor batch_cell_pre_act;
|
|
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
|
|
|
|
auto& batch_starts = batch_gate->lod()[0];
|
|
size_t num_batch = batch_starts.size() - 1;
|
|
auto gate_act = ctx.Attr<std::string>("gateActivation");
|
|
auto cell_act = ctx.Attr<std::string>("cellActivation");
|
|
auto cand_act = ctx.Attr<std::string>("candidateActivation");
|
|
|
|
for (size_t n = 0; n < num_batch; n++) {
|
|
int bstart = static_cast<int>(batch_starts[n]);
|
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
|
|
|
Tensor gate_t = batch_gate->Slice<T>(bstart, bend);
|
|
Tensor out_t = batch_out.Slice<T>(bstart, bend);
|
|
Tensor cell_t = batch_cell.Slice<T>(bstart, bend);
|
|
Tensor cell_pre_act_t = batch_cell_pre_act.Slice<T>(bstart, bend);
|
|
|
|
int cur_batch_size = bend - bstart;
|
|
|
|
if (n != 0) {
|
|
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
|
|
int pre_h_end = pre_h_start + cur_batch_size;
|
|
auto pre_hidden_t = batch_out.Slice<T>(pre_h_start, pre_h_end);
|
|
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
|
|
*weight, false, static_cast<T>(1.0), &gate_t,
|
|
static_cast<T>(1.0));
|
|
}
|
|
// else if : FIXME support the initial hidden and cell
|
|
|
|
lstm_value.gateValue = gate_t.data<T>();
|
|
lstm_value.outputValue = out_t.data<T>();
|
|
lstm_value.stateValue = cell_t.data<T>();
|
|
lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
|
|
math::LstmUnitFunctor<Place, T>::compute(ctx.device_context(), lstm_value,
|
|
frame_size, cur_batch_size,
|
|
gate_act, cell_act, cand_act);
|
|
lstm_value.prevStateValue = lstm_value.stateValue;
|
|
}
|
|
|
|
math::Batch2LoDTensorFunctor<Place, T> to_seq;
|
|
batch_out.set_lod(batch_gate->lod());
|
|
// restore the output hidden in LoDTensor from the batch hidden
|
|
to_seq(ctx.device_context(), batch_out, *hidden_out);
|
|
|
|
batch_cell.set_lod(batch_gate->lod());
|
|
// restore the output cell state in LoDTensor from the batch cell
|
|
to_seq(ctx.device_context(), batch_cell, *cell_out);
|
|
}
|
|
};
|
|
|
|
template <typename Place, typename T>
|
|
class LSTMGradKernel : public framework::OpKernel<T> {
|
|
public:
|
|
void Compute(const framework::ExecutionContext& ctx) const override {}
|
|
};
|
|
|
|
} // namespace operators
|
|
} // namespace paddle
|