diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 8b13097990..12ba8bb7e3 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1384,12 +1384,14 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & std::string arg_id = GetId(args[i]); auto it = node_abs_map_.find(arg_id); if (it != node_abs_map_.end()) { - cell_id += it->second->ToString(); + cell_id += "_" + it->second->BuildShape()->ToString(); + cell_id += "_" + it->second->BuildType()->ToString(); } else { auto abs = PyAttrValue(args[i])->ToAbstract(); auto config = abstract::AbstractBase::kBroadenTensorOnly; abs = abs->Broaden(config); - cell_id += abs->ToString(); + cell_id += "_" + abs->BuildShape()->ToString(); + cell_id += "_" + abs->BuildType()->ToString(); node_abs_map_[arg_id] = abs; } } diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index b5c575fd3a..83654c147f 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -29,7 +29,6 @@ from mindspore.ops.primitive import constexpr, Primitive from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from mindspore._checkparam import Rel, Validator -from mindspore.common.api import ms_function from mindspore import context from ..cell import Cell from .activation import get_activation @@ -413,9 +412,7 @@ class ClipByNorm(Cell): self.expand_dims = P.ExpandDims() self.dtype = P.DType() - @ms_function def construct(self, x, clip_norm): - """add ms_function decorator for pynative mode""" mul_x = F.square(x) l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32) cond = self.greater_(l2sum, 0)