!6673 Add stage information for ops and strategy

Merge pull request !6673 from huangxinjing/stage_strategy
pull/6673/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9bd34a1b29

@ -63,6 +63,8 @@ void ParallelContext::Reset() {
all_reduce_fusion_split_indices_.clear();
all_reduce_fusion_split_sizes_.clear();
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
stages_.clear();
pipeline_stage_split_num_ = 0;
}
void ParallelContext::set_device_num(int32_t device_num) {
@ -83,6 +85,10 @@ void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
void ParallelContext::set_pipeline_stage_split_num(const int32_t stage_num) { pipeline_stage_split_num_ = stage_num; }
void ParallelContext::set_stage(const std::vector<int32_t> &stages) { stages_ = stages; }
bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) {
auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
if (iter == PARALLEL_MODE_LIST.end()) {

@ -67,6 +67,12 @@ class ParallelContext {
void set_device_num(int32_t device_num);
int32_t device_num() const { return device_num_; }
void set_pipeline_stage_split_num(const int32_t stages);
int32_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
void set_stage(const std::vector<int32_t> &stages);
std::vector<int32_t> stage() const { return stages_; }
void set_global_rank(int32_t global_rank);
int32_t global_rank() const { return global_rank_; }
@ -115,6 +121,8 @@ class ParallelContext {
int32_t global_rank_;
std::string parallel_mode_;
std::string strategy_search_mode_;
std::vector<int32_t> stages_;
int32_t pipeline_stage_split_num_;
bool parameter_broadcast_;
bool device_num_is_set_;
bool global_rank_is_set_;

@ -36,7 +36,8 @@ Stage::Stage(const std::vector<mindspore::parallel::Device> &devices, int num, i
// NOTE: '-1' indicates ERROR
int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); }
bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) {
bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend,
const std::vector<int32_t> &stage) {
if (device_num <= 0) {
MS_LOG(ERROR) << "'device_num' must be positive.";
return false;
@ -68,7 +69,30 @@ bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &back
devices.push_back(i);
}
stage_map.push_back(device_num);
if (stage.size()) {
int32_t summed_value = 0;
for (auto begin = stage.begin(); begin != stage.end(); ++begin) {
if (*begin <= 0) {
MS_LOG(ERROR) << "The value in the pipeline stages should be positive value";
return false;
}
summed_value += *begin;
stage_map.push_back(*begin);
}
if (summed_value != device_num) {
MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num "
<< device_num;
return false;
}
} else {
stage_map.push_back(device_num);
}
for (auto &y : stage_map) {
MS_LOG(DEBUG) << "Obtained stage id :" << y;
}
g_device_manager = std::make_shared<DeviceManager>();
if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) {
MS_LOG(INFO) << "Device initialization succeeds.";

@ -70,7 +70,7 @@ class Stage {
// This method is used for initializing the global DeviceManager 'g_device_manager',
// arguments including 'device_num' and 'global_rank'
bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend);
bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend, const std::vector<int32_t> &stage);
void CheckGlobalDeviceManager();

@ -126,9 +126,22 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *ra
}
}
Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_);
// Convert the global rank to the local rank(The index of the array) to compute the coordinate
uint32_t local_rank = 0;
for (auto &tmp_rank : dev_list_) {
Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_);
if (tmp_rank == rank_) {
break;
}
++local_rank;
}
if (local_rank == dev_list_.size()) {
MS_LOG(ERROR) << "Rank id: " << local_rank << "is not in the device list.";
return FAILED;
}
Shape current_rank_coordinate = ConvertRankToCoordinate((int32_t)local_rank, dev_shape_);
for (uint32_t loop_local_rank = 0; loop_local_rank < dev_list_.size(); ++loop_local_rank) {
Shape tmp_rank_coordinate = ConvertRankToCoordinate(loop_local_rank, dev_shape_);
bool matched = true;
for (auto &map : tensor_map) {
if (map == MAP_NONE) {
@ -141,7 +154,7 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *ra
}
}
if (matched) {
rank_list->push_back(tmp_rank);
rank_list->push_back(dev_list_[loop_local_rank]);
}
}

@ -43,7 +43,7 @@ Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
// dropout don't support repeated calculation
CheckGlobalDeviceManager();
auto input_strategy = strategy->GetInputDim().at(0);
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int>());
if (IntToSize(product_p) != dev_num) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";

@ -196,7 +196,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
// Don't support repeated calc
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
if (IntToSize(product_p) < dev_num) {
MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
@ -269,7 +269,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
// param_strategy(axis) != 1, Don't support repeated calc
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) {
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
@ -346,7 +346,7 @@ Status GatherV2PInfo::InferDevMatrixShape() {
out_dev_matrix_shape_ = dev_matrix_shape_;
}
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>());
if (param_product * index_product < SizeToInt(dev_num)) {
@ -516,10 +516,11 @@ Status GatherV2PInfo::InferGroup() {
if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
dim = (axis_ + 1) % 2;
}
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_);
int32_t rank = g_device_manager->global_rank();
RankList dev_list = g_device_manager->GetDeviceListByStageId(0);
DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_);
RankList group_devices;
if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) {

@ -162,7 +162,8 @@ class OperatorInfo {
void set_type(const std::string &type) { type_ = type; }
const std::string &type() const { return type_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
int32_t stage_id() const { return stage_id_; }
// Key for user data.
constexpr static char key[] = "OpInfo";
@ -205,6 +206,7 @@ class OperatorInfo {
std::vector<ValuePtr> input_value_;
TypePtr outputs_dtype_;
int32_t stage_id_ = 0;
StrategyPtr strategy_;
std::vector<TensorInfo> inputs_tensor_info_;
std::vector<TensorInfo> outputs_tensor_info_;

@ -55,6 +55,7 @@ constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only";
constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only";
constexpr char CHECK_SET_STRATEGY_VALID_ONCE_ONLY[] = "check_set_strategy_valid_once_only";
constexpr char STRATEGY[] = "strategy";
constexpr char STAGE_ATTR[] = "stage";
constexpr char GEN_STRATEGY[] = "gen_strategy";
constexpr char REDUCE_OP_SUM[] = "sum";
constexpr char REDUCE_OP_MAX[] = "max";

@ -133,9 +133,9 @@ Status ReduceMethod::InferTensorMap() {
return SUCCESS;
}
bool IsDataParallelStrategy(const Dimensions &strategy) {
bool IsDataParallelStrategy(const Dimensions &strategy, int32_t stage_id) {
CheckGlobalDeviceManager();
size_t total_dev_num = g_device_manager->GetDeviceListByStageId(0).size();
size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
if (strategy.empty()) {
MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty";
}
@ -145,7 +145,7 @@ bool IsDataParallelStrategy(const Dimensions &strategy) {
Status ReduceMethod::InferForwardCommunication() {
Dimensions stra = strategy_->GetInputDim().at(0);
if (cross_batch_ && IsDataParallelStrategy(stra)) {
if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) {
MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
return SUCCESS;
}
@ -211,7 +211,7 @@ ForwardOp CreatReduceMeanForwardOp(const std::vector<Group> &forward_group, cons
Status ReduceMeanInfo::InferForwardCommunication() {
Dimensions stra = strategy_->GetInputDim().at(0);
if (cross_batch_ && IsDataParallelStrategy(stra)) {
if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) {
MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
return SUCCESS;
}

@ -998,6 +998,17 @@ OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAtt
StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
ValueTuplePtr var = attrs[STRATEGY]->cast<ValueTuplePtr>();
StrategyPtr strategyPtr;
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
auto res = attrs.find(STAGE_ATTR);
int32_t stage_id = 0;
if (res != attrs.end()) {
stage_id = GetValue<int>(res->second);
}
if (stage_id && stages.empty()) {
MS_LOG(ERROR) << "Find stage id:" << stage_id << " but the pipeline_stages is 0.";
return nullptr;
}
MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString();
if (var == nullptr) {
MS_LOG(EXCEPTION) << "Strategy value is nullptr";
@ -1016,13 +1027,13 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
});
strategy.push_back(dim);
} else {
MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue";
MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequence";
}
}
if (strategy.empty()) {
MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy";
}
strategyPtr = NewStrategy(0, strategy);
strategyPtr = NewStrategy(stage_id, strategy);
}
return strategyPtr;
@ -1420,6 +1431,30 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
(void)prim->SetAttrs(attrs_temp);
}
}
// This function aims to check the valid rank and stage in the operations
// If the rank is not valid for the given stage, we chose not to init the strategy of the operation
// For example stage is [4, 4], and the group_list [[0,1,2,3],[4,5,6,7]]
// For stage 0, we require the rank_id is in [0,1,2,3]
Status ValidRankCheck(int32_t global_rank, int32_t strategy_stage) {
RankList local_group_list = g_device_manager->GetDeviceListByStageId(strategy_stage);
int32_t target = global_rank;
if (std::any_of(local_group_list.begin(), local_group_list.end(), [target](int32_t a) { return a == target; })) {
return Status::SUCCESS;
}
return Status::FAILED;
}
Status ValidStageCheck(const std::vector<int32_t> &stages, int32_t strategy_stage) {
if (stages.size() > 0) {
if (strategy_stage >= 0 && strategy_stage < (int32_t)stages.size()) {
return Status::SUCCESS;
}
return Status::FAILED;
} else {
return Status::SUCCESS;
}
}
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
// load strategy map from checkpoint
@ -1429,6 +1464,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
// Get global rank after the checkpoint?
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
@ -1501,7 +1541,18 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
strategyPtr = ExtractStrategy(attrs);
}
if (strategyPtr != nullptr) {
if (operator_->Init(strategyPtr) == FAILED) {
(*operator_).set_stage_id(strategyPtr->GetInputStage());
MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id();
if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) {
MS_LOG(ERROR) << "Find stage " << strategyPtr->GetInputStage() << " for operator " << prim->name()
<< " exceeds the global stage size " << stages.size() << '.';
return;
}
// If the strategy is not valid for the given global rank, then we skip the Init of the strategy
if (ValidRankCheck(global_rank, (*operator_).stage_id()) == FAILED) {
MS_LOG(INFO) << "Find global exceeds the range of the stage, skip the strategy init for operator "
<< prim->name();
} else if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
}
cnode->set_user_data<OperatorInfo>(operator_);
@ -2416,6 +2467,9 @@ Status ParallelInit() {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
int32_t device_num = ParallelContext::GetInstance()->device_num();
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
@ -2431,6 +2485,26 @@ Status ParallelInit() {
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
}
if (device_num <= 0) {
MS_LOG(ERROR) << "Invalid device num " << device_num << " , expected a positive device number";
return FAILED;
}
if (split_stage_num > 0) {
if (device_num % split_stage_num != 0) {
MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num
<< " , as we support only extract devision now";
return FAILED;
}
for (int i = 0; i < split_stage_num; i++) {
stages.push_back(device_num / split_stage_num);
}
} else if (split_stage_num < 0) {
MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << " , expected a positive stage number";
return FAILED;
}
ParallelContext::GetInstance()->set_stage(stages);
uint32_t world_rank_size = 0;
if (!ParallelContext::GetInstance()->device_num_is_set()) {
if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
@ -2449,7 +2523,12 @@ Status ParallelInit() {
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
}
if (!InitDevice(device_num, global_rank, communication_backend)) {
if (!stages.empty() && parallel_mode != SEMI_AUTO_PARALLEL) {
MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL;
return FAILED;
}
if (!InitDevice(device_num, global_rank, communication_backend, stages)) {
MS_LOG(ERROR) << "Init device failed";
return FAILED;
}
@ -2457,6 +2536,7 @@ Status ParallelInit() {
MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
<< ", backend: " << backend << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean()
<< ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
return SUCCESS;
}

@ -152,6 +152,9 @@ PYBIND11_MODULE(_c_expression, m) {
"Set strategy checkpoint save file.")
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num,
"Set pipeline stage split num.")
.def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.")
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,

@ -331,7 +331,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list)
all_reduce_fusion_config=list, pipeline_stages=int)
def set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
@ -357,6 +357,7 @@ def set_auto_parallel_context(**kwargs):
parallel_mode strategy_ckpt_load_file
all_reduce_fusion_config strategy_ckpt_save_file
full_batch
pipeline_stages
=========================== =========================== =================
Args:
@ -399,6 +400,10 @@ def set_auto_parallel_context(**kwargs):
the fusion is closed.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
the devices are distributed alone the pipeline. The total devices will be divided into
'pipeline_stags' stages. This currently could only be used when
parall mode semi_auto_parallel is enabled.
Raises:
ValueError: If input key is not attribute in auto parallel context.
@ -416,10 +421,10 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(full_batch=True)
>>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
>>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
>>> context.set_auto_parallel_context(pipeline_stages=2)
"""
_set_auto_parallel_context(**kwargs)
def get_auto_parallel_context(attr_key):
"""
Gets auto parallel context attribute value according to the key.

@ -102,6 +102,20 @@ class Primitive(Primitive_):
self.add_attr(name, value)
return self
def set_stage(self, stage):
"""
Add stage id to primitive attribute.
Note:
It is valid only in semi auto parallel.
In other parallel modes, please set it to be 0.
Args:
stage (int): The stage id for the current operation
"""
self.add_prim_attr("stage", stage)
return self
def shard(self, strategy):
"""
Add strategies to primitive attribute.

@ -95,6 +95,16 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_global_rank()
def set_pipeline_stages(self, stages):
"""Set the stages of the pipeline"""
self.check_context_handle()
self._context_handle.set_pipeline_stage_split_num(stages)
def get_pipeline_stages(self):
"""Get the stages of the pipeline"""
self.check_context_handle()
return self._context_handle.get_pipeline_stage_split_num()
def set_gradients_mean(self, gradients_mean):
"""
Set gradients_mean flag.
@ -466,6 +476,7 @@ _set_auto_parallel_context_func_map = {
"gradients_mean": auto_parallel_context().set_gradients_mean,
"gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
"pipeline_stages": auto_parallel_context().set_pipeline_stages,
"parallel_mode": auto_parallel_context().set_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
@ -482,6 +493,7 @@ _get_auto_parallel_context_func_map = {
"gradients_mean": auto_parallel_context().get_gradients_mean,
"gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
"pipeline_stages": auto_parallel_context().get_pipeline_stages,
"parallel_mode": auto_parallel_context().get_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
@ -569,7 +581,6 @@ def _get_auto_parallel_context(attr_key):
get_func = _get_auto_parallel_context_func_map[attr_key]
return get_func()
def _reset_auto_parallel_context():
"""
Reset auto parallel context attributes to the default values:
@ -584,5 +595,6 @@ def _reset_auto_parallel_context():
- strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False
- auto_parallel_search_mode: dynamic_programming
- pipeline_stages: 0
"""
auto_parallel_context().reset()

@ -83,6 +83,39 @@ TEST_F(TestDeviceMatrix, TestCornerCaseGetAlongDim) {
EXPECT_THROW({ DeviceMatrix arr(3, dev_list, shape); }, std::runtime_error);
}
TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceOne) {
RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0};
Shape tensor_map = {-1, 0};
RankList rank_list;
Shape shape = {4, 2};
DeviceMatrix arr(0, dev_list, shape);
arr.GetDevicesByTensorMap(tensor_map, &rank_list);
RankList rank_list_except = {3, 9, 100, 0};
ASSERT_EQ(rank_list, rank_list_except);
}
TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceTwo) {
RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0};
Shape tensor_map = {1, 0};
RankList rank_list;
Shape shape = {4, 2};
DeviceMatrix arr(0, dev_list, shape);
arr.GetDevicesByTensorMap(tensor_map, &rank_list);
RankList rank_list_except = {0};
ASSERT_EQ(rank_list, rank_list_except);
}
TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapNoramalOrder2D) {
RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7};
Shape tensor_map = {-1, 0};
RankList rank_list;
Shape shape = {4, 2};
DeviceMatrix arr(6, dev_list, shape);
arr.GetDevicesByTensorMap(tensor_map, &rank_list);
RankList rank_list_except = {0, 2, 4, 6};
ASSERT_EQ(rank_list, rank_list_except);
}
TEST_F(TestDeviceMatrix, TestCornerCase2GetAlongDim) {
// Rank is out of range
RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7};

@ -0,0 +1,89 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
class Net(nn.Cell):
def __init__(self, axis=0, stage1=0, stage2=0, strategy1=None, strategy2=None, shape=None, target=""):
super().__init__()
if shape is None:
shape = [64, 64]
self.gatherv2 = P.GatherV2().shard(strategy1).add_prim_attr("primitive_target", target)
self.mul = P.Mul().shard(strategy2)
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.gatherv2.set_stage(stage1)
self.mul.set_stage(stage2)
self.axis = axis
def construct(self, x, y):
out = self.gatherv2(x, self.index, self.axis)
out = self.mul(out, y)
return out
def test_gatherv2_semi_samestage1():
context.set_auto_parallel_context(device_num=8, global_rank=0, \
parallel_mode="semi_auto_parallel", pipeline_stages=2)
strategy1 = ((1, 2), (1, 1))
strategy2 = ((2, 1, 1), (2, 1, 1))
net = GradWrap(NetWithLoss(Net(0, 0, 0, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_semi_samestage2():
context.set_auto_parallel_context(device_num=8, global_rank=5, \
parallel_mode="semi_auto_parallel", pipeline_stages=2)
strategy1 = ((1, 2), (1, 1))
strategy2 = ((2, 1, 1), (2, 1, 1))
net = GradWrap(NetWithLoss(Net(0, 1, 1, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)

@ -81,6 +81,11 @@ def test_set_auto_parallel_context():
assert context.get_auto_parallel_context("enable_parallel_optimizer")
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
def test_pipeline_parallel_context():
context.set_auto_parallel_context(device_num=8, global_rank=4,
parallel_mode="semi_auto_parallel", pipeline_stages=2)
stage = auto_parallel_context().get_pipeline_stages()
assert stage == 2
def test_reset_auto_parallel_context():
context.reset_auto_parallel_context()
@ -92,6 +97,8 @@ def test_reset_auto_parallel_context():
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
device_num_is_set = auto_parallel_context().get_device_num_is_set()
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
stage = auto_parallel_context().get_pipeline_stages()
assert device_num == 1
assert global_rank == 0
assert not gradients_mean
@ -100,3 +107,4 @@ def test_reset_auto_parallel_context():
assert not parameter_broadcast
assert not device_num_is_set
assert not parameter_broadcast_is_set
assert not stage

Loading…
Cancel
Save