diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index e666a0b3d7..73d6086707 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -136,7 +136,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; // define the parse constant const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1; const char CUSTOM_BPROP_NAME[] = "bprop"; -const char STAGE_NAME[] = "stage"; +const char STAGE_NAME[] = "pipeline_stage"; // define the Namespace name const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 68e18d7eee..f458e4cd7e 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -471,6 +471,9 @@ class Receive(PrimitiveWithInfer): self.shape = shape self.dtype = dtype self.group = group + valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] + args = {"dtype": dtype} + validator.check_scalar_or_tensor_types_same(args, valid_type, self.name) def infer_shape(self, x_shape=None): return self.shape diff --git a/tests/ut/python/parallel/test_pipeline_split.py b/tests/ut/python/parallel/test_pipeline_split.py index abc09fb44e..957586ddcf 100644 --- a/tests/ut/python/parallel/test_pipeline_split.py +++ b/tests/ut/python/parallel/test_pipeline_split.py @@ -77,7 +77,7 @@ class Net(nn.Cell): self.block = nn.CellList() for i in range(2): cell = MatMulCell(strategy1, strategy2, param) - cell.stage = i + cell.pipeline_stage = i self.block.append(cell) def construct(self, x):