!205 add ReverseSequenceGrad

Merge pull request !205 from xutianchun/reverse_sequence_grad
pull/3198/head
mindspore-ci-bot 5 years ago committed by Gitee
commit 0958a695a4

@ -580,3 +580,14 @@ def get_bprop_batch_to_space_nd(self):
dx = batch_to_space_nd_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(P.ReverseSequence)
def get_bprop_reverse_sequence(self):
"""Generate bprop for ReverseSequence"""
reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)
def bprop(x, seq_lengths, out, dout):
dx = reverse_sequence_grad(dout, seq_lengths)
return dx, zeros_like(seq_lengths)
return bprop

@ -1378,6 +1378,11 @@ test_case_array_ops = [
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)),
Tensor(np.array([0, 1, 1]).astype(np.int32))],
'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}),
('ReverseSequence', {
'block': P.ReverseSequence(1, 0),
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)),
Tensor(np.array([1, 2, 3]).astype(np.int32))],
'desc_bprop': [[3, 3]]}),
]
test_case_other_ops = [

Loading…
Cancel
Save