Fix CPPLint issues in /math/detail/gru_kernel.h (#10390)

* Fix CPPLint issyes in gru_kernel.h

* Fix CPPLint issyes in gru_kernel.h

* Fix Compile error
simplify_fluid_api_recognize_digit
Abhinav Arora 7 years ago committed by GitHub
parent 20fa848076
commit c9f55dfafc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,8 +43,8 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
r_prev_out = prev_output_value[i]; r_prev_out = prev_output_value[i];
} }
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
r_value_reset_output, active_gate); &r_value_reset_output, active_gate);
update_gate[i] = r_value_update_gate; update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate; reset_gate[i] = r_value_reset_gate;
@ -71,8 +71,8 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
r_prev_out = prev_output_value[i]; r_prev_out = prev_output_value[i];
} }
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
r_output, active_node); &r_output, active_node);
frame_state[i] = r_value_frame_state; frame_state[i] = r_value_frame_state;
output_value[i] = r_output; output_value[i] = r_output;
@ -99,8 +99,8 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i]; r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
} }
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
r_value_reset_output, active_gate); &r_value_reset_output, active_gate);
update_gate[i] = r_value_update_gate; update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate; reset_gate[i] = r_value_reset_gate;
@ -129,8 +129,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i]; r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
} }
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
r_output, active_node); &r_output, active_node);
frame_state[i] = r_value_frame_state; frame_state[i] = r_value_frame_state;
(reinterpret_cast<__m256 *>(output_value))[i] = r_output; (reinterpret_cast<__m256 *>(output_value))[i] = r_output;
@ -213,9 +213,9 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
r_prev_out_grad = prev_out_grad[i]; r_prev_out_grad = prev_out_grad[i];
} }
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad,
r_frame_state_grad, r_prev_out_value, r_prev_out_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
r_out_grad, active_node); &r_prev_out_grad, &r_out_grad, active_node);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad; frame_state_grad[i] = r_frame_state_grad;
@ -258,9 +258,9 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
r_prev_out_grad = prev_out_grad[i]; r_prev_out_grad = prev_out_grad[i];
} }
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value, op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
r_reset_output_grad, active_gate); &r_prev_out_grad, &r_reset_output_grad, active_gate);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad; reset_gate_grad[i] = r_reset_gate_grad;
@ -302,9 +302,9 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
} }
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad,
r_frame_state_grad, r_prev_out_value, r_prev_out_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
r_out_grad, active_node); &r_prev_out_grad, &r_out_grad, active_node);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad; frame_state_grad[i] = r_frame_state_grad;
@ -350,9 +350,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
} }
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value, op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
r_reset_output_grad, active_gate); &r_prev_out_grad, &r_reset_output_grad, active_gate);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad; reset_gate_grad[i] = r_reset_gate_grad;

@ -55,8 +55,8 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
r_prev_out = prev_output_value[frame_idx]; r_prev_out = prev_output_value[frame_idx];
} }
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
r_value_reset_output, active_gate); &r_value_reset_output, active_gate);
gate_value[frame_idx + frame_size * 0] = r_value_update_gate; gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
gate_value[frame_idx + frame_size * 1] = r_value_reset_gate; gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
@ -93,8 +93,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
r_prev_out = prev_output_value[frame_idx]; r_prev_out = prev_output_value[frame_idx];
} }
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
r_output, active_node); &r_output, active_node);
gate_value[frame_idx + frame_size * 2] = r_value_frame_state; gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
output_value[frame_idx] = r_output; output_value[frame_idx] = r_output;
@ -137,9 +137,9 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
r_prev_out_grad = prev_out_grad[frame_idx]; r_prev_out_grad = prev_out_grad[frame_idx];
} }
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value,
r_frame_state_grad, r_prev_out_value, r_prev_out_grad, &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad,
r_out_grad, active_node); &r_out_grad, active_node);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
@ -185,9 +185,9 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
r_reset_output_grad = reset_output_grad[frame_idx]; r_reset_output_grad = reset_output_grad[frame_idx];
} }
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value, op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value,
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad, &r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad,
r_reset_output_grad, active_gate); &r_reset_output_grad, active_gate);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad; gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;

@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include <type_traits>
// TODO(guosheng): refine code style in gru_kernel // TODO(guosheng): refine code style in gru_kernel
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -28,25 +28,25 @@ namespace forward {
template <typename T> template <typename T>
class gru_resetOutput { class gru_resetOutput {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate, HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate,
T &prev_out, T &value_reset_output, T *prev_out, T *value_reset_output,
ActivationType act_gate) { ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate); *value_update_gate = activation(*value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate);
value_reset_output = prev_out * value_reset_gate; *value_reset_output = (*prev_out) * (*value_reset_gate);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 &value_reset_gate, __m256 &prev_out, __m256 *value_reset_gate, __m256 *prev_out,
__m256 &value_reset_output, __m256 *value_reset_output,
ActivationType act_gate) { ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate); *value_update_gate = activation(*value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate);
value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate); *value_reset_output = _mm256_mul_ps(*prev_out, *value_reset_gate);
} }
#endif #endif
#endif #endif
@ -55,25 +55,25 @@ class gru_resetOutput {
template <typename T> template <typename T>
class gru_finalOutput { class gru_finalOutput {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state, HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state,
T &prev_out, T &value_output, T *prev_out, T *value_output,
ActivationType act_input) { ActivationType act_input) {
value_frame_state = activation(value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
value_output = prev_out - (value_update_gate * prev_out) + *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
(value_update_gate * value_frame_state); ((*value_update_gate) * (*value_frame_state));
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 &value_frame_state, __m256 &prev_out, __m256 *value_frame_state, __m256 *prev_out,
__m256 &value_output, ActivationType act_input) { __m256 *value_output, ActivationType act_input) {
value_frame_state = activation(value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
value_output = _mm256_add_ps( *value_output = _mm256_add_ps(
_mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)), _mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_mul_ps(value_update_gate, value_frame_state)); _mm256_mul_ps(*value_update_gate, *value_frame_state));
} }
#endif #endif
#endif #endif
@ -85,37 +85,38 @@ namespace backward {
template <typename T> template <typename T>
class gru_stateGrad { class gru_stateGrad {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T &value_frame_state, T &grad_frame_state, T *value_frame_state, T *grad_frame_state,
T &value_prev_out, T &grad_prev_out, T *value_prev_out, T *grad_prev_out,
T &grad_output, ActivationType act_input) { T *grad_output, ActivationType act_input) {
grad_update_gate = (grad_output * value_frame_state); *grad_update_gate = (*grad_output * (*value_frame_state));
grad_update_gate -= (grad_output * value_prev_out); *grad_update_gate -= (*grad_output * (*value_prev_out));
grad_prev_out -= (grad_output * value_update_gate); *grad_prev_out -= (*grad_output * (*value_update_gate));
grad_prev_out += grad_output; *grad_prev_out += *grad_output;
grad_frame_state = activation(grad_output * value_update_gate, *grad_frame_state = activation(*grad_output * (*value_update_gate),
value_frame_state, act_input); *value_frame_state, act_input);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 &grad_update_gate, __m256 *grad_update_gate,
__m256 &value_frame_state, __m256 *value_frame_state,
__m256 &grad_frame_state, __m256 &value_prev_out, __m256 *grad_frame_state, __m256 *value_prev_out,
__m256 &grad_prev_out, __m256 &grad_output, __m256 *grad_prev_out, __m256 *grad_output,
ActivationType act_input) { ActivationType act_input) {
grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state); *grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state);
grad_update_gate = _mm256_sub_ps( *grad_update_gate = _mm256_sub_ps(
grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out)); *grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out));
grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
_mm256_sub_ps(grad_prev_out, _mm256_sub_ps(*grad_prev_out,
_mm256_mul_ps(grad_output, value_update_gate)), _mm256_mul_ps(*grad_output, *value_update_gate)),
grad_output); *grad_output);
grad_frame_state = activation(_mm256_mul_ps(grad_output, value_update_gate), *grad_frame_state =
value_frame_state, act_input); activation(_mm256_mul_ps(*grad_output, *value_update_gate),
*value_frame_state, act_input);
} }
#endif #endif
#endif #endif
@ -124,32 +125,34 @@ class gru_stateGrad {
template <typename T> template <typename T>
class gru_resetGrad { class gru_resetGrad {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T &value_reset_gate, T &grad_reset_gate, T *value_reset_gate, T *grad_reset_gate,
T &value_prev_out, T &grad_prev_out, T *value_prev_out, T *grad_prev_out,
T &grad_reset_output, ActivationType act_gate) { T *grad_reset_output, ActivationType act_gate) {
grad_reset_gate = (grad_reset_output * value_prev_out); *grad_reset_gate = (*grad_reset_output * (*value_prev_out));
grad_prev_out += (grad_reset_output * value_reset_gate); *grad_prev_out += (*grad_reset_output * (*value_reset_gate));
grad_update_gate = *grad_update_gate =
activation(grad_update_gate, value_update_gate, act_gate); activation(*grad_update_gate, *value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate); *grad_reset_gate =
activation(*grad_reset_gate, *value_reset_gate, act_gate);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 &grad_update_gate, __m256 &value_reset_gate, __m256 *grad_update_gate, __m256 *value_reset_gate,
__m256 &grad_reset_gate, __m256 &value_prev_out, __m256 *grad_reset_gate, __m256 *value_prev_out,
__m256 &grad_prev_out, __m256 &grad_reset_output, __m256 *grad_prev_out, __m256 *grad_reset_output,
ActivationType act_gate) { ActivationType act_gate) {
grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out); *grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out);
grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate)); *grad_prev_out, _mm256_mul_ps(*grad_reset_output, *value_reset_gate));
grad_update_gate = *grad_update_gate =
activation(grad_update_gate, value_update_gate, act_gate); activation(*grad_update_gate, *value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate); *grad_reset_gate =
activation(*grad_reset_gate, *value_reset_gate, act_gate);
} }
#endif #endif
#endif #endif

Loading…
Cancel
Save