!7901 fix bug of pynative's mixprecision

Merge pull request !7901 from lianliguang/master
pull/7901/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b27d16c0a5

@ -261,6 +261,17 @@ class Cell(Cell_):
object.__delattr__(self, name)
self._attr_synced = False
def _cast_mixed_precision_inputs(self, inputs, dst_type):
res = list()
for item in inputs:
if isinstance(item, tuple):
res.append(self._cast_mixed_precision_inputs(item, dst_type))
elif item.dtype in {mstype.float16, mstype.float32}:
res.append(cast(item, dst_type))
else:
res.append(item)
return tuple(res)
def cast_inputs(self, inputs, dst_type):
res = list()
for item in inputs:
@ -299,9 +310,9 @@ class Cell(Cell_):
cast_inputs = list()
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
cast_inputs = self.cast_inputs(inputs, mstype.float16)
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
if self._mindspore_flags.get('fp32'):
cast_inputs = self.cast_inputs(inputs, mstype.float32)
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
if not cast_inputs:
cast_inputs = inputs
if self.enable_hook:

Loading…
Cancel
Save