|
|
|
@ -146,35 +146,27 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* weight_grad =
|
|
|
|
|
context.Output<Tensor>(framework::GradVarName("Weight"));
|
|
|
|
|
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
|
|
|
|
|
input_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
hidden_prev_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
weight_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
Tensor gate_grad;
|
|
|
|
|
gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
|
|
|
|
|
Tensor reset_hidden_prev_grad;
|
|
|
|
|
reset_hidden_prev_grad.mutable_data<T>(reset_hidden_prev->dims(),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
|
|
|
|
|
int batch_size = input->dims()[0];
|
|
|
|
|
int frame_size = hidden_prev->dims()[1];
|
|
|
|
|
|
|
|
|
|
const T* hidden_prev_data = hidden_prev->data<T>();
|
|
|
|
|
T* hidden_prev_grad_data = hidden_prev_grad->data<T>();
|
|
|
|
|
const T* weight_data = weight->data<T>();
|
|
|
|
|
T* weight_grad_data = weight_grad->data<T>();
|
|
|
|
|
T* gate_grad_data = gate_grad.data<T>();
|
|
|
|
|
T* gate_grad_data =
|
|
|
|
|
gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
|
|
|
|
|
const T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
|
|
|
|
|
T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data<T>();
|
|
|
|
|
T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.mutable_data<T>(
|
|
|
|
|
reset_hidden_prev->dims(), context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto h_p = EigenMatrix<T>::From(*hidden_prev);
|
|
|
|
|
auto g = EigenMatrix<T>::From(*gate);
|
|
|
|
|
auto d_h = EigenMatrix<T>::From(*hidden_grad);
|
|
|
|
|
auto d_x = EigenMatrix<T>::From(*input_grad);
|
|
|
|
|
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
|
|
|
|
|
auto d_g = EigenMatrix<T>::From(gate_grad);
|
|
|
|
|
auto d_r_h_p = EigenMatrix<T>::From(reset_hidden_prev_grad);
|
|
|
|
|
auto place = context.GetEigenDevice<Place>();
|
|
|
|
|
|
|
|
|
|
int batch_size = input->dims()[0];
|
|
|
|
|
int frame_size = hidden_prev->dims()[1];
|
|
|
|
|
|
|
|
|
|
Eigen::array<int, 2> extents({{batch_size, frame_size}});
|
|
|
|
|
Eigen::array<int, 2> u_offsets({{0, 0}});
|
|
|
|
|
auto u = g.slice(u_offsets, extents); // update gate
|
|
|
|
@ -195,28 +187,42 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
gate_grad_data + frame_size * 2, frame_size * 3,
|
|
|
|
|
weight_data + frame_size * frame_size * 2, frame_size,
|
|
|
|
|
0, reset_hidden_prev_grad_data, frame_size);
|
|
|
|
|
// backward for state_weight
|
|
|
|
|
math::gemm<Place, T>(
|
|
|
|
|
context.device_context(), true, false, frame_size, frame_size,
|
|
|
|
|
batch_size, 1, reset_hidden_prev_data, frame_size,
|
|
|
|
|
gate_grad_data + frame_size * 2, frame_size * 3, 0,
|
|
|
|
|
weight_grad_data + frame_size * frame_size * 2, frame_size);
|
|
|
|
|
// backward for unactivated reset gate
|
|
|
|
|
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
|
|
|
|
|
d_g.slice(r_offsets, extents), d_r_h_p * h_p);
|
|
|
|
|
// backward for update_gate_weight and reset_gate_weight
|
|
|
|
|
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
|
|
|
|
|
frame_size * 2, batch_size, 1, hidden_prev_data,
|
|
|
|
|
frame_size, gate_grad_data, frame_size * 3, 0,
|
|
|
|
|
weight_grad_data, frame_size * 2);
|
|
|
|
|
// backward for weight
|
|
|
|
|
if (weight_grad) {
|
|
|
|
|
T* weight_grad_data = weight_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
// backward for state_weight
|
|
|
|
|
math::gemm<Place, T>(
|
|
|
|
|
context.device_context(), true, false, frame_size, frame_size,
|
|
|
|
|
batch_size, 1, reset_hidden_prev_data, frame_size,
|
|
|
|
|
gate_grad_data + frame_size * 2, frame_size * 3, 0,
|
|
|
|
|
weight_grad_data + frame_size * frame_size * 2, frame_size);
|
|
|
|
|
|
|
|
|
|
// backward for update_gate_weight and reset_gate_weight
|
|
|
|
|
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
|
|
|
|
|
frame_size * 2, batch_size, 1, hidden_prev_data,
|
|
|
|
|
frame_size, gate_grad_data, frame_size * 3, 0,
|
|
|
|
|
weight_grad_data, frame_size * 2);
|
|
|
|
|
}
|
|
|
|
|
// backward for hidden_prev
|
|
|
|
|
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
|
|
|
|
|
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
|
|
|
|
|
frame_size, frame_size * 2, 1, gate_grad_data,
|
|
|
|
|
frame_size * 3, weight_data, frame_size * 2, 1,
|
|
|
|
|
hidden_prev_grad_data, frame_size);
|
|
|
|
|
if (hidden_prev_grad) {
|
|
|
|
|
T* hidden_prev_grad_data =
|
|
|
|
|
hidden_prev_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
|
|
|
|
|
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
|
|
|
|
|
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
|
|
|
|
|
frame_size, frame_size * 2, 1, gate_grad_data,
|
|
|
|
|
frame_size * 3, weight_data, frame_size * 2, 1,
|
|
|
|
|
hidden_prev_grad_data, frame_size);
|
|
|
|
|
}
|
|
|
|
|
// backward for input
|
|
|
|
|
d_x.device(place) = d_g;
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
input_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto d_x = EigenMatrix<T>::From(*input_grad);
|
|
|
|
|
d_x.device(place) = d_g;
|
|
|
|
|
}
|
|
|
|
|
// backward for bias
|
|
|
|
|
if (bias_grad) {
|
|
|
|
|
bias_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|