You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
427 lines
15 KiB
427 lines
15 KiB
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#pragma once
|
|
#include <type_traits>
|
|
#include "paddle/fluid/operators/math/detail/activation_functions.h"
|
|
#include "paddle/fluid/operators/math/gru_compute.h"
|
|
|
|
namespace paddle {
|
|
namespace operators {
|
|
namespace math {
|
|
namespace detail {
|
|
|
|
#ifndef __NVCC__
|
|
|
|
template <class OpResetOutput, typename T>
|
|
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
|
|
T *gate_value, T *reset_output_value,
|
|
T *prev_output_value, int frame_size,
|
|
ActivationType active_gate) {
|
|
T r_value_update_gate;
|
|
T r_value_reset_gate;
|
|
T r_value_reset_output;
|
|
T r_prev_out = 0;
|
|
T *update_gate = gate_value;
|
|
T *reset_gate = gate_value + frame_size;
|
|
|
|
for (int i = 0; i < frame_size; i++) {
|
|
r_value_update_gate = update_gate[i];
|
|
r_value_reset_gate = reset_gate[i];
|
|
if (prev_output_value) {
|
|
r_prev_out = prev_output_value[i];
|
|
}
|
|
|
|
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
|
|
r_value_reset_output, active_gate);
|
|
|
|
update_gate[i] = r_value_update_gate;
|
|
reset_gate[i] = r_value_reset_gate;
|
|
reset_output_value[i] = r_value_reset_output;
|
|
}
|
|
}
|
|
|
|
template <class OpFinalOutput, typename T>
|
|
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
T *gate_value, T *prev_output_value,
|
|
T *output_value, int frame_size,
|
|
ActivationType active_node) {
|
|
T r_value_update_gate;
|
|
T r_value_frame_state;
|
|
T r_prev_out = 0;
|
|
T r_output;
|
|
T *update_gate = gate_value;
|
|
T *frame_state = gate_value + frame_size * 2;
|
|
|
|
for (int i = 0; i < frame_size; i++) {
|
|
r_value_update_gate = update_gate[i];
|
|
r_value_frame_state = frame_state[i];
|
|
if (prev_output_value) {
|
|
r_prev_out = prev_output_value[i];
|
|
}
|
|
|
|
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
|
|
r_output, active_node);
|
|
|
|
frame_state[i] = r_value_frame_state;
|
|
output_value[i] = r_output;
|
|
}
|
|
}
|
|
|
|
template <class OpResetOutput, typename T>
|
|
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
|
|
T *gate_value, T *reset_output_value,
|
|
T *prev_output_value, int frame_size,
|
|
ActivationType active_gate) {
|
|
#ifdef __AVX__
|
|
__m256 r_value_update_gate;
|
|
__m256 r_value_reset_gate;
|
|
__m256 r_value_reset_output;
|
|
__m256 r_prev_out = _mm256_set1_ps(0.0f);
|
|
__m256 *update_gate = (__m256 *)gate_value;
|
|
__m256 *reset_gate = (__m256 *)(gate_value + frame_size);
|
|
|
|
for (int i = 0; i < frame_size / 8; i++) {
|
|
r_value_update_gate = update_gate[i];
|
|
r_value_reset_gate = reset_gate[i];
|
|
if (prev_output_value) {
|
|
r_prev_out = ((__m256 *)prev_output_value)[i];
|
|
}
|
|
|
|
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
|
|
r_value_reset_output, active_gate);
|
|
|
|
update_gate[i] = r_value_update_gate;
|
|
reset_gate[i] = r_value_reset_gate;
|
|
((__m256 *)reset_output_value)[i] = r_value_reset_output;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template <class OpFinalOutput, typename T>
|
|
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
|
|
T *gate_value, T *prev_output_value,
|
|
T *output_value, int frame_size,
|
|
ActivationType active_node) {
|
|
#ifdef __AVX__
|
|
__m256 r_value_update_gate;
|
|
__m256 r_value_frame_state;
|
|
__m256 r_prev_out = _mm256_set1_ps(0.0f);
|
|
__m256 r_output;
|
|
__m256 *update_gate = (__m256 *)gate_value;
|
|
__m256 *frame_state = (__m256 *)(gate_value + frame_size * 2);
|
|
|
|
for (int i = 0; i < frame_size / 8; i++) {
|
|
r_value_update_gate = update_gate[i];
|
|
r_value_frame_state = frame_state[i];
|
|
if (prev_output_value) {
|
|
r_prev_out = ((__m256 *)prev_output_value)[i];
|
|
}
|
|
|
|
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
|
|
r_output, active_node);
|
|
|
|
frame_state[i] = r_value_frame_state;
|
|
((__m256 *)output_value)[i] = r_output;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template <class OpResetOutput, typename T>
|
|
inline void forward_reset_output(OpResetOutput op_reset_output,
|
|
GRUMetaValue<T> value, int frame_size,
|
|
int batch_size, ActivationType active_gate) {
|
|
for (int b = 0; b < batch_size; b++) {
|
|
if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
|
|
hl_avx_gru_forward_reset_output(
|
|
op_reset_output, value.gate_value, value.reset_output_value,
|
|
value.prev_out_value, frame_size, active_gate);
|
|
} else {
|
|
hl_naive_gru_forward_reset_output(
|
|
op_reset_output, value.gate_value, value.reset_output_value,
|
|
value.prev_out_value, frame_size, active_gate);
|
|
}
|
|
|
|
value.gate_value += frame_size * 3;
|
|
value.reset_output_value += frame_size;
|
|
if (value.prev_out_value) {
|
|
value.prev_out_value += frame_size;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <class OpFinalOutput, typename T>
|
|
inline void forward_final_output(OpFinalOutput op_final_output,
|
|
GRUMetaValue<T> value, int frame_size,
|
|
int batch_size, ActivationType active_node) {
|
|
for (int b = 0; b < batch_size; b++) {
|
|
if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
|
|
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
|
|
value.prev_out_value, value.output_value,
|
|
frame_size, active_node);
|
|
} else {
|
|
hl_naive_gru_forward_final_output(
|
|
op_final_output, value.gate_value, value.prev_out_value,
|
|
value.output_value, frame_size, active_node);
|
|
}
|
|
|
|
value.gate_value += frame_size * 3;
|
|
value.output_value += frame_size;
|
|
if (value.prev_out_value) {
|
|
value.prev_out_value += frame_size;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <class OpStateGrad, typename T>
|
|
void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
|
|
T *gate_grad, T *prev_out_value,
|
|
T *prev_out_grad, T *output_grad,
|
|
int frame_size,
|
|
ActivationType active_node) {
|
|
T r_update_gate_value;
|
|
T r_update_gate_grad;
|
|
T r_frame_state_value;
|
|
T r_frame_state_grad;
|
|
T r_out_grad;
|
|
T r_prev_out_value = 0;
|
|
T r_prev_out_grad = 0;
|
|
T *update_gate_value = gate_value;
|
|
T *update_gate_grad = gate_grad;
|
|
T *frame_state_value = gate_value + frame_size * 2;
|
|
T *frame_state_grad = gate_grad + frame_size * 2;
|
|
|
|
for (int i = 0; i < frame_size; i++) {
|
|
r_update_gate_value = update_gate_value[i];
|
|
r_frame_state_value = frame_state_value[i];
|
|
r_out_grad = output_grad[i];
|
|
if (prev_out_value) {
|
|
r_prev_out_value = prev_out_value[i];
|
|
}
|
|
if (prev_out_grad) {
|
|
r_prev_out_grad = prev_out_grad[i];
|
|
}
|
|
|
|
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
|
|
r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
|
|
r_out_grad, active_node);
|
|
|
|
update_gate_grad[i] = r_update_gate_grad;
|
|
frame_state_grad[i] = r_frame_state_grad;
|
|
if (prev_out_grad) {
|
|
prev_out_grad[i] = r_prev_out_grad;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <class OpResetGrad, typename T>
|
|
void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
|
|
T *gate_grad, T *prev_out_value,
|
|
T *prev_out_grad, T *reset_output_grad,
|
|
int frame_size,
|
|
ActivationType active_gate) {
|
|
T r_update_gate_value;
|
|
T r_update_gate_grad;
|
|
T r_reset_gate_value;
|
|
T r_reset_gate_grad;
|
|
T r_reset_output_grad = 0;
|
|
T r_prev_out_value = 0;
|
|
T r_prev_out_grad = 0;
|
|
T *update_gate_value = gate_value;
|
|
T *update_gate_grad = gate_grad;
|
|
T *reset_gate_value = gate_value + frame_size;
|
|
T *reset_gate_grad = gate_grad + frame_size;
|
|
|
|
for (int i = 0; i < frame_size; i++) {
|
|
r_update_gate_value = update_gate_value[i];
|
|
r_update_gate_grad = update_gate_grad[i];
|
|
r_reset_gate_value = reset_gate_value[i];
|
|
|
|
if (prev_out_value && prev_out_grad) {
|
|
r_reset_output_grad = reset_output_grad[i];
|
|
}
|
|
if (prev_out_value) {
|
|
r_prev_out_value = prev_out_value[i];
|
|
}
|
|
if (prev_out_grad) {
|
|
r_prev_out_grad = prev_out_grad[i];
|
|
}
|
|
|
|
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
|
|
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
|
|
r_reset_output_grad, active_gate);
|
|
|
|
update_gate_grad[i] = r_update_gate_grad;
|
|
reset_gate_grad[i] = r_reset_gate_grad;
|
|
if (prev_out_grad) {
|
|
prev_out_grad[i] = r_prev_out_grad;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <class OpStateGrad, typename T>
|
|
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
|
|
T *gate_grad, T *prev_out_value,
|
|
T *prev_out_grad, T *output_grad,
|
|
int frame_size,
|
|
ActivationType active_node) {
|
|
#ifdef __AVX__
|
|
__m256 r_update_gate_value;
|
|
__m256 r_update_gate_grad;
|
|
__m256 r_frame_state_value;
|
|
__m256 r_frame_state_grad;
|
|
__m256 r_out_grad;
|
|
__m256 r_prev_out_value = _mm256_set1_ps(0.0f);
|
|
__m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
|
|
__m256 *update_gate_value = (__m256 *)gate_value;
|
|
__m256 *update_gate_grad = (__m256 *)gate_grad;
|
|
__m256 *frame_state_value = (__m256 *)(gate_value + frame_size * 2);
|
|
__m256 *frame_state_grad = (__m256 *)(gate_grad + frame_size * 2);
|
|
|
|
for (int i = 0; i < frame_size / 8; i++) {
|
|
r_update_gate_value = update_gate_value[i];
|
|
r_frame_state_value = frame_state_value[i];
|
|
r_out_grad = ((__m256 *)output_grad)[i];
|
|
if (prev_out_value) {
|
|
r_prev_out_value = ((__m256 *)prev_out_value)[i];
|
|
}
|
|
if (prev_out_grad) {
|
|
r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
|
|
}
|
|
|
|
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
|
|
r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
|
|
r_out_grad, active_node);
|
|
|
|
update_gate_grad[i] = r_update_gate_grad;
|
|
frame_state_grad[i] = r_frame_state_grad;
|
|
if (prev_out_grad) {
|
|
((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template <class OpResetGrad, typename T>
|
|
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
|
|
T *gate_grad, T *prev_out_value,
|
|
T *prev_out_grad, T *reset_output_grad,
|
|
int frame_size,
|
|
ActivationType active_gate) {
|
|
#ifdef __AVX__
|
|
__m256 r_update_gate_value;
|
|
__m256 r_update_gate_grad;
|
|
__m256 r_reset_gate_value;
|
|
__m256 r_reset_gate_grad;
|
|
__m256 r_reset_output_grad = _mm256_set1_ps(0.0f);
|
|
__m256 r_prev_out_value = _mm256_set1_ps(0.0f);
|
|
__m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
|
|
__m256 *update_gate_value = (__m256 *)gate_value;
|
|
__m256 *update_gate_grad = (__m256 *)gate_grad;
|
|
__m256 *reset_gate_value = (__m256 *)(gate_value + frame_size);
|
|
__m256 *reset_gate_grad = (__m256 *)(gate_grad + frame_size);
|
|
|
|
for (int i = 0; i < frame_size / 8; i++) {
|
|
r_update_gate_value = update_gate_value[i];
|
|
r_update_gate_grad = update_gate_grad[i];
|
|
r_reset_gate_value = reset_gate_value[i];
|
|
|
|
if (prev_out_value && prev_out_grad) {
|
|
r_reset_output_grad = ((__m256 *)reset_output_grad)[i];
|
|
}
|
|
if (prev_out_value) {
|
|
r_prev_out_value = ((__m256 *)prev_out_value)[i];
|
|
}
|
|
if (prev_out_grad) {
|
|
r_prev_out_grad = ((__m256 *)prev_out_grad)[i];
|
|
}
|
|
|
|
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
|
|
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
|
|
r_reset_output_grad, active_gate);
|
|
|
|
update_gate_grad[i] = r_update_gate_grad;
|
|
reset_gate_grad[i] = r_reset_gate_grad;
|
|
if (prev_out_grad) {
|
|
((__m256 *)prev_out_grad)[i] = r_prev_out_grad;
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template <class OpStateGrad, typename T>
|
|
inline void backward_state_grad(OpStateGrad op_state_grad,
|
|
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
|
|
int frame_size, int batch_size,
|
|
ActivationType active_node) {
|
|
for (int b = 0; b < batch_size; b++) {
|
|
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
|
|
hl_avx_gru_backward_state_grad(
|
|
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
|
|
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
|
|
} else {
|
|
hl_naive_gru_backward_state_grad(
|
|
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
|
|
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
|
|
}
|
|
|
|
value.gate_value += frame_size * 3;
|
|
if (value.prev_out_value) {
|
|
value.prev_out_value += frame_size;
|
|
}
|
|
|
|
grad.gate_grad += frame_size * 3;
|
|
grad.output_grad += frame_size;
|
|
if (grad.prev_out_grad) {
|
|
grad.prev_out_grad += frame_size;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <class OpResetGrad, typename T>
|
|
inline void backward_reset_grad(OpResetGrad op_reset_grad,
|
|
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
|
|
int frame_size, int batch_size,
|
|
ActivationType active_gate) {
|
|
for (int b = 0; b < batch_size; b++) {
|
|
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
|
|
hl_avx_gru_backward_reset_grad(
|
|
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
|
|
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
|
|
} else {
|
|
hl_naive_gru_backward_reset_grad(
|
|
op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
|
|
grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
|
|
}
|
|
|
|
value.gate_value += frame_size * 3;
|
|
if (value.prev_out_value) {
|
|
value.prev_out_value += frame_size;
|
|
}
|
|
|
|
grad.gate_grad += frame_size * 3;
|
|
grad.reset_output_grad += frame_size;
|
|
if (grad.prev_out_grad) {
|
|
grad.prev_out_grad += frame_size;
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif
|
|
|
|
} // namespace detail
|
|
} // namespace math
|
|
} // namespace operators
|
|
} // namespace paddle
|