|
|
|
@ -22,9 +22,6 @@
|
|
|
|
|
#include "graph/common/omg_util.h"
|
|
|
|
|
#include "graph/utils/type_utils.h"
|
|
|
|
|
|
|
|
|
|
using std::string;
|
|
|
|
|
using std::vector;
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
|
Status MultiBatchPass::Run(ComputeGraphPtr graph) {
|
|
|
|
|
GELOGD("MultiBatchPass Enter");
|
|
|
|
@ -53,7 +50,7 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::vector<int64_t>> batch_shape;
|
|
|
|
|
vector<vector<int64_t>> combined_batch;
|
|
|
|
|
std::vector<std::vector<int64_t>> combined_batch;
|
|
|
|
|
if (!CheckSwitchN(batch_shape, combined_batch)) {
|
|
|
|
|
GELOGE(FAILED, "CheckSwitchN failed.");
|
|
|
|
|
return FAILED;
|
|
|
|
@ -104,6 +101,7 @@ Status MultiBatchPass::ClearStatus() {
|
|
|
|
|
///
|
|
|
|
|
Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) {
|
|
|
|
|
const auto &func_desc = case_node->GetOpDesc();
|
|
|
|
|
GE_CHECK_NOTNULL(func_desc);
|
|
|
|
|
if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
|
|
|
|
|
GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str());
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -114,7 +112,7 @@ Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr
|
|
|
|
|
const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]);
|
|
|
|
|
GE_CHECK_NOTNULL(subgraph);
|
|
|
|
|
|
|
|
|
|
const string batch_label = "Batch_" + std::to_string(i);
|
|
|
|
|
const std::string batch_label = "Batch_" + std::to_string(i);
|
|
|
|
|
for (const auto &node : subgraph->GetDirectNode()) {
|
|
|
|
|
(void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
|
|
|
|
|
}
|
|
|
|
@ -139,12 +137,12 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
|
|
|
|
|
const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
|
|
|
|
|
if (in_data_anchor == nullptr) {
|
|
|
|
|
GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor();
|
|
|
|
|
const auto &pred_input = in_data_anchor->GetPeerOutAnchor();
|
|
|
|
|
if (pred_input == nullptr) {
|
|
|
|
|
GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
@ -178,12 +176,10 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor
|
|
|
|
|
/// @return Status
|
|
|
|
|
///
|
|
|
|
|
Status MultiBatchPass::GetDynamicType() {
|
|
|
|
|
for (const auto &switchn : switch_n_nodes_) {
|
|
|
|
|
auto switchn_desc = switchn->GetOpDesc();
|
|
|
|
|
GE_CHECK_NOTNULL(switchn_desc);
|
|
|
|
|
for (const auto &switch_n : switch_n_nodes_) {
|
|
|
|
|
int32_t dynamic_type = static_cast<int32_t>(FIXED);
|
|
|
|
|
if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) {
|
|
|
|
|
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str());
|
|
|
|
|
if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) {
|
|
|
|
|
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (dynamic_type == static_cast<int32_t>(FIXED)) {
|
|
|
|
@ -191,7 +187,7 @@ Status MultiBatchPass::GetDynamicType() {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) {
|
|
|
|
|
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.",
|
|
|
|
|
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.",
|
|
|
|
|
dynamic_type, dynamic_type_);
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
@ -212,21 +208,19 @@ Status MultiBatchPass::GetDynamicType() {
|
|
|
|
|
Status MultiBatchPass::GetUserDesignateShape() {
|
|
|
|
|
data_name_order_.clear();
|
|
|
|
|
bool first_check = true;
|
|
|
|
|
for (const auto &switchn : switch_n_nodes_) {
|
|
|
|
|
auto switchn_desc = switchn->GetOpDesc();
|
|
|
|
|
GE_CHECK_NOTNULL(switchn_desc);
|
|
|
|
|
vector<string> cur_switchn_data_name_order;
|
|
|
|
|
if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) {
|
|
|
|
|
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str());
|
|
|
|
|
for (const auto &switch_n : switch_n_nodes_) {
|
|
|
|
|
std::vector<std::string> cur_data_name_order;
|
|
|
|
|
if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) {
|
|
|
|
|
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (first_check) {
|
|
|
|
|
data_name_order_ = cur_switchn_data_name_order;
|
|
|
|
|
data_name_order_ = cur_data_name_order;
|
|
|
|
|
first_check = false;
|
|
|
|
|
} else {
|
|
|
|
|
if (data_name_order_ != cur_switchn_data_name_order) {
|
|
|
|
|
if (data_name_order_ != cur_data_name_order) {
|
|
|
|
|
GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.",
|
|
|
|
|
switchn->GetName().c_str());
|
|
|
|
|
switch_n->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -245,7 +239,8 @@ Status MultiBatchPass::GetUserDesignateShape() {
|
|
|
|
|
/// @param [out] combined_batch
|
|
|
|
|
/// @return bool
|
|
|
|
|
///
|
|
|
|
|
bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<vector<int64_t>> &combined_batch) {
|
|
|
|
|
bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape,
|
|
|
|
|
std::vector<std::vector<int64_t>> &combined_batch) {
|
|
|
|
|
// Check if output_num of different SwitchN is same
|
|
|
|
|
uint32_t batch_num = 0;
|
|
|
|
|
for (const NodePtr &node : switch_n_nodes_) {
|
|
|
|
@ -281,7 +276,8 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v
|
|
|
|
|
}
|
|
|
|
|
size_t tmp_combined_dim_num = combined_batch[i].size();
|
|
|
|
|
if (combined_dim_num != tmp_combined_dim_num) {
|
|
|
|
|
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num);
|
|
|
|
|
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.",
|
|
|
|
|
combined_dim_num, i, tmp_combined_dim_num);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -296,11 +292,11 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v
|
|
|
|
|
/// @param [out] combined_batch
|
|
|
|
|
/// @return bool
|
|
|
|
|
///
|
|
|
|
|
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &batch_shape,
|
|
|
|
|
vector<vector<int64_t>> &combined_batch) {
|
|
|
|
|
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape,
|
|
|
|
|
std::vector<std::vector<int64_t>> &combined_batch) {
|
|
|
|
|
// Check if output_shape of different SwitchN is same
|
|
|
|
|
vector<vector<int64_t>> idx_batch_shape;
|
|
|
|
|
vector<vector<int64_t>> idx_combined_batch;
|
|
|
|
|
std::vector<std::vector<int64_t>> idx_batch_shape;
|
|
|
|
|
std::vector<std::vector<int64_t>> idx_combined_batch;
|
|
|
|
|
for (uint32_t i = 0; i < batch_num; i++) {
|
|
|
|
|
idx_batch_shape.clear();
|
|
|
|
|
idx_combined_batch.clear();
|
|
|
|
@ -310,7 +306,7 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &b
|
|
|
|
|
GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
vector<int64_t> output_dims;
|
|
|
|
|
std::vector<int64_t> output_dims;
|
|
|
|
|
if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) {
|
|
|
|
|
GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i);
|
|
|
|
|
return false;
|
|
|
|
@ -385,8 +381,8 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) {
|
|
|
|
|
/// @return Status
|
|
|
|
|
///
|
|
|
|
|
Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value,
|
|
|
|
|
const vector<vector<int64_t>> &batch_shape,
|
|
|
|
|
const vector<vector<int64_t>> &combined_batch) {
|
|
|
|
|
const std::vector<std::vector<int64_t>> &batch_shape,
|
|
|
|
|
const std::vector<std::vector<int64_t>> &combined_batch) {
|
|
|
|
|
NodePtr pred_value_node = pred_value->GetOwnerNode();
|
|
|
|
|
// Create SwitchCase node
|
|
|
|
|
const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN;
|
|
|
|
@ -429,31 +425,11 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t num = output_shape.size();
|
|
|
|
|
size_t dim_num = output_shape[0].size();
|
|
|
|
|
for (size_t i = 1; i < num; i++) {
|
|
|
|
|
size_t tmp_dim_num = output_shape[i].size();
|
|
|
|
|
if (dim_num != tmp_dim_num) {
|
|
|
|
|
GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num);
|
|
|
|
|
for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) {
|
|
|
|
|
if (output_shape[0] != *iter) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dim_num == 0) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < dim_num; i++) {
|
|
|
|
|
int64_t dim_value = output_shape[0][i];
|
|
|
|
|
for (size_t j = 1; j < num; j++) {
|
|
|
|
|
int64_t tmp_dim_value = output_shape[j][i];
|
|
|
|
|
if (dim_value != tmp_dim_value) {
|
|
|
|
|
GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i,
|
|
|
|
|
dim_value, j, tmp_dim_value);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -468,8 +444,8 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s
|
|
|
|
|
///
|
|
|
|
|
NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name,
|
|
|
|
|
const OutDataAnchorPtr &pred_value,
|
|
|
|
|
const vector<vector<int64_t>> &batch_shape,
|
|
|
|
|
const vector<vector<int64_t>> &combined_batch) {
|
|
|
|
|
const std::vector<std::vector<int64_t>> &batch_shape,
|
|
|
|
|
const std::vector<std::vector<int64_t>> &combined_batch) {
|
|
|
|
|
OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN);
|
|
|
|
|
if (op_desc == nullptr) {
|
|
|
|
|
GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str());
|
|
|
|
@ -512,7 +488,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const
|
|
|
|
|
GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str());
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
|
|
|
|
|
const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
|
|
|
|
|
if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) {
|
|
|
|
|
GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str());
|
|
|
|
|
return nullptr;
|
|
|
|
|