!13670 fix centernet loss error

From: @caojian05
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
pull/13670/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 26e39a9692

@ -308,11 +308,8 @@ class CenterNetWithLossScaleCell(nn.Cell):
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if overflow:
succ = False
else:
succ = self.optimizer(grads)
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return ops.depend(ret, succ)

@ -137,7 +137,7 @@ class GatherFeature(nn.Cell):
self.gather_nd = ops.GatherD()
self.expand_dims = ops.ExpandDims()
else:
self.gather_nd = ops.GatherND()
self.gather_nd = ops.GatherNd()
def construct(self, feat, ind):
"""gather by specified index"""

Loading…
Cancel
Save