|
|
|
@ -779,7 +779,8 @@ def get_bprop_unsorted_segment_sum(self):
|
|
|
|
|
"""Generate bprop for UnsortedSegmentSum"""
|
|
|
|
|
|
|
|
|
|
def bprop(x, segment_ids, num_segments, out, dout):
|
|
|
|
|
return _gather_drop_negatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
|
|
|
|
|
return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
|
|
|
|
|
zeros_like(num_segments)
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
@ -827,7 +828,7 @@ def get_bprop_unsorted_segment_prod(self):
|
|
|
|
|
gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
|
|
|
|
|
prod_divided_by_x = gathered_prod / x
|
|
|
|
|
partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
|
|
|
|
|
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices)
|
|
|
|
|
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices, None)
|
|
|
|
|
dx = gathered_grad * partial_derivative
|
|
|
|
|
return dx, zeros_like(segment_ids), zeros_like(num_segments)
|
|
|
|
|
|
|
|
|
|