From 90967215d31d0b64557567399a6ec90c9cdf74d4 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Wed, 9 Dec 2020 15:41:09 +0800 Subject: [PATCH] change cell id for pynative --- mindspore/ccsrc/pipeline/pynative/pynative_execute.cc | 6 ++++-- mindspore/nn/layer/basic.py | 3 --- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index ece0b50b3f..ba79a39f84 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1384,12 +1384,14 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & std::string arg_id = GetId(args[i]); auto it = node_abs_map_.find(arg_id); if (it != node_abs_map_.end()) { - cell_id += it->second->ToString(); + cell_id += "_" + it->second->BuildShape()->ToString(); + cell_id += "_" + it->second->BuildType()->ToString(); } else { auto abs = PyAttrValue(args[i])->ToAbstract(); auto config = abstract::AbstractBase::kBroadenTensorOnly; abs = abs->Broaden(config); - cell_id += abs->ToString(); + cell_id += "_" + abs->BuildShape()->ToString(); + cell_id += "_" + abs->BuildType()->ToString(); node_abs_map_[arg_id] = abs; } } diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index a290878995..34dac1ae9a 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -29,7 +29,6 @@ from mindspore.ops.primitive import constexpr, Primitive from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from mindspore._checkparam import Rel, Validator -from mindspore.common.api import ms_function from mindspore import context from ..cell import Cell from .activation import get_activation @@ -408,9 +407,7 @@ class ClipByNorm(Cell): self.expand_dims = P.ExpandDims() self.dtype = P.DType() - @ms_function def construct(self, x, clip_norm): - """add ms_function decorator for pynative mode""" mul_x = F.square(x) l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32) cond = self.greater_(l2sum, 0)