From e4167da1b1bf7180d1bcd156b09085f8b8cd6df5 Mon Sep 17 00:00:00 2001 From: simson Date: Mon, 30 Nov 2020 14:07:16 +0800 Subject: [PATCH] fix bug for upgrading python3.8 --- mindspore/common/parameter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index d40a4b4abe..4456782378 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -128,6 +128,12 @@ class Parameter(MetaTensor_): self.init_in_server = False self._unique = False self.is_in_parallel = _is_in_parallel_mode() + if isinstance(default_input, (MetaTensor, Tensor)): + MetaTensor_.__init__(self, default_input.dtype, default_input.shape) + elif isinstance(default_input, int): + MetaTensor_.__init__(self, mstype.int64, ()) + elif isinstance(default_input, float): + MetaTensor_.__init__(self, mstype.float32, ()) @staticmethod def _get_base_class(input_class):