diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc index 65b576cef9..c017244cfb 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc @@ -201,7 +201,7 @@ Status RecoverStrategy(std::vector eliminations) { right_edge->set_selected_cost(decision->right_edge_cost_); // 'left_node' recovers the strategy. left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); - if (TRIANGLE_STRATEGY_OVERWRITE) { + if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { // 'right_node' recovers the strategy. MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); @@ -225,10 +225,16 @@ Status RecoverStrategy(std::vector eliminations) { MS_EXCEPTION_IF_NULL(succ_nodes[0]); MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); - // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. + // Star is eliminated into 'succ_nodes[0]' succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); for (size_t k = 1; k < succ_nodes.size(); ++k) { - succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); + if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { + // 'succ_nodes[k]' is overwritten strategy and cost. + succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]); + } else { + // In this case, 'succ_nodes[k]' is NOT overwritten strategy and cost, however, it checks the strategy. + succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); + } } MS_LOG(INFO) << "Recover starElimination succeeded."; } else { diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h index ec131e519f..3a84db55b8 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h @@ -20,9 +20,9 @@ #include #include #include -#include "ir/value.h" #include "frontend/parallel/auto_parallel/edge_costmodel.h" #include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "ir/value.h" namespace mindspore { namespace parallel { diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h index 9a09021380..b636048723 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h @@ -22,11 +22,11 @@ #include #include #include -#include "utils/ms_utils.h" #include "frontend/parallel/auto_parallel/costmodel.h" #include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/tensor_layout/tensor_info.h" #include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "utils/ms_utils.h" namespace mindspore { namespace parallel { diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc index 249adb5659..abf4ed8eef 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -40,7 +40,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; int32_t RUN_PHASE = DEFAULT_RUN_PHASE; -bool TRIANGLE_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; +bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; void CostGraph::SetDeviceMemoryAndCostParameter() { MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); @@ -155,12 +155,12 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { MS_LOG(INFO) << "multi_subgraphs: false."; } - auto overwrite = CostModelContext::GetInstance()->triangle_strategy_overwrite(); - TRIANGLE_STRATEGY_OVERWRITE = overwrite; - if (TRIANGLE_STRATEGY_OVERWRITE) { - MS_LOG(INFO) << "triangle_strategy_overwrite: true."; + auto overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); + TRIANGLE_STAR_STRATEGY_OVERWRITE = overwrite; + if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { + MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; } else { - MS_LOG(INFO) << "triangle_strategy_overwrite: false."; + MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; } // RUN_PHASE @@ -1303,7 +1303,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; - if (TRIANGLE_STRATEGY_OVERWRITE) { + if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { new_computation += right_op_cost->computation_cost_; new_memory += right_op_cost->memory_with_reuse_; new_commu_cost += right_op_cost->communication_cost_; @@ -1399,7 +1399,9 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, } if (!valid) { - MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed."; + MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() + << " failed. It may be caused by " + "configuring inconsistent strategies for operators."; } elimi_op->SetNotAlive(); MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded."; @@ -1440,6 +1442,13 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n commu_cost += succ_edges_costs[i]->communication_cost_; commu_forward += succ_edges_costs[i]->communication_forward_; commu_without += succ_edges_costs[i]->communication_without_parameter_; + if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { + computation_cost += succ_nodes_costs[i]->computation_cost_; + memory_cost += succ_nodes_costs[i]->memory_with_reuse_; + commu_cost += succ_nodes_costs[i]->communication_cost_; + commu_forward += succ_nodes_costs[i]->communication_forward_; + commu_without += succ_nodes_costs[i]->communication_without_parameter_; + } } } @@ -1544,7 +1553,9 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo } if (!valid) { - MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed."; + MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() + << " failed. It may be caused by " + "configuring inconsistent strategies for operators."; } merged_op->SetNotAlive(); diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h index 2281f9c7fc..34e56361fc 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -22,11 +22,11 @@ #include #include #include -#include "utils/ms_utils.h" #include "frontend/parallel/auto_parallel/edge_costmodel.h" #include "frontend/parallel/costmodel_context.h" #include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/ops_info/tmp_identity_info.h" +#include "utils/ms_utils.h" namespace mindspore { namespace parallel { @@ -46,7 +46,7 @@ extern bool FULLY_USE_DEVICES; extern bool ELEMENTWISE_OP_STRA_FOLLOW; extern bool MULTI_SUBGRAPHS; extern int32_t RUN_PHASE; -extern bool TRIANGLE_STRATEGY_OVERWRITE; +extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE; class CostGraph { // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc index 63524ec3fe..2d5460bfcc 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc @@ -16,8 +16,8 @@ #include "frontend/parallel/auto_parallel/operator_costmodel.h" -#include #include +#include #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 3a2084a6e2..57b1f96305 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -19,9 +19,9 @@ #include #include #include +#include #include #include -#include #include "frontend/parallel/device_manager.h" diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index f1981aac33..4e71b0d094 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -18,18 +18,18 @@ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ #include -#include #include +#include #include #include +#include "abstract/abstract_value.h" #include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/status.h" -#include "utils/convert_utils.h" #include "ir/anf.h" #include "ir/func_graph.h" +#include "utils/convert_utils.h" #include "utils/info.h" -#include "abstract/abstract_value.h" namespace mindspore { namespace parallel { diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc index e3383ad58d..030d2201f1 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc @@ -64,7 +64,7 @@ void CostModelContext::ResetAlgoParameters() { tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; - triangle_strategy_overwrite_ = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; + triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; } void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { @@ -134,7 +134,9 @@ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { elementwise_stra_follow_ = elementwise_follow; } -void CostModelContext::set_triangle_strategy_overwrite(bool overwrite) { triangle_strategy_overwrite_ = overwrite; } +void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) { + triangle_star_strategy_overwrite_ = overwrite; +} void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.h b/mindspore/ccsrc/frontend/parallel/costmodel_context.h index e809b53e33..6da7a67ce7 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.h +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.h @@ -44,7 +44,7 @@ namespace parallel { #define DEFAULT_RUN_PHASE 0 #define TRAINING_PHASE 0 #define INFERENCE_PHASE 1 -#define DEFAULT_TRIANGLE_STRATEGY_OVERWRITE true; +#define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true; class CostModelContext { public: @@ -135,8 +135,8 @@ class CostModelContext { void set_elementwise_stra_follow(bool); bool elementwise_stra_follow() const { return elementwise_stra_follow_; } - void set_triangle_strategy_overwrite(bool); - bool triangle_strategy_overwrite() const { return triangle_strategy_overwrite_; } + void set_triangle_star_strategy_overwrite(bool); + bool triangle_star_strategy_overwrite() const { return triangle_star_strategy_overwrite_; } void set_run_phase(int32_t); int32_t run_phase() const { return run_phase_; } @@ -172,9 +172,9 @@ class CostModelContext { // MULTI_SUBGRAPHS bool is_multi_subgraphs_; - // In the recovery phase of DP algorithm, when encountering triangle structure, + // In the recovery phase of DP algorithm, when encountering triangle structure and star structure, // whether overwrite the right-node strategy - bool triangle_strategy_overwrite_; + bool triangle_star_strategy_overwrite_; int32_t run_phase_; // 0: 'training', 1: 'inference' diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.h b/mindspore/ccsrc/frontend/parallel/device_manager.h index 79e0487c10..09c7d003a6 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.h +++ b/mindspore/ccsrc/frontend/parallel/device_manager.h @@ -25,13 +25,13 @@ #include #include -#include "utils/ms_utils.h" #include "frontend/parallel/device.h" #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/group_manager.h" #include "frontend/parallel/status.h" #include "frontend/parallel/strategy.h" #include "utils/convert_utils.h" +#include "utils/ms_utils.h" namespace mindspore { namespace parallel { diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.cc b/mindspore/ccsrc/frontend/parallel/group_manager.cc index e6ec6ff68a..5bd019fede 100644 --- a/mindspore/ccsrc/frontend/parallel/group_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/group_manager.cc @@ -17,8 +17,8 @@ #include "frontend/parallel/group_manager.h" #include #include -#include "frontend/parallel/device_manager.h" #include "backend/session/executor_manager.h" +#include "frontend/parallel/device_manager.h" #include "utils/comm_manager.h" #include "utils/ms_context.h" diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 8c2b37b327..25d6cacf0e 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -26,10 +26,8 @@ #include #include #include +#include -#include "ir/anf.h" -#include "ir/param_info.h" -#include "ir/tensor.h" #include "frontend/optimizer/opt.h" #include "frontend/optimizer/optimizer.h" #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" @@ -39,11 +37,14 @@ #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" #include "frontend/parallel/context.h" -#include "frontend/parallel/ops_info/tmp_identity_info.h" -#include "frontend/parallel/ops_info/reshape_info.h" #include "frontend/parallel/graph_util/node_info.h" +#include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/ops_info/tmp_identity_info.h" #include "frontend/parallel/step_parallel.h" #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" +#include "ir/anf.h" +#include "ir/param_info.h" +#include "ir/tensor.h" namespace mindspore { namespace parallel { diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h index 59eb50c33a..13d96ce334 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h @@ -21,9 +21,9 @@ #include #include #include -#include "ir/anf.h" #include "frontend/optimizer/opt.h" #include "frontend/parallel/status.h" +#include "ir/anf.h" #include "pipeline/jit/pipeline.h" namespace mindspore { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index fadd62c2c1..9976e292b4 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -27,8 +27,6 @@ #include #include -#include "ir/tensor.h" -#include "ir/param_info.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/optimizer.h" #include "frontend/parallel/auto_parallel/graph_costmodel.h" @@ -41,9 +39,11 @@ #include "frontend/parallel/node_check.h" #include "frontend/parallel/ops_info/matmul_info.h" #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" +#include "ir/param_info.h" +#include "ir/tensor.h" #include "utils/comm_manager.h" -#include "utils/symbolic.h" #include "utils/ms_context.h" +#include "utils/symbolic.h" using mindspore::tensor::Tensor; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 045cc658bd..c7514b44bd 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -21,10 +21,10 @@ #include #include +#include #include #include #include -#include #include "frontend/optimizer/opt.h" #include "frontend/parallel/strategy.h" diff --git a/mindspore/ccsrc/frontend/parallel/strategy.h b/mindspore/ccsrc/frontend/parallel/strategy.h index 95b09c6cb0..03bee638e7 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy.h +++ b/mindspore/ccsrc/frontend/parallel/strategy.h @@ -23,8 +23,8 @@ #include #include -#include "frontend/parallel/status.h" #include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/status.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/python/parallel/test_auto_parallel_star_partial_strategy.py b/tests/ut/python/parallel/test_auto_parallel_star_partial_strategy.py new file mode 100644 index 0000000000..af3d0ac431 --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_star_partial_strategy.py @@ -0,0 +1,134 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import mindspore as ms +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.parallel._utils import _reset_op_id as reset_op_id +from mindspore import context, Tensor, Parameter +from mindspore.parallel import set_algo_parameters +from tests.ut.python.ops.test_math_ops import VirtualLoss + +grad_all = C.GradOperation(get_all=True) + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x): + predict = self.network(x) + return self.loss(predict) + +class GradWarp(nn.Cell): + def __init__(self, network): + super(GradWarp, self).__init__() + self.network = network + + def construct(self, x): + return grad_all(self.network)(x) + +class Net(nn.Cell): + def __init__(self, strategy_dict=None): + super(Net, self).__init__() + self.mul1 = P.Mul() + self.mul2 = P.Mul() + self.mul3 = P.Mul() + self.mul4 = P.Mul() + self.relu1 = P.ReLU() + self.relu2 = P.ReLU() + self.ba1 = P.BiasAdd() + self.add = P.TensorAdd() + self.weight = Parameter(Tensor(np.ones([128, 1000]), dtype=ms.float32), name="weight") + self.bias = Parameter(Tensor(np.ones([1000]), dtype=ms.float32), name="bias") + + if strategy_dict is not None: + self.mul1.shard(strategy_dict["mul1"]) + self.mul2.shard(strategy_dict["mul2"]) + self.relu1.shard(strategy_dict["relu1"]) + self.relu2.shard(strategy_dict["relu2"]) + self.ba1.shard(strategy_dict["bias_add"]) + self.add.shard(strategy_dict["add"]) + + def construct(self, inputs): + x = self.mul1(inputs, self.weight) + y = self.relu1(x) + y = self.mul2(y, self.weight) + z = self.mul3(x, self.weight) + z = self.ba1(z, self.bias) + x = self.add(y, z) + x = self.mul4(x, self.weight) + x = self.relu2(x) + return x + +def test_star_strategy_consistency1(): + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + set_algo_parameters(fully_use_devices=False) + x = Tensor(np.ones([128, 1000]), dtype=ms.float32) + strategy_dict = {"mul1": ((2, 4), (2, 4)), "mul2": None, "relu1": ((4, 1),), "bias_add": ((8, 1), (1,)), + "relu2": ((2, 2),), "add": ((1, 8), (1, 8))} + net = NetWithLoss(Net(strategy_dict)) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + reset_op_id() + _executor.compile(net, x, phase='train') + + +def test_star_strategy_consistency2(): + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + set_algo_parameters(fully_use_devices=False) + x = Tensor(np.ones([128, 1000]), dtype=ms.float32) + strategy_dict = {"mul1": None, "mul2": ((1, 4), (1, 4)), "relu1": ((2, 1),), "bias_add": ((4, 2), (2,)), + "relu2": ((2, 2),), "add": ((8, 1), (8, 1))} + net = NetWithLoss(Net(strategy_dict)) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + reset_op_id() + _executor.compile(net, x, phase='train') + + +def test_star_strategy_consistency3(): + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + set_algo_parameters(fully_use_devices=False) + x = Tensor(np.ones([128, 1000]), dtype=ms.float32) + strategy_dict = {"mul1": None, "mul2": None, "relu1": ((8, 1),), "bias_add": ((1, 4), (4,)), + "relu2": ((4, 1),), "add": ((2, 2), (2, 2))} + net = NetWithLoss(Net(strategy_dict)) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + reset_op_id() + _executor.compile(net, x, phase='train') + + +def test_star_strategy_consistency4(): + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + set_algo_parameters(fully_use_devices=False) + x = Tensor(np.ones([128, 1000]), dtype=ms.float32) + strategy_dict = {"mul1": ((1, 8), (1, 8)), "mul2": ((1, 4), (1, 4)), "relu1": None, "bias_add": None, + "relu2": None, "add": None} + net = NetWithLoss(Net(strategy_dict)) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + reset_op_id() + with pytest.raises(RuntimeError): + _executor.compile(net, x, phase='train')