|
|
|
@ -233,7 +233,8 @@ void InitCostGraph() {
|
|
|
|
|
entire_costgraph->Init();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) {
|
|
|
|
|
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
|
|
|
|
|
StrategyMap *stra_map) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
@ -290,7 +291,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|
|
|
|
// If no strategy has been configured for this operator, then candidate strategies are generated for
|
|
|
|
|
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
|
|
|
|
|
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
|
|
|
|
|
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
|
|
|
|
|
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt && !is_last_nodes) {
|
|
|
|
|
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
|
|
|
|
|
// BatchParallelInfo operator
|
|
|
|
|
operator_info->ComputeBatchSplitFlagList();
|
|
|
|
@ -307,10 +308,16 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|
|
|
|
} else {
|
|
|
|
|
// In this case, the configured strategy should be extracted to help setting cost
|
|
|
|
|
StrategyPtr strategyPtr;
|
|
|
|
|
if (load_strategy_from_ckpt) {
|
|
|
|
|
strategyPtr = (*stra_map)[strategy_key_name];
|
|
|
|
|
} else {
|
|
|
|
|
if (is_last_nodes) {
|
|
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
|
|
|
|
strategyPtr = GenerateBatchParallelStrategy(operator_info, prim);
|
|
|
|
|
if (full_batch) {
|
|
|
|
|
SetLastNodeStrategy(strategyPtr);
|
|
|
|
|
}
|
|
|
|
|
} else if (StrategyFound(attrs)) {
|
|
|
|
|
strategyPtr = parallel::ExtractStrategy(attrs);
|
|
|
|
|
} else {
|
|
|
|
|
strategyPtr = (*stra_map)[strategy_key_name];
|
|
|
|
|
}
|
|
|
|
|
if (strategyPtr != nullptr) {
|
|
|
|
|
if (prim->name() == RESHAPE) {
|
|
|
|
@ -341,8 +348,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Using CNode's UniqueIds to construct nodes
|
|
|
|
|
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
|
|
|
|
|
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
|
|
|
|
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
|
|
|
|
|
entire_costgraph = std::make_shared<CostGraph>();
|
|
|
|
|
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
|
|
|
|
// The map from CNode's UniqueId to its operatorInfo
|
|
|
|
|
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
|
|
|
|
|
// The operator_infos in a loop
|
|
|
|
@ -356,7 +365,12 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> last_forward_node_ids;
|
|
|
|
|
if (!root->has_flag(TRAINING)) {
|
|
|
|
|
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
|
|
|
|
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
|
|
|
|
}
|
|
|
|
|
// Step 1
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
// NOTE: we only care about splittable Primitive operators
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
@ -401,7 +415,9 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|
|
|
|
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
|
|
|
|
|
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
|
|
|
|
|
last_forward_node_ids.end();
|
|
|
|
|
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
|
|
|
|
|
if (operator_info == nullptr) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
@ -436,8 +452,10 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Using CNode's UniqueIdThroughCopys to construct nodes
|
|
|
|
|
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
|
|
|
|
|
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
|
|
|
|
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
|
|
|
|
|
entire_costgraph = std::make_shared<CostGraph>();
|
|
|
|
|
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
|
|
|
|
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
|
|
|
|
|
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
|
|
|
|
|
// The operator_infos in a loop
|
|
|
|
@ -451,6 +469,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::string> last_forward_node_ids;
|
|
|
|
|
if (!root->has_flag(TRAINING)) {
|
|
|
|
|
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
|
|
|
|
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
|
|
|
|
}
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
// NOTE: we only care about splittable Primitive operators
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
@ -496,7 +519,9 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// In this case, the corresponding OperatorInfo is not created, create the new one.
|
|
|
|
|
auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map);
|
|
|
|
|
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
|
|
|
|
|
last_forward_node_ids.end();
|
|
|
|
|
auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
|
|
|
|
|
if (operator_info == nullptr) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|