diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 253e271e52..a21cfd30bf 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -29,7 +29,6 @@ const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__"; // flag names const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16"; const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; -const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index 6ea584e66d..b84efda770 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -30,7 +30,6 @@ extern const char PYTHON_DATACLASS_FIELDS[]; extern const char GRAPH_FLAG_MIX_PRECISION_FP16[]; extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; -extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; extern const char GRAPH_FLAG_RANDOM_EFFECT[]; diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index ecff453fab..dbe4056844 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -286,7 +286,6 @@ class ClipByNorm(Cell): self.select_ = P.Select() self.greater_ = P.Greater() self.cast = P.Cast() - self.zero = Tensor(np.array([0.0]).astype(np.float32)) self.sqrt = P.Sqrt() self.max_op = P.Maximum() self.shape = P.Shape() @@ -300,7 +299,7 @@ class ClipByNorm(Cell): """add ms_function decorator for pynative mode""" mul_x = F.square(x) l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) - cond = self.greater_(l2sum, self.zero) + cond = self.greater_(l2sum, 0) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) @@ -407,11 +406,13 @@ class OneHot(Cell): super(OneHot, self).__init__() self.onehot = P.OneHot(axis) self.depth = depth - self.on_value = Tensor(on_value, dtype) - self.off_value = Tensor(off_value, dtype) + self.dtype = dtype + self.on_value = on_value + self.off_value = off_value def construct(self, indices): - return self.onehot(indices, self.depth, self.on_value, self.off_value) + return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype)) + class Pad(Cell): diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index 71c2920850..c640f89557 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -133,7 +133,8 @@ class LSTM(Cell): self.transpose2 = P.Transpose() num_directions = 2 if self.bidirectional else 1 self.cpu_target = False - if context.get_context("device_target") == "CPU": + enable_debug = context.get_context("enable_debug_runtime") + if context.get_context("device_target") == "CPU" and not enable_debug: self.cpu_target = True if not self.cpu_target: self.lstm = P.LSTM(input_size=self.input_size, diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 868b2a4d99..4dcff00edf 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -141,7 +141,7 @@ class Optimizer(Cell): if self.is_group_lr: self.learning_rate = ParameterTuple(self.group_lr) else: - self.learning_rate = Parameter(learning_rate, name="learning_rate") + self.learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name="learning_rate") if self.is_group: self.parameters = ParameterTuple(self.group_params) diff --git a/model_zoo/bert/src/bert_model.py b/model_zoo/bert/src/bert_model.py index 5cd90ab84b..8f972f8cec 100644 --- a/model_zoo/bert/src/bert_model.py +++ b/model_zoo/bert/src/bert_model.py @@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell): def __init__(self, length, max_relative_position): super(RelaPosMatrixGenerator, self).__init__() self._length = length - self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) - self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) + self._max_relative_position = max_relative_position + self._min_relative_position = -max_relative_position self.range_length = -length + 1 self.tile = P.Tile() @@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell): self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, max_relative_position=max_relative_position) self.reshape = P.Reshape() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) + self.one_hot = nn.OneHot(depth=self.vocab_size) self.shape = P.Shape() self.gather = P.GatherV2() # index_select self.matmul = P.BatchMatMul() @@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell): if self.use_one_hot_embeddings: flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) one_hot_relative_positions_matrix = self.one_hot( - flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) + flat_relative_positions_matrix) embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) embeddings = self.reshape(embeddings, my_shape) @@ -372,11 +370,9 @@ class SaturateCast(nn.Cell): def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): super(SaturateCast, self).__init__() np_type = mstype.dtype_to_nptype(dst_type) - min_type = np.finfo(np_type).min - max_type = np.finfo(np_type).max - self.tensor_min_type = Tensor([min_type], dtype=src_type) - self.tensor_max_type = Tensor([max_type], dtype=src_type) + self.tensor_min_type = float(np.finfo(np_type).min) + self.tensor_max_type = float(np.finfo(np_type).max) self.min_op = P.Minimum() self.max_op = P.Maximum() @@ -442,7 +438,7 @@ class BertAttention(nn.Cell): self.has_attention_mask = has_attention_mask self.use_relative_positions = use_relative_positions - self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) self.reshape = P.Reshape() self.shape_from_2d = (-1, from_tensor_width) self.shape_to_2d = (-1, to_tensor_width) @@ -471,7 +467,7 @@ class BertAttention(nn.Cell): self.trans_shape = (0, 2, 1, 3) self.trans_shape_relative = (2, 0, 1, 3) self.trans_shape_position = (1, 2, 0, 3) - self.multiply_data = Tensor([-10000.0,], dtype=compute_type) + self.multiply_data = -10000.0 self.batch_num = batch_size * num_attention_heads self.matmul = P.BatchMatMul() diff --git a/model_zoo/deeplabv3/src/deeplabv3.py b/model_zoo/deeplabv3/src/deeplabv3.py index 03bb03ad14..bbfc4dceb3 100644 --- a/model_zoo/deeplabv3/src/deeplabv3.py +++ b/model_zoo/deeplabv3/src/deeplabv3.py @@ -17,7 +17,6 @@ import numpy as np import mindspore.nn as nn from mindspore.ops import operations as P -from mindspore.ops.composite import add_flags from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ DepthwiseConv2dNative, SpaceToBatch, BatchToSpace @@ -122,7 +121,6 @@ class ASPP(nn.Cell): self.feature_shape = feature_shape self.concat = P.Concat(axis=1) - @add_flags(loop_can_unroll=True) def construct(self, x, scale_index=0): aspp0 = self.aspp0(x) aspp1 = self.global_poolings[scale_index](x) diff --git a/model_zoo/mass/src/transformer/transformer_for_infer.py b/model_zoo/mass/src/transformer/transformer_for_infer.py index 8b1a1c4667..99d56ba3a1 100644 --- a/model_zoo/mass/src/transformer/transformer_for_infer.py +++ b/model_zoo/mass/src/transformer/transformer_for_infer.py @@ -275,8 +275,6 @@ class TransformerInferModel(nn.Cell): length_penalty_weight=config.length_penalty_weight, max_decode_length=config.max_decode_length) - self.decoder.add_flags(loop_can_unroll=True) - self.cast = P.Cast() self.dtype = config.dtype self.cast_compute_type = SaturateCast(dst_type=config.compute_type) diff --git a/model_zoo/official/nlp/transformer/src/transformer_model.py b/model_zoo/official/nlp/transformer/src/transformer_model.py index 409f8965eb..fb33f526da 100644 --- a/model_zoo/official/nlp/transformer/src/transformer_model.py +++ b/model_zoo/official/nlp/transformer/src/transformer_model.py @@ -1104,7 +1104,6 @@ class TransformerModel(nn.Cell): beam_width=config.beam_width, length_penalty_weight=config.length_penalty_weight, max_decode_length=config.max_decode_length) - self.tfm_decoder.add_flags(loop_can_unroll=True) self.cast = P.Cast() self.dtype = config.dtype diff --git a/tests/mindspore_test_framework/apps/bert_attention_submodules.py b/tests/mindspore_test_framework/apps/bert_attention_submodules.py index 4ce72ffc84..83729d9e70 100644 --- a/tests/mindspore_test_framework/apps/bert_attention_submodules.py +++ b/tests/mindspore_test_framework/apps/bert_attention_submodules.py @@ -108,7 +108,7 @@ class BertAttentionRelativePositionKeys(nn.Cell): self.trans_shape_position = (1, 2, 0, 3) self.trans_shape_relative = (2, 0, 1, 3) - self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=dtype) + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) self.reshape = P.Reshape() self.multiply = P.Mul() @@ -301,7 +301,7 @@ class BertAttentionRelativePositionValues(nn.Cell): self.trans_shape_position = (1, 2, 0, 3) self.trans_shape_relative = (2, 0, 1, 3) - self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=dtype) + self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) self.trans_shape = (0, 2, 1, 3) self.reshape = P.Reshape() diff --git a/tests/st/networks/models/deeplabv3/src/deeplabv3.py b/tests/st/networks/models/deeplabv3/src/deeplabv3.py index 906a207302..bbfc4dceb3 100644 --- a/tests/st/networks/models/deeplabv3/src/deeplabv3.py +++ b/tests/st/networks/models/deeplabv3/src/deeplabv3.py @@ -276,7 +276,7 @@ class SingleDeepLabV3(nn.Cell): atrous_rates=atrous_rates, output_stride=output_stride, fine_tune_batch_norm=fine_tune_batch_norm) - self.aspp.add_flags(loop_can_unroll=True) + atrous_rates_len = 0 if atrous_rates is not None: atrous_rates_len = len(atrous_rates) diff --git a/tests/ut/python/nn/test_distribution.py b/tests/ut/python/nn/test_distribution.py index 845c64a110..b779814fd5 100644 --- a/tests/ut/python/nn/test_distribution.py +++ b/tests/ut/python/nn/test_distribution.py @@ -259,7 +259,7 @@ class NormalKl(nn.Cell): """ def __init__(self): super(NormalKl, self).__init__() - self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + self.n = nn.Normal(Tensor([3.0]), Tensor([4.0]), dtype=dtype.float32) def construct(self, x_, y_): return self.n('kl_loss', 'Normal', x_, y_) diff --git a/tests/ut/python/pynative_mode/test_remove_unnecessary_phi.py b/tests/ut/python/pynative_mode/test_remove_unnecessary_phi.py index 5cc2ce35cc..35e0687b9d 100644 --- a/tests/ut/python/pynative_mode/test_remove_unnecessary_phi.py +++ b/tests/ut/python/pynative_mode/test_remove_unnecessary_phi.py @@ -20,7 +20,6 @@ from numpy.random import normal from mindspore import Tensor from mindspore import context from mindspore.common.api import ms_function -from mindspore.ops.composite import core def setup_module(module): @@ -34,7 +33,6 @@ def test_remove_phi_and_fv(): """ test_remove_phi_and_fv """ @ms_function - @core(loop_can_unroll=True) def loop(x, input_data): def fv_func(y): return x * y @@ -60,7 +58,6 @@ def test_remove_multiple_phi(): """ test_remove_multiple_phi """ @ms_function - @core(loop_can_unroll=True) def loop(x): def mul(a, b): return a * b @@ -83,7 +80,6 @@ def test_remove_multiple_phi_recursive(): """ test_remove_multiple_phi_recursive """ @ms_function - @core(loop_can_unroll=True) def loop(x): def mul(a, b): return a * b