From edd090eee22bd03509f9375e983336371c6c4c09 Mon Sep 17 00:00:00 2001 From: Lian Date: Wed, 28 Oct 2020 14:40:49 +0800 Subject: [PATCH] fix pynative mixprecision's int and other type's error cast --- mindspore/nn/cell.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index e309d0f37d..662ee24464 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -261,6 +261,17 @@ class Cell(Cell_): object.__delattr__(self, name) self._attr_synced = False + def _cast_mixed_precision_inputs(self, inputs, dst_type): + res = list() + for item in inputs: + if isinstance(item, tuple): + res.append(self._cast_mixed_precision_inputs(item, dst_type)) + elif item.dtype in {mstype.float16, mstype.float32}: + res.append(cast(item, dst_type)) + else: + res.append(item) + return tuple(res) + def cast_inputs(self, inputs, dst_type): res = list() for item in inputs: @@ -299,9 +310,9 @@ class Cell(Cell_): cast_inputs = list() if hasattr(self, "_mindspore_flags"): if self._mindspore_flags.get('fp16'): - cast_inputs = self.cast_inputs(inputs, mstype.float16) + cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16) if self._mindspore_flags.get('fp32'): - cast_inputs = self.cast_inputs(inputs, mstype.float32) + cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32) if not cast_inputs: cast_inputs = inputs if self.enable_hook: