|
|
|
@ -20,11 +20,11 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/math/MathFunctions.h"
|
|
|
|
|
|
|
|
|
|
#ifndef PADDLE_TYPE_DOUBLE
|
|
|
|
|
#define CBLAS_GEMM paddle::gemm<float>
|
|
|
|
|
#else
|
|
|
|
|
#define CBLAS_GEMM paddle::gemm<double>
|
|
|
|
|
#endif
|
|
|
|
|
// #ifndef PADDLE_TYPE_DOUBLE
|
|
|
|
|
// #define CBLAS_GEMM paddle::gemm<float>
|
|
|
|
|
// #else
|
|
|
|
|
// #define CBLAS_GEMM paddle::gemm<double>
|
|
|
|
|
// #endif
|
|
|
|
|
|
|
|
|
|
template<class OpResetOutput>
|
|
|
|
|
void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput,
|
|
|
|
@ -219,37 +219,37 @@ void hl_cpu_gru_forward(OpResetOutput opResetOutput,
|
|
|
|
|
hl_activation_mode_t active_node,
|
|
|
|
|
hl_activation_mode_t active_gate) {
|
|
|
|
|
if (value.prevOutValue) {
|
|
|
|
|
CBLAS_GEMM(CblasNoTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
batchSize,
|
|
|
|
|
2 * frameSize,
|
|
|
|
|
frameSize,
|
|
|
|
|
1,
|
|
|
|
|
value.prevOutValue,
|
|
|
|
|
frameSize,
|
|
|
|
|
value.gateWeight,
|
|
|
|
|
frameSize * 2,
|
|
|
|
|
1,
|
|
|
|
|
value.gateValue,
|
|
|
|
|
frameSize * 3);
|
|
|
|
|
// CBLAS_GEMM(CblasNoTrans,
|
|
|
|
|
// CblasNoTrans,
|
|
|
|
|
// batchSize,
|
|
|
|
|
// 2 * frameSize,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// 1,
|
|
|
|
|
// value.prevOutValue,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// value.gateWeight,
|
|
|
|
|
// frameSize * 2,
|
|
|
|
|
// 1,
|
|
|
|
|
// value.gateValue,
|
|
|
|
|
// frameSize * 3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
forward_reset_output(opResetOutput, value, frameSize, batchSize, active_gate);
|
|
|
|
|
|
|
|
|
|
if (value.prevOutValue) {
|
|
|
|
|
CBLAS_GEMM(CblasNoTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
batchSize,
|
|
|
|
|
frameSize,
|
|
|
|
|
frameSize,
|
|
|
|
|
1,
|
|
|
|
|
value.resetOutputValue,
|
|
|
|
|
frameSize,
|
|
|
|
|
value.stateWeight,
|
|
|
|
|
frameSize,
|
|
|
|
|
1,
|
|
|
|
|
value.gateValue + frameSize * 2,
|
|
|
|
|
frameSize * 3);
|
|
|
|
|
// CBLAS_GEMM(CblasNoTrans,
|
|
|
|
|
// CblasNoTrans,
|
|
|
|
|
// batchSize,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// 1,
|
|
|
|
|
// value.resetOutputValue,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// value.stateWeight,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// 1,
|
|
|
|
|
// value.gateValue + frameSize * 2,
|
|
|
|
|
// frameSize * 3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
forward_final_output(opFinalOutput, value, frameSize, batchSize, active_node);
|
|
|
|
@ -538,34 +538,34 @@ void hl_cpu_gru_backward(OpStateGrad opStateGrad,
|
|
|
|
|
frameSize, batchSize, active_node);
|
|
|
|
|
|
|
|
|
|
if (value.prevOutValue && grad.prevOutGrad) {
|
|
|
|
|
CBLAS_GEMM(CblasNoTrans,
|
|
|
|
|
CblasTrans,
|
|
|
|
|
batchSize,
|
|
|
|
|
frameSize,
|
|
|
|
|
frameSize,
|
|
|
|
|
1,
|
|
|
|
|
grad.gateGrad + frameSize * 2,
|
|
|
|
|
frameSize * 3,
|
|
|
|
|
value.stateWeight,
|
|
|
|
|
frameSize,
|
|
|
|
|
0,
|
|
|
|
|
grad.resetOutputGrad,
|
|
|
|
|
frameSize);
|
|
|
|
|
// CBLAS_GEMM(CblasNoTrans,
|
|
|
|
|
// CblasTrans,
|
|
|
|
|
// batchSize,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// 1,
|
|
|
|
|
// grad.gateGrad + frameSize * 2,
|
|
|
|
|
// frameSize * 3,
|
|
|
|
|
// value.stateWeight,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// 0,
|
|
|
|
|
// grad.resetOutputGrad,
|
|
|
|
|
// frameSize);
|
|
|
|
|
|
|
|
|
|
if (grad.stateWeightGrad) {
|
|
|
|
|
CBLAS_GEMM(CblasTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
frameSize,
|
|
|
|
|
frameSize,
|
|
|
|
|
batchSize,
|
|
|
|
|
1,
|
|
|
|
|
value.resetOutputValue,
|
|
|
|
|
frameSize,
|
|
|
|
|
grad.gateGrad + frameSize * 2,
|
|
|
|
|
frameSize * 3,
|
|
|
|
|
1,
|
|
|
|
|
grad.stateWeightGrad,
|
|
|
|
|
frameSize);
|
|
|
|
|
// CBLAS_GEMM(CblasTrans,
|
|
|
|
|
// CblasNoTrans,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// batchSize,
|
|
|
|
|
// 1,
|
|
|
|
|
// value.resetOutputValue,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// grad.gateGrad + frameSize * 2,
|
|
|
|
|
// frameSize * 3,
|
|
|
|
|
// 1,
|
|
|
|
|
// grad.stateWeightGrad,
|
|
|
|
|
// frameSize);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -573,34 +573,34 @@ void hl_cpu_gru_backward(OpStateGrad opStateGrad,
|
|
|
|
|
frameSize, batchSize, active_gate);
|
|
|
|
|
|
|
|
|
|
if (grad.prevOutGrad && value.prevOutValue) {
|
|
|
|
|
CBLAS_GEMM(CblasNoTrans,
|
|
|
|
|
CblasTrans,
|
|
|
|
|
batchSize,
|
|
|
|
|
frameSize,
|
|
|
|
|
frameSize * 2,
|
|
|
|
|
1,
|
|
|
|
|
grad.gateGrad,
|
|
|
|
|
frameSize * 3,
|
|
|
|
|
value.gateWeight,
|
|
|
|
|
frameSize * 2,
|
|
|
|
|
1,
|
|
|
|
|
grad.prevOutGrad,
|
|
|
|
|
frameSize);
|
|
|
|
|
// CBLAS_GEMM(CblasNoTrans,
|
|
|
|
|
// CblasTrans,
|
|
|
|
|
// batchSize,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// frameSize * 2,
|
|
|
|
|
// 1,
|
|
|
|
|
// grad.gateGrad,
|
|
|
|
|
// frameSize * 3,
|
|
|
|
|
// value.gateWeight,
|
|
|
|
|
// frameSize * 2,
|
|
|
|
|
// 1,
|
|
|
|
|
// grad.prevOutGrad,
|
|
|
|
|
// frameSize);
|
|
|
|
|
|
|
|
|
|
if (grad.gateWeightGrad) {
|
|
|
|
|
CBLAS_GEMM(CblasTrans,
|
|
|
|
|
CblasNoTrans,
|
|
|
|
|
frameSize,
|
|
|
|
|
frameSize * 2,
|
|
|
|
|
batchSize,
|
|
|
|
|
1,
|
|
|
|
|
value.prevOutValue,
|
|
|
|
|
frameSize,
|
|
|
|
|
grad.gateGrad,
|
|
|
|
|
frameSize * 3,
|
|
|
|
|
1,
|
|
|
|
|
grad.gateWeightGrad,
|
|
|
|
|
frameSize * 2);
|
|
|
|
|
// CBLAS_GEMM(CblasTrans,
|
|
|
|
|
// CblasNoTrans,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// frameSize * 2,
|
|
|
|
|
// batchSize,
|
|
|
|
|
// 1,
|
|
|
|
|
// value.prevOutValue,
|
|
|
|
|
// frameSize,
|
|
|
|
|
// grad.gateGrad,
|
|
|
|
|
// frameSize * 3,
|
|
|
|
|
// 1,
|
|
|
|
|
// grad.gateWeightGrad,
|
|
|
|
|
// frameSize * 2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|