|
|
|
@ -143,7 +143,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto proj_act = math::detail::GetActivationType(
|
|
|
|
|
ctx.Attr<std::string>("proj_activation"));
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
|
|
|
|
|
for (size_t n = 0; n < num_batch; n++) {
|
|
|
|
|
int bstart = static_cast<int>(batch_starts[n]);
|
|
|
|
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
|
|
|
@ -160,9 +160,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
|
|
|
|
|
int pre_h_end = pre_h_start + cur_batch_size;
|
|
|
|
|
auto pre_proj_t = batch_proj.Slice(pre_h_start, pre_h_end);
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, pre_proj_t, false, *weight,
|
|
|
|
|
false, static_cast<T>(1.0), &gate_t,
|
|
|
|
|
static_cast<T>(1.0));
|
|
|
|
|
blas.MatMul(pre_proj_t, false, *weight, false, static_cast<T>(1.0),
|
|
|
|
|
&gate_t, static_cast<T>(1.0));
|
|
|
|
|
} else if (hidden_t0) {
|
|
|
|
|
// If n == 0 and there is no initialized hidden state, that is to say
|
|
|
|
|
// the H0 is zeros, the calculation W_h * H0 will be skiped.
|
|
|
|
@ -176,16 +175,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
ordered_proj0->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
|
|
|
|
|
&ordered_h0, true);
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false,
|
|
|
|
|
*proj_weight, false, static_cast<T>(1.0),
|
|
|
|
|
ordered_proj0, static_cast<T>(0.0));
|
|
|
|
|
blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast<T>(1.0),
|
|
|
|
|
ordered_proj0, static_cast<T>(0.0));
|
|
|
|
|
if (proj_act != math::detail::ActivationType::kIdentity) {
|
|
|
|
|
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
|
|
|
|
|
ActCompute(cell_act, place, proj0_dev, proj0_dev);
|
|
|
|
|
}
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, false,
|
|
|
|
|
*weight, false, static_cast<T>(1.0),
|
|
|
|
|
&gate_t, static_cast<T>(1.0));
|
|
|
|
|
blas.MatMul(*ordered_proj0, false, *weight, false, static_cast<T>(1.0),
|
|
|
|
|
&gate_t, static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lstmp_value.gate_value = gate_t.data<T>();
|
|
|
|
@ -196,9 +193,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
|
|
|
|
|
device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act,
|
|
|
|
|
cell_act, cand_act);
|
|
|
|
|
lstmp_value.prev_state_value = lstmp_value.state_value;
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, hidden_t, false, *proj_weight,
|
|
|
|
|
false, static_cast<T>(1.0), &proj_t,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0),
|
|
|
|
|
&proj_t, static_cast<T>(0.0));
|
|
|
|
|
if (proj_act != math::detail::ActivationType::kIdentity) {
|
|
|
|
|
auto proj_t_dev = EigenMatrix<T>::From(proj_t);
|
|
|
|
|
ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
|
|
|
|
@ -361,6 +357,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto batch_starts = batch_gate->lod()[0];
|
|
|
|
|
size_t num_batch = batch_starts.size() - 1;
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
|
|
|
|
|
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
|
|
|
|
|
int bstart = static_cast<int>(batch_starts[n]);
|
|
|
|
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
|
|
|
@ -375,15 +372,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
/* hidden state backwarad */
|
|
|
|
|
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight,
|
|
|
|
|
true, static_cast<T>(1.0), &out_g,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
blas.MatMul(proj_g, false, *proj_weight, true, static_cast<T>(1.0),
|
|
|
|
|
&out_g, static_cast<T>(0.0));
|
|
|
|
|
/* projection weight backward*/
|
|
|
|
|
if (proj_weight_g) {
|
|
|
|
|
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
|
|
|
|
|
false, static_cast<T>(1.0),
|
|
|
|
|
proj_weight_g, static_cast<T>(1.0));
|
|
|
|
|
blas.MatMul(hidden_t, true, proj_g, false, static_cast<T>(1.0),
|
|
|
|
|
proj_weight_g, static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor gate = batch_gate->Slice(bstart, bend);
|
|
|
|
@ -419,24 +414,21 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
|
|
|
|
|
int pre_h_end = pre_h_start + cur_batch_size;
|
|
|
|
|
auto pre_proj_g = batch_proj_g.Slice(pre_h_start, pre_h_end);
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true,
|
|
|
|
|
static_cast<T>(1.0), &pre_proj_g,
|
|
|
|
|
static_cast<T>(1.0));
|
|
|
|
|
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
|
|
|
|
|
&pre_proj_g, static_cast<T>(1.0));
|
|
|
|
|
if (weight_g) {
|
|
|
|
|
/* weight backward*/
|
|
|
|
|
auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end);
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, pre_proj, true, gate_g,
|
|
|
|
|
false, static_cast<T>(1.0), weight_g,
|
|
|
|
|
static_cast<T>(1.0));
|
|
|
|
|
blas.MatMul(pre_proj, true, gate_g, false, static_cast<T>(1.0),
|
|
|
|
|
weight_g, static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (h0 && weight_g) {
|
|
|
|
|
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
|
|
|
|
|
&ordered_h0, true);
|
|
|
|
|
if (weight_g) {
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, true,
|
|
|
|
|
gate_g, false, static_cast<T>(1.0),
|
|
|
|
|
weight_g, static_cast<T>(1.0));
|
|
|
|
|
blas.MatMul(*ordered_proj0, true, gate_g, false,
|
|
|
|
|
static_cast<T>(1.0), weight_g, static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (h0 && (h0_g || proj_weight_g)) {
|
|
|
|
@ -444,9 +436,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
Tensor proj0_g;
|
|
|
|
|
proj0_g.Resize({in_dims[0], proj_weight->dims()[1]});
|
|
|
|
|
proj0_g.mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
|
|
|
|
|
true, static_cast<T>(1.0), &proj0_g,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
|
|
|
|
|
&proj0_g, static_cast<T>(0.0));
|
|
|
|
|
if (proj_act != math::detail::ActivationType::kIdentity) {
|
|
|
|
|
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
|
|
|
|
|
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
|
|
|
|
@ -454,14 +445,12 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
proj0_g_dev);
|
|
|
|
|
}
|
|
|
|
|
if (h0_g) {
|
|
|
|
|
math::matmul<DeviceContext, T>(
|
|
|
|
|
device_ctx, proj0_g, false, *proj_weight, true,
|
|
|
|
|
static_cast<T>(1.0), &ordered_h0_g, static_cast<T>(0.0));
|
|
|
|
|
blas.MatMul(proj0_g, false, *proj_weight, true, static_cast<T>(1.0),
|
|
|
|
|
&ordered_h0_g, static_cast<T>(0.0));
|
|
|
|
|
}
|
|
|
|
|
if (proj_weight_g) {
|
|
|
|
|
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true,
|
|
|
|
|
proj0_g, false, static_cast<T>(1.0),
|
|
|
|
|
proj_weight_g, static_cast<T>(1.0));
|
|
|
|
|
blas.MatMul(ordered_h0, true, proj0_g, false, static_cast<T>(1.0),
|
|
|
|
|
proj_weight_g, static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|