|
|
|
@ -110,7 +110,7 @@ class GRUUnitKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto c = g.slice(c_offsets, extents); // output candidate
|
|
|
|
|
|
|
|
|
|
// calculate final output
|
|
|
|
|
h.device(place) = u * (h_p - c) + c;
|
|
|
|
|
h.device(place) = u * (c - h_p) + h_p;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -185,10 +185,10 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// backward for unactivated update gate
|
|
|
|
|
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
|
|
|
|
|
d_g.slice(u_offsets, extents), d_h * (h_p - c));
|
|
|
|
|
d_g.slice(u_offsets, extents), d_h * (c - h_p));
|
|
|
|
|
// backward for unactivated output candidate
|
|
|
|
|
ActGradCompute(context.Attr<int>("activation"), place, c, c,
|
|
|
|
|
d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u));
|
|
|
|
|
d_g.slice(c_offsets, extents), d_h * u);
|
|
|
|
|
// backward for reset_hidden_prev
|
|
|
|
|
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
|
|
|
|
|
frame_size, frame_size, 1,
|
|
|
|
@ -210,7 +210,7 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
frame_size, gate_grad_data, frame_size * 3, 0,
|
|
|
|
|
weight_grad_data, frame_size * 2);
|
|
|
|
|
// backward for hidden_prev
|
|
|
|
|
d_h_p.device(place) = d_r_h_p * r + d_h * u;
|
|
|
|
|
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
|
|
|
|
|
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
|
|
|
|
|
frame_size, frame_size * 2, 1, gate_grad_data,
|
|
|
|
|
frame_size * 3, weight_data, frame_size * 2, 1,
|
|
|
|
|