|
|
|
@ -261,6 +261,17 @@ class Cell(Cell_):
|
|
|
|
|
object.__delattr__(self, name)
|
|
|
|
|
self._attr_synced = False
|
|
|
|
|
|
|
|
|
|
def _cast_mixed_precision_inputs(self, inputs, dst_type):
|
|
|
|
|
res = list()
|
|
|
|
|
for item in inputs:
|
|
|
|
|
if isinstance(item, tuple):
|
|
|
|
|
res.append(self._cast_mixed_precision_inputs(item, dst_type))
|
|
|
|
|
elif item.dtype in {mstype.float16, mstype.float32}:
|
|
|
|
|
res.append(cast(item, dst_type))
|
|
|
|
|
else:
|
|
|
|
|
res.append(item)
|
|
|
|
|
return tuple(res)
|
|
|
|
|
|
|
|
|
|
def cast_inputs(self, inputs, dst_type):
|
|
|
|
|
res = list()
|
|
|
|
|
for item in inputs:
|
|
|
|
@ -299,9 +310,9 @@ class Cell(Cell_):
|
|
|
|
|
cast_inputs = list()
|
|
|
|
|
if hasattr(self, "_mindspore_flags"):
|
|
|
|
|
if self._mindspore_flags.get('fp16'):
|
|
|
|
|
cast_inputs = self.cast_inputs(inputs, mstype.float16)
|
|
|
|
|
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
|
|
|
|
|
if self._mindspore_flags.get('fp32'):
|
|
|
|
|
cast_inputs = self.cast_inputs(inputs, mstype.float32)
|
|
|
|
|
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
|
|
|
|
|
if not cast_inputs:
|
|
|
|
|
cast_inputs = inputs
|
|
|
|
|
if self.enable_hook:
|
|
|
|
|