!4005 unsortedsegsum grad

Merge pull request !4005 from fangzehua/unsortedsegsum
pull/4005/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0b407dfe78

@ -673,6 +673,16 @@ def _GatherDropNegatives(params,
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
@bprop_getters.register(P.UnsortedSegmentSum)
def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""
def bprop(x, segment_ids, num_segments, out, dout):
return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
return bprop
@bprop_getters.register(P.UnsortedSegmentMin)
def get_bprop_unsorted_segment_min(self):
"""Generate bprop for UnsortedSegmentMin"""

@ -1448,14 +1448,12 @@ test_case_nn_ops = [
'block': P.UnsortedSegmentSum(),
'desc_const': [1280],
'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))],
'desc_bprop': [[8192, 1024]],
'skip': ['backward']}),
'desc_bprop': [[1280, 1024]]}),
('UnsortedSegmentSum_1', {
'block': P.UnsortedSegmentSum(),
'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
'desc_bprop': [[4, 1, 3]],
'skip': ['backward']}),
'desc_bprop': [[4, 1, 3]]}),
('UnsortedSegmentMin', {
'block': P.UnsortedSegmentMin(),
'desc_const': [4],

Loading…
Cancel
Save