update get rank in parallel ops

pull/8556/head
yangzhenzhang 4 years ago
parent 8612885167
commit 0c2c76d037

@ -113,7 +113,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co
size_t type_length, TypePtr type, CostPtr *cost) {
MS_EXCEPTION_IF_NULL(prev_op_);
MS_EXCEPTION_IF_NULL(cost);
RankList dev_list = prev_op_->global_device_list();
RankList dev_list = prev_op_->stage_device_list();
TensorRedistribution tensor_redistribution(false);
// Init TensorRedistribution

@ -140,7 +140,12 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
const std::string &backend) {
if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
MS_LOG(ERROR) << "Invalid backend: " << backend;
return Status::FAILED;
return FAILED;
}
if (stage_map.empty() || devices.empty()) {
MS_LOG(ERROR) << "The size of stage_map and devices must be positive";
return FAILED;
}
for (auto &dev : devices) {
@ -153,11 +158,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
int64_t num_device = stage;
if (num_device > MAX_DEVICE_NUM) {
MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM;
return Status::FAILED;
return FAILED;
}
if (num_device <= 0) {
MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive";
return Status::FAILED;
return FAILED;
}
RankList curr_dev_list;
for (int64_t i = 0; i < num_device; ++i) {
@ -170,10 +175,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank);
device_ = dev;
set_global_rank(global_device_rank);
set_stage_num(static_cast<const int64_t>(stage_map.size()));
int64_t stage_id = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
set_stage_id(stage_id);
global_rank_ = global_device_rank;
stage_num_ = static_cast<const int64_t>(stage_map.size());
stage_id_ = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
rank_index_in_stage_ = global_rank_ - stage_id_ * (static_cast<const int64_t>(devices.size()) / stage_num_);
stage_device_num_ = static_cast<const int64_t>(devices.size()) / stage_num_;
backend_ = backend;
@ -185,10 +191,13 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
gm_.set_world_group(UNDEFINED_WORLD_GROUP);
}
MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank
<< ", the backend: " << backend << ", the stage num: " << stage_num() << ", the stage id: " << stage_id;
return Status::SUCCESS;
<< ", the backend: " << backend << ", the stage num: " << stage_num_ << ", the stage id: " << stage_id_
<< ", the rank index in stage is: " << rank_index_in_stage_;
return SUCCESS;
}
RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); }
RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
if (LongToSize(stage_id) >= stage_devices_.size())
MS_LOG(ERROR) << "the 'stage_id': " << stage_id
@ -204,49 +213,6 @@ RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
return res;
}
RankList DeviceManager::global_device_list(int64_t stage_id, int64_t rank, int64_t split_num) const {
RankList res;
if (split_num <= 0) {
return res;
}
if (LongToSize(stage_id) >= stage_devices_.size()) {
MS_LOG(ERROR) << "the 'stage_id': " << stage_id
<< ", is out of the scope of 'stage_devices_': " << stage_devices_.size();
return res;
}
RankList global_list = GetDeviceListByStageId(stage_id);
if (global_list.size() % LongToSize(split_num)) {
MS_LOG(ERROR) << "dev list size(" << global_list.size() << ") can not be divisible by split num: " << stage_id;
return res;
}
std::vector<int64_t> dev_list;
(void)std::copy(global_list.begin(), global_list.end(), std::back_inserter(dev_list));
size_t index = 0;
size_t slice_size = dev_list.size() / LongToSize(split_num);
for (int64_t i = 0; i < split_num; ++i) {
bool found = false;
index = slice_size * LongToSize(i);
for (size_t j = 0; j < slice_size; ++j) {
if (dev_list[index + j] == rank) {
found = true;
break;
}
}
if (found) {
break;
}
}
for (size_t k = 0; k < slice_size; ++k) {
res.push_back(dev_list[index + k]);
}
return res;
}
Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); }
std::vector<Device> DeviceManager::CreateDeviceListByRankList(RankList ranks) {

@ -57,14 +57,14 @@ std::string HashName(const std::string &rank_list_name);
class DeviceManager {
// This class is used to manage the abstract devices, including group-related and stage-related management.
public:
DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(1), stage_id_(0) { gm_ = GroupManager(); }
DeviceManager() { gm_ = GroupManager(); }
~DeviceManager() = default;
Status Init(const RankList &devices, int64_t local_device, const RankList &stage_map, const std::string &backend);
static DeviceManager &GetInstance();
RankList GetDeviceListByStageId(int64_t stage_id) const;
RankList global_device_list(int64_t stage_id, int64_t rank, int64_t split_num) const;
RankList GetDeviceListInThisStage() const;
Device CreateNewDeviceByRank(int64_t rank) const;
std::vector<Device> CreateDeviceListByRankList(RankList ranks);
@ -74,17 +74,11 @@ class DeviceManager {
Group CreateGroup(const RankList &dev_ranks);
size_t DeviceNum() const { return devices_.size(); }
int64_t stage_num() const { return stage_num_; }
void set_stage_num(int64_t num) { stage_num_ = num; }
int64_t stage_id() const { return stage_id_; }
void set_stage_id(int64_t id) { stage_id_ = id; }
std::string backend() const { return backend_; }
int64_t rank_index_in_stage() const { return rank_index_in_stage_; }
int64_t global_rank() const { return global_rank_; }
void set_global_rank(int64_t global_rank) { global_rank_ = global_rank; }
std::string backend() const { return backend_; }
void Clear();
std::string world_group() const { return gm_.world_group(); }
@ -102,10 +96,11 @@ class DeviceManager {
std::map<std::string, std::string> rank_to_group_; // the key is rank list, value is hash name
std::map<std::string, std::string> group_to_rank_; // the key is hash name, value is rank list
int64_t local_rank_;
int64_t global_rank_;
int64_t stage_num_;
int64_t stage_id_;
int64_t global_rank_ = 0; // the real rank in all devices
int64_t stage_num_ = 0; // the stage num
int64_t stage_id_ = 0; // the stage id of the global_rank_
int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage
int64_t stage_device_num_ = 0; // the device num of one stage
};
} // namespace parallel
} // namespace mindspore

@ -232,7 +232,7 @@ Status GatherV2Info::InferTensorSubOps() {
MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ").";
}
int64_t mod_p = mod_n * dev_matrix_shape_.at(axis_);
int64_t rank = g_device_manager->global_rank();
int64_t rank = g_device_manager->rank_index_in_stage();
int64_t mod_rank = rank % mod_p;
mod_rank = static_cast<int64_t>(mod_rank / mod_n);
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {

@ -451,7 +451,7 @@ Status GatherV2PInfo::InferTensorInfo() {
Shape input_shape = inputs_shape_.at(0);
Shape input_index_shape = inputs_shape_.at(1);
Shape output_shape = outputs_shape_.at(0);
int64_t rank = g_device_manager->global_rank();
int64_t rank = g_device_manager->rank_index_in_stage();
// infer tensor layout
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
if (manual_split_) {
@ -481,7 +481,7 @@ Status GatherV2PInfo::InferTensorInfo() {
Status GatherV2PInfo::InferBias() {
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
int64_t rank = g_device_manager->rank_index_in_stage();
auto input_shape = inputs_shape_.at(0);
auto params_strategy = strategy_->GetInputDim().at(0);
// axis don't split
@ -513,7 +513,7 @@ Status GatherV2PInfo::InferBias() {
Status GatherV2PInfo::InferOffset() {
CheckGlobalDeviceManager();
size_t rank = g_device_manager->global_rank();
size_t rank = g_device_manager->rank_index_in_stage();
MS_EXCEPTION_IF_NULL(strategy_);
auto param_strategy = strategy_->GetInputDim()[0];

@ -134,7 +134,7 @@ Status OneHotInfo::InferTensorInfo() {
Status OneHotInfo::ExtractInputInfo() {
CheckGlobalDeviceManager();
rank_ = g_device_manager->global_rank();
rank_ = g_device_manager->rank_index_in_stage();
mod_rank_ = rank_ % old_dev_matrix_back_;
if (!cnode_) {
MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr";

@ -116,7 +116,6 @@ void OperatorInfo::ResetQueueMember() {
replace_op_.clear();
replace_op_info_.clear();
virtual_div_op_.clear();
global_device_list_.clear();
}
Status OperatorInfo::InferAttrs() {
@ -131,14 +130,8 @@ Status OperatorInfo::InferAttrs() {
return SUCCESS;
}
void OperatorInfo::SetDeviceListByStrategy() {
int64_t stage = strategy_->GetInputStage();
CheckGlobalDeviceManager();
global_device_list_ = g_device_manager->GetDeviceListByStageId(stage);
}
Status OperatorInfo::InferRepeatedCalcInfo() {
int64_t g_dev_list_size = SizeToLong(global_device_list_.size());
int64_t g_dev_list_size = stage_device_size_;
int64_t dev_matrix_size =
std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
if (dev_matrix_size == 0) {
@ -155,12 +148,6 @@ Status OperatorInfo::InferRepeatedCalcInfo() {
<< dev_matrix_size;
return FAILED;
}
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
int64_t stage = strategy_->GetInputStage();
local_device_list_ = g_device_manager->global_device_list(stage, rank, repeated_calc_num_);
return SUCCESS;
}
@ -331,7 +318,7 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector
}
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_);
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
RankList group_devices;
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
return FAILED;
@ -354,7 +341,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
}
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_);
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
RankList group_devices;
if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) {
return FAILED;
@ -469,7 +456,6 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat
ResetQueueMember();
strategy_ = strategy;
SetDeviceListByStrategy();
if (InferDevMatrixShape() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed.";
@ -526,7 +512,6 @@ Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &str
ResetQueueMember();
strategy_ = strategy;
SetDeviceListByStrategy();
if (InferDevMatrixShape() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed.";
@ -1325,7 +1310,7 @@ Status OperatorInfo::InferAsLossDivisor() {
}
if (outputs_tensor_map_[0].empty()) {
as_loss_divisor_ = SizeToLong(global_device_list_.size());
as_loss_divisor_ = stage_device_size_;
MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
return SUCCESS;
}

@ -64,6 +64,8 @@ class OperatorInfo {
std::vector<bool> not_parameteter(inputs_shape_.size(), false);
is_parameter_ = not_parameteter;
refkey_parameter_name_ = "";
stage_device_list_ = g_device_manager->GetDeviceListInThisStage();
stage_device_size_ = SizeToLong(stage_device_list_.size());
}
virtual ~OperatorInfo() = default;
@ -119,7 +121,7 @@ class OperatorInfo {
std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; }
const std::string &name() const { return name_; }
void set_name(const std::string &name) { name_ = name; }
RankList global_device_list() const { return global_device_list_; }
RankList stage_device_list() const { return stage_device_list_; }
void AddSuccEdge(const std::shared_ptr<Edge> &e) { succ_edges_.push_back(e); }
void AddPrevEdge(const std::shared_ptr<Edge> &e) { prev_edges_.push_back(e); }
@ -187,7 +189,6 @@ class OperatorInfo {
virtual Status InferTensorInfo() = 0;
virtual Status InferDevMatrixShape() = 0;
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
void SetDeviceListByStrategy();
void SetRepeatedCalcDevMatrix();
void ResetTensorMapIfRepeatedCalc();
Status CreateGroupByDim(size_t axis, std::vector<Group> *group);
@ -231,8 +232,8 @@ class OperatorInfo {
ReplaceGraphPtr replace_graph_;
MirrorOps mirror_ops_;
VirtualDivOp virtual_div_op_;
RankList global_device_list_; // the size of global_device_list equal to the size of stageID
RankList local_device_list_; // the size equal to global_device_list_.size() / repeated_calc_num_
RankList stage_device_list_; // the device list in this stage
int64_t stage_device_size_ = 0;
bool infer_attrs_completed_ = false;
bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel

@ -136,7 +136,7 @@ Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) {
Status RangeInfo::InferNewAttr() {
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
int64_t rank = g_device_manager->rank_index_in_stage();
// If repeated calculation and repeated num as the last dimension of dev-matrix,
// the dev-matrix is [split_num_, repeated_calc_num_], so from rank 0 to rank repeated_calc_num_

@ -531,7 +531,7 @@ Status ArgMaxWithValueInfo::InferAsLossDivisor() {
MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer";
if (outputs_tensor_map_[0].empty()) {
as_loss_divisor_ = SizeToLong(global_device_list_.size());
as_loss_divisor_ = stage_device_size_;
MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor.";
return SUCCESS;
}

@ -172,7 +172,7 @@ Status ReLUV2Info::InferAsLossDivisor() {
}
if (outputs_tensor_map_[0].empty()) {
as_loss_divisor_ = SizeToInt(global_device_list_.size());
as_loss_divisor_ = stage_device_size_;
MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
return SUCCESS;
}

@ -113,7 +113,7 @@ Status ReshapeInfo::GetParameterInput() {
}
Status ReshapeInfo::ComputeReplaceOp() {
RankList dev_list = global_device_list();
RankList dev_list = stage_device_list();
TensorRedistribution tensor_redistribution(!is_generating_costs_, true);
if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) {
if (is_generating_costs_) {
@ -289,13 +289,7 @@ void ReshapeInfo::InferTensorInfoByLayout() {
Status ReshapeInfo::GetAttrs() { return GetParameterInput(); }
void ReshapeInfo::device_number(const StrategyPtr &strategy) {
int64_t stage = 0;
if (strategy != nullptr) {
stage = strategy->GetInputStage();
}
CheckGlobalDeviceManager();
global_device_list_ = g_device_manager->GetDeviceListByStageId(stage);
dev_num_ = SizeToLong(global_device_list_.size());
dev_num_ = stage_device_size_;
MS_ASSERT(dev_num_ > 0);
}

@ -260,7 +260,7 @@ Status SplitInfo::InferAsLossDivisor() {
}
if (outputs_tensor_map_[0].empty()) {
as_loss_divisor_ = SizeToInt(global_device_list_.size());
as_loss_divisor_ = stage_device_size_;
MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
return SUCCESS;
}

@ -325,7 +325,7 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera
if (next_distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed";
}
RankList dev_list = distribute_operator->global_device_list();
RankList dev_list = distribute_operator->stage_device_list();
std::string next_prim_name = GetValueNode<PrimitivePtr>(next_node->input(0))->name();
MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name;
MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString();

@ -161,6 +161,8 @@ class EmbeddingLookup(Cell):
Examples:
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup(4,2)(input_indices)
>>> output.shape
(2, 2, 2)
"""
BATCH_SLICE = "batch_slice"
FIELD_SLICE = "field_slice"

@ -135,6 +135,8 @@ TEST_F(TestDeviceManager, test_StageID) {
ASSERT_EQ(dm_.DeviceNum(), 4);
ASSERT_EQ(dm_.stage_num(), 2);
ASSERT_EQ(dm_.stage_id(), 1);
ASSERT_EQ(dm_.rank_index_in_stage(), 0);
ASSERT_EQ(dm_.GetDeviceListInThisStage().back(), 3);
RankList dev_list_0 = dm_.GetDeviceListByStageId(0);
RankList dev_list_1 = dm_.GetDeviceListByStageId(1);

@ -171,7 +171,7 @@ TEST_F(TestLogSoftmaxInfo, GetDeviceList1) {
StrategyPtr strategy = NewStrategy(0, inputs);
log_softmax->Init(strategy);
RankList dev_list = log_softmax->global_device_list();
RankList dev_list = log_softmax->stage_device_list();
ASSERT_EQ(dev_list.size(), 128);
}

Loading…
Cancel
Save