|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/activation_op.h"
|
|
|
|
@ -21,17 +22,50 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/math/detail/activation_functions.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/lstm_compute.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/sequence2batch.h"
|
|
|
|
|
#include "paddle/fluid/platform/transform.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using platform::Transform;
|
|
|
|
|
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class _ClipFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit _ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
|
|
|
|
|
HOSTDEVICE T operator()(const T& x) const {
|
|
|
|
|
if (x < min_)
|
|
|
|
|
return min_;
|
|
|
|
|
else if (x > max_)
|
|
|
|
|
return max_;
|
|
|
|
|
else
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T min_;
|
|
|
|
|
T max_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class _ClipGradFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit _ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
|
|
|
|
|
HOSTDEVICE T operator()(const T& x, const T& y) const {
|
|
|
|
|
return (y > min_ && y < max_) ? x : 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T min_;
|
|
|
|
|
T max_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
inline void ReorderInitState(const DeviceContext& ctx,
|
|
|
|
|
const framework::Tensor& src,
|
|
|
|
@ -60,6 +94,25 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_THROW("unsupported activation type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Print(const Tensor& t, std::string name) const {
|
|
|
|
|
VLOG(1) << name << "size = " << t.numel();
|
|
|
|
|
size_t size = t.numel();
|
|
|
|
|
T* d = t.data<T>();
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
std::vector<T> vec;
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(t.place())->Wait();
|
|
|
|
|
if (platform::is_gpu_place(t.place())) {
|
|
|
|
|
vec.resize(size);
|
|
|
|
|
cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost);
|
|
|
|
|
d = vec.data();
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
VLOG(1) << name << " data_ptr = " << static_cast<void*>(d);
|
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
|
VLOG(1) << d[i] << ",";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* input = ctx.Input<LoDTensor>("Input");
|
|
|
|
|
auto* weight = ctx.Input<Tensor>("Weight");
|
|
|
|
@ -67,9 +120,11 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* bias = ctx.Input<Tensor>("Bias");
|
|
|
|
|
|
|
|
|
|
auto* hidden_t0 = ctx.Input<Tensor>("H0");
|
|
|
|
|
auto* ordered_proj0 = ctx.Output<Tensor>("OrderedP0");
|
|
|
|
|
auto* cell_t0 = ctx.Input<Tensor>("C0");
|
|
|
|
|
|
|
|
|
|
auto proj_clip = static_cast<T>(ctx.Attr<float>("proj_clip"));
|
|
|
|
|
auto cell_clip = static_cast<T>(ctx.Attr<float>("cell_clip"));
|
|
|
|
|
|
|
|
|
|
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
|
|
|
|
|
batch_gate->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* proj_out = ctx.Output<LoDTensor>("Projection");
|
|
|
|
@ -110,6 +165,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
lstmp_value.prev_state_value = nullptr;
|
|
|
|
|
Tensor ordered_c0;
|
|
|
|
|
Tensor ordered_h0;
|
|
|
|
|
|
|
|
|
|
framework::Vector<size_t> order(batch_gate->lod()[2]);
|
|
|
|
|
|
|
|
|
@ -169,18 +225,10 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
// Since the batch computing for LSTMP reorders the input sequence
|
|
|
|
|
// according to their length. The initialized hidden state also needs
|
|
|
|
|
// to reorder.
|
|
|
|
|
|
|
|
|
|
Tensor ordered_h0;
|
|
|
|
|
ordered_proj0->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
VLOG(1) << "qxz h0 used";
|
|
|
|
|
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
|
|
|
|
|
&ordered_h0, true);
|
|
|
|
|
blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast<T>(1.0),
|
|
|
|
|
ordered_proj0, static_cast<T>(0.0));
|
|
|
|
|
if (proj_act != math::detail::ActivationType::kIdentity) {
|
|
|
|
|
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
|
|
|
|
|
ActCompute(cell_act, place, proj0_dev, proj0_dev);
|
|
|
|
|
}
|
|
|
|
|
blas.MatMul(*ordered_proj0, false, *weight, false, static_cast<T>(1.0),
|
|
|
|
|
blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
|
|
|
|
|
&gate_t, static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -189,8 +237,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
lstmp_value.state_value = cell_t.data<T>();
|
|
|
|
|
lstmp_value.state_active_value = cell_pre_act_t.data<T>();
|
|
|
|
|
math::LstmUnitFunctor<DeviceContext, T>::compute(
|
|
|
|
|
device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act,
|
|
|
|
|
cell_act, cand_act);
|
|
|
|
|
device_ctx, lstmp_value, frame_size, cur_batch_size, cell_clip,
|
|
|
|
|
gate_act, cell_act, cand_act);
|
|
|
|
|
lstmp_value.prev_state_value = lstmp_value.state_value;
|
|
|
|
|
blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0),
|
|
|
|
|
&proj_t, static_cast<T>(0.0));
|
|
|
|
@ -198,6 +246,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto proj_t_dev = EigenMatrix<T>::From(proj_t);
|
|
|
|
|
ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
|
|
|
|
|
}
|
|
|
|
|
if (proj_clip && proj_clip > 0.0) {
|
|
|
|
|
T* x_data = proj_t.data<T>();
|
|
|
|
|
int64_t numel = proj_t.numel();
|
|
|
|
|
Transform<DeviceContext> trans;
|
|
|
|
|
trans(ctx.template device_context<DeviceContext>(), x_data,
|
|
|
|
|
x_data + numel, x_data,
|
|
|
|
|
_ClipFunctor<T>(-1.0 * proj_clip, proj_clip));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
|
|
|
|
@ -239,6 +295,9 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* proj_out = ctx.Input<LoDTensor>("Projection");
|
|
|
|
|
auto* cell_out = ctx.Input<LoDTensor>("Cell");
|
|
|
|
|
|
|
|
|
|
auto proj_clip = static_cast<T>(ctx.Attr<float>("proj_clip"));
|
|
|
|
|
auto cell_clip = static_cast<T>(ctx.Attr<float>("cell_clip"));
|
|
|
|
|
|
|
|
|
|
auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
|
|
|
|
|
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
|
|
|
|
|
auto* batch_hidden = ctx.Input<LoDTensor>("BatchHidden");
|
|
|
|
@ -253,7 +312,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
|
|
|
|
|
|
|
|
|
auto* h0 = ctx.Input<Tensor>("H0");
|
|
|
|
|
auto* ordered_proj0 = ctx.Input<Tensor>("OrderedP0");
|
|
|
|
|
auto* c0 = ctx.Input<Tensor>("C0");
|
|
|
|
|
|
|
|
|
|
auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
|
|
|
|
@ -363,6 +421,17 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
Tensor cur_proj = batch_proj.Slice(bstart, bend);
|
|
|
|
|
Tensor proj_g = batch_proj_g.Slice(bstart, bend);
|
|
|
|
|
|
|
|
|
|
if (proj_clip && proj_clip > 0.0) {
|
|
|
|
|
T* dx_data = proj_g.data<T>();
|
|
|
|
|
T* x_data = cur_proj.data<T>();
|
|
|
|
|
int64_t numel = proj_g.numel();
|
|
|
|
|
Transform<DeviceContext> trans;
|
|
|
|
|
trans(ctx.template device_context<DeviceContext>(), dx_data,
|
|
|
|
|
dx_data + numel, x_data, dx_data,
|
|
|
|
|
_ClipGradFunctor<T>(-1.0 * proj_clip, proj_clip));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (proj_act != math::detail::ActivationType::kIdentity) {
|
|
|
|
|
auto cur_proj_dev = EigenMatrix<T>::From(cur_proj);
|
|
|
|
|
auto proj_g_dev = EigenMatrix<T>::From(proj_g);
|
|
|
|
@ -407,7 +476,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
int cur_batch_size = bend - bstart;
|
|
|
|
|
math::LstmUnitGradFunctor<DeviceContext, T>::compute(
|
|
|
|
|
device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size,
|
|
|
|
|
gate_act, cell_act, cand_act);
|
|
|
|
|
cell_clip, gate_act, cell_act, cand_act);
|
|
|
|
|
|
|
|
|
|
if (n > 0) {
|
|
|
|
|
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
|
|
|
|
@ -426,31 +495,14 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
|
|
|
|
|
&ordered_h0, true);
|
|
|
|
|
if (weight_g) {
|
|
|
|
|
blas.MatMul(*ordered_proj0, true, gate_g, false,
|
|
|
|
|
static_cast<T>(1.0), weight_g, static_cast<T>(1.0));
|
|
|
|
|
blas.MatMul(ordered_h0, true, gate_g, false, static_cast<T>(1.0),
|
|
|
|
|
weight_g, static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (h0 && (h0_g || proj_weight_g)) {
|
|
|
|
|
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
|
|
|
|
|
Tensor proj0_g;
|
|
|
|
|
proj0_g.Resize({in_dims[0], proj_weight->dims()[1]});
|
|
|
|
|
proj0_g.mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
|
|
|
|
|
&proj0_g, static_cast<T>(0.0));
|
|
|
|
|
if (proj_act != math::detail::ActivationType::kIdentity) {
|
|
|
|
|
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
|
|
|
|
|
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
|
|
|
|
|
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
|
|
|
|
|
proj0_g_dev);
|
|
|
|
|
}
|
|
|
|
|
if (h0_g) {
|
|
|
|
|
blas.MatMul(proj0_g, false, *proj_weight, true, static_cast<T>(1.0),
|
|
|
|
|
&ordered_h0_g, static_cast<T>(0.0));
|
|
|
|
|
}
|
|
|
|
|
if (proj_weight_g) {
|
|
|
|
|
blas.MatMul(ordered_h0, true, proj0_g, false, static_cast<T>(1.0),
|
|
|
|
|
proj_weight_g, static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
&ordered_h0_g, static_cast<T>(0.0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|