@ -308,11 +308,8 @@ class CenterNetWithLossScaleCell(nn.Cell):
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if overflow:
succ = False
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return ops.depend(ret, succ)
@ -137,7 +137,7 @@ class GatherFeature(nn.Cell):
self.gather_nd = ops.GatherD()
self.expand_dims = ops.ExpandDims()
self.gather_nd = ops.GatherND()
self.gather_nd = ops.GatherNd()
def construct(self, feat, ind):
"""gather by specified index"""