|
|
@ -16,7 +16,10 @@ limitations under the License. */
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
#include "paddle/fluid/operators/math/detail/activation_functions.h"
|
|
|
|
#include "paddle/fluid/operators/math/detail/activation_functions.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
|
|
|
|
#include "paddle/fluid/operators/math/gru_compute.h"
|
|
|
|
#include "paddle/fluid/operators/math/gru_compute.h"
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
#include "paddle/fluid/operators/math/sequence2batch.h"
|
|
|
|
#include "paddle/fluid/operators/math/sequence2batch.h"
|
|
|
@ -94,6 +97,7 @@ class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
context.Attr<std::string>("activation"));
|
|
|
|
context.Attr<std::string>("activation"));
|
|
|
|
auto active_gate = math::detail::GetActivationType(
|
|
|
|
auto active_gate = math::detail::GetActivationType(
|
|
|
|
context.Attr<std::string>("gate_activation"));
|
|
|
|
context.Attr<std::string>("gate_activation"));
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
|
|
|
|
for (size_t n = 0; n < num_batch; n++) {
|
|
|
|
for (size_t n = 0; n < num_batch; n++) {
|
|
|
|
int bstart = static_cast<int>(batch_starts[n]);
|
|
|
|
int bstart = static_cast<int>(batch_starts[n]);
|
|
|
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
|
|
|
int bend = static_cast<int>(batch_starts[n + 1]);
|
|
|
@ -105,9 +109,27 @@ class GRUKernel : public framework::OpKernel<T> {
|
|
|
|
gru_value.output_value = hidden_t.data<T>();
|
|
|
|
gru_value.output_value = hidden_t.data<T>();
|
|
|
|
gru_value.gate_value = gate_t.data<T>();
|
|
|
|
gru_value.gate_value = gate_t.data<T>();
|
|
|
|
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
|
|
|
|
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
|
|
|
|
math::GRUUnitFunctor<DeviceContext, T>::compute(
|
|
|
|
if (gru_value.prev_out_value) {
|
|
|
|
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
|
|
|
|
blas.GEMM(false, false, cur_batch_size, frame_size * 2, frame_size, 1,
|
|
|
|
active_gate);
|
|
|
|
gru_value.prev_out_value, frame_size, gru_value.gate_weight,
|
|
|
|
|
|
|
|
frame_size * 2, 1, gru_value.gate_value, frame_size * 3);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
math::detail::forward_reset_output(
|
|
|
|
|
|
|
|
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
|
|
|
|
|
|
|
|
cur_batch_size, active_gate);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (gru_value.prev_out_value) {
|
|
|
|
|
|
|
|
blas.GEMM(false, false, cur_batch_size, frame_size, frame_size, 1,
|
|
|
|
|
|
|
|
gru_value.reset_output_value, frame_size,
|
|
|
|
|
|
|
|
gru_value.state_weight, frame_size, 1,
|
|
|
|
|
|
|
|
gru_value.gate_value + frame_size * 2, frame_size * 3);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
math::detail::forward_final_output(
|
|
|
|
|
|
|
|
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
|
|
|
|
|
|
|
|
cur_batch_size, active_node);
|
|
|
|
|
|
|
|
|
|
|
|
gru_value.prev_out_value = gru_value.output_value;
|
|
|
|
gru_value.prev_out_value = gru_value.output_value;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|