|
|
|
@ -55,9 +55,7 @@ class MultiBatchGraphCopyer {
|
|
|
|
|
data_name_order_.push_back(item.first);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void SetDataToDynamicInfo(const map<string, vector<vector<int64_t>>> &designate_shape) {
|
|
|
|
|
data_to_dynamic_info_ = designate_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDynamicType(const DynamicType dynamic_type) {
|
|
|
|
|
dynamic_type_ = dynamic_type;
|
|
|
|
|
}
|
|
|
|
@ -69,15 +67,26 @@ class MultiBatchGraphCopyer {
|
|
|
|
|
|
|
|
|
|
// label status for origin_all_nodes_
|
|
|
|
|
Status LabelStatus();
|
|
|
|
|
Status LabelInBatchBranchStatus();
|
|
|
|
|
void LabelStatusForData(const NodePtr &data);
|
|
|
|
|
void LabelStatusForGetNextSink(const NodePtr &data);
|
|
|
|
|
// add nodes functions
|
|
|
|
|
Status CreateNewNodes();
|
|
|
|
|
|
|
|
|
|
NodePtr InsertShapeDataNode();
|
|
|
|
|
Status InsertSwitchNForData(const NodePtr &data);
|
|
|
|
|
NodePtr InsertGetDynamicDimsNode();
|
|
|
|
|
|
|
|
|
|
Status InsertSwitchNAndUpdateMaxShape(const NodePtr &node);
|
|
|
|
|
|
|
|
|
|
Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index,
|
|
|
|
|
std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn);
|
|
|
|
|
|
|
|
|
|
Status InsertIdentityAfterSwitchN();
|
|
|
|
|
Status UpdateMaxShapeToData(const NodePtr &data);
|
|
|
|
|
Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index);
|
|
|
|
|
Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);
|
|
|
|
|
|
|
|
|
|
Status InsertMergeForEdgeNode(const NodePtr &node);
|
|
|
|
|
Status LinkGetDynamicDimsToNetOutput(const NodePtr &node);
|
|
|
|
|
|
|
|
|
|
/// Insert a merge node for src node `node` on output index `index`. The merge node will be used to merge all nodes
|
|
|
|
|
/// in batch-branch to one output to the node out of the batch-branch.
|
|
|
|
@ -95,12 +104,16 @@ class MultiBatchGraphCopyer {
|
|
|
|
|
|
|
|
|
|
// link edges functions
|
|
|
|
|
Status LinkEdges();
|
|
|
|
|
Status LinkDataToSwitchN(const NodePtr &data);
|
|
|
|
|
Status AddAttrForGetDynamicDims(const NodePtr &node);
|
|
|
|
|
Status AddLinkForGetDynamicDims(const NodePtr &node);
|
|
|
|
|
|
|
|
|
|
Status LinkDataToSwitchN(const NodePtr &data, const NodePtr &switchn, const int &out_index);
|
|
|
|
|
Status LinkToMerge(const NodePtr &node);
|
|
|
|
|
Status LinkToNodeInBranch(const NodePtr &node);
|
|
|
|
|
Status LinkToNodeOutBranch(const NodePtr &node);
|
|
|
|
|
Status LinkDataToMerge(const NodePtr &data, const NodePtr &merge);
|
|
|
|
|
Status LinkDataToMerge(const NodePtr &data, const NodePtr &merge, const NodePtr &switchn);
|
|
|
|
|
Status LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge);
|
|
|
|
|
NodePtr FindSwitchnNodeForDataEdge(const OutDataAnchorPtr &data_out_anchor, const NodePtr &origin_node);
|
|
|
|
|
Status CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr ©ed_node);
|
|
|
|
|
Status CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr ©ed_node);
|
|
|
|
|
Status CheckAndParseDynamicData();
|
|
|
|
@ -127,6 +140,11 @@ class MultiBatchGraphCopyer {
|
|
|
|
|
// the data nodes, and the SwitchN nodes inserted after it
|
|
|
|
|
std::map<Node *, NodePtr> data_nodes_to_switchn_;
|
|
|
|
|
|
|
|
|
|
// the getnext_sink nodes, and the SwitchN nodes inserted after it
|
|
|
|
|
std::vector<std::vector<std::pair<Node *, NodePtr>>> getnext_nodes_to_switchn_;
|
|
|
|
|
std::vector<std::vector<std::pair<int, int>>> outidx_inidx_mappings_;
|
|
|
|
|
std::vector<std::pair<int, int>> outidx_inidx_mapping_;
|
|
|
|
|
|
|
|
|
|
// the nodes on the in/out-batch-branch edge, and the merge nodes inserted after it
|
|
|
|
|
std::map<Node *, std::vector<NodePtr>> nodes_to_merge_nodes_;
|
|
|
|
|
|
|
|
|
@ -142,6 +160,9 @@ class MultiBatchGraphCopyer {
|
|
|
|
|
|
|
|
|
|
// dynamic type : dynamic batch,, dynamic image size, dynamic dims.
|
|
|
|
|
DynamicType dynamic_type_ = DynamicType::kDynamicUnknown;
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<size_t, size_t>> getnext_sink_dynamic_out_mapping_;
|
|
|
|
|
bool getnext_sink_dynamic_dims_ = false;
|
|
|
|
|
};
|
|
|
|
|
} // namespace multibatch
|
|
|
|
|
} // namespace ge
|
|
|
|
|