change switchn to case and add ut

pull/862/head
zhou_lili 4 years ago
parent 8d35443bb9
commit dd6996e2e9

File diff suppressed because it is too large Load Diff

@ -864,11 +864,13 @@ class DavinciModel {
void ParseDynamicOutShape(const vector<string> &str_info, vector<vector<int64_t>> &vec_info);
bool IsGetNextSinkDynamic(const OpDescPtr &op_desc);
Status InitRealSizeAndShapeInfo(const ComputeGraphPtr &compute_graph, const NodePtr &node);
void GetAllGearsInfo(const NodePtr &node);
Status GetGetDynamicDimsNodeInfo(const NodePtr &node);
Status GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node);
Status GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node);
Status GetGearAndRealOutShapeInfo(size_t input_count, const OpDescPtr &op_desc);
Status GetGearAndRealOutSizeInfo(const ComputeGraphPtr &graph, const NodePtr &node);
Status GetRealOutputSizeOfCase(const ComputeGraphPtr &graph, size_t input_index, const NodePtr &case_node);
Status GetGearAndRealOutShapeInfo(const ComputeGraphPtr &graph, const NodePtr &node);
bool is_weight_mem_has_inited_;
bool is_feature_map_mem_has_inited_;
@ -1021,15 +1023,15 @@ class DavinciModel {
bool is_new_model_desc_{false};
bool is_online_infer_dynamic_ = false;
bool is_getnext_sink_dynamic_ = false;
vector<int64_t> cur_dynamic_dims_;
vector<int32_t> cur_dynamic_dims_;
void *netoutput_last_input_addr_ = nullptr;
int64_t netoutput_last_input_size_ = 0;
size_t shape_of_cur_dynamic_dims_ = 0;
// key: input_index: input is merge node; value: each gear info and each output size
map<size_t, map<vector<int64_t>, int64_t>> merge_nodes_gear_and_real_out_size_info_;
map<size_t, map<vector<int32_t>, int64_t>> merge_nodes_gear_and_real_out_size_info_;
// key: input_index: input is merge node; value: each gear info and each output shape
map<size_t, map<vector<int64_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_;
vector<vector<int64_t>> all_gears_info_;
map<size_t, map<vector<int32_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_;
vector<vector<int32_t>> all_gears_info_;
multimap<uint32_t, uint32_t> op_id_map_;
vector<ProfileInfo> profile_list_;

@ -460,8 +460,8 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d
Status ModelManager::GetCurDynamicDims(const vector<vector<int64_t>> &user_real_input_dims,
const vector<pair<string, vector<int64_t>>> &user_input_dims,
vector<int64_t> &cur_dynamic_dims) {
GELOGD(" Start get cur dynamic dims.");
vector<int32_t> &cur_dynamic_dims) {
GELOGD("Start get cur dynamic dims.");
if (user_real_input_dims.size() != user_input_dims.size()) {
GELOGE(INTERNAL_ERROR,
"The input count of user: %zu should be equal to the data count of graph: %zu",
@ -478,7 +478,7 @@ Status ModelManager::GetCurDynamicDims(const vector<vector<int64_t>> &user_real_
}
for (size_t j = 0; j < user_input_dims.at(i).second.size(); ++j) {
if (user_input_dims.at(i).second.at(j) < 0) {
cur_dynamic_dims.emplace_back(user_real_input_dims[i][j]);
cur_dynamic_dims.emplace_back(static_cast<int32_t>(user_real_input_dims[i][j]));
}
}
}
@ -523,7 +523,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputT
input_data.blobs.push_back(data);
}
if (!GetLocalOmgContext().user_input_dims.empty() && GetLocalOmgContext().need_multi_batch) {
std::vector<int64_t> cur_dynamic_dims;
std::vector<int32_t> cur_dynamic_dims;
if (!GetLocalOmgContext().user_real_input_dims.empty()) {
if (GetCurDynamicDims(GetLocalOmgContext().user_real_input_dims, GetLocalOmgContext().user_input_dims,
cur_dynamic_dims) != SUCCESS) {
@ -531,9 +531,9 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputT
return INTERNAL_ERROR;
}
DataBuffer data;
data.data = new(std::nothrow) int64_t[cur_dynamic_dims.size()];
data.data = new(std::nothrow) int32_t[cur_dynamic_dims.size()];
GE_CHECK_NOTNULL(data.data);
uint64_t length = static_cast<uint64_t>(cur_dynamic_dims.size() * sizeof(int64_t));
uint32_t length = static_cast<uint32_t>(cur_dynamic_dims.size() * sizeof(int32_t));
GE_CHK_BOOL_EXEC(memcpy_s(data.data, length, cur_dynamic_dims.data(), length) == EOK, return INTERNAL_ERROR,
"Failed to memcpy data.");
data.length = length;

@ -126,14 +126,14 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager {
///
/// @ingroup domi_ome
/// @brief Get cur_dynamic_dims for all input.
/// @param [in] vector<vector<uint64_t>> &user_real_input_dims: dims info of all user_inputs.
/// @param [in] vector<vector<int64_t>> &user_real_input_dims: dims info of all user_inputs.
/// @param [in] vector<pair<string, vector<int64_t>>> &user_input_dims: key:name. value:dynamic dims from option.
/// @param [out] vector<uint64_t> &cur_dynamic_dims: real dims gather, where the index of -1.
/// @param [out] vector<int32_t> &cur_dynamic_dims: real dims gather, where the index of -1.
/// @return 0: SUCCESS / others: INTERNAL_ERROR
///
Status GetCurDynamicDims(const vector<vector<int64_t>> &user_real_input_dims,
const vector<pair<string, vector<int64_t>>> &user_input_dims,
vector<int64_t> &cur_dynamic_dims);
vector<int32_t> &cur_dynamic_dims);
///
/// @ingroup domi_ome

@ -145,7 +145,9 @@ Status HcclTaskInfo::SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciM
} else {
GELOGI("need to reuse follow stream and create new follow stream.");
size_t created_stream_num = follow_stream_usage.size();
hccl_stream_list_ = follow_stream_usage;
for (const auto &stream : follow_stream_usage) {
hccl_stream_list_.emplace_back(stream);
}
ret = CreateStream(hccl_stream_num - created_stream_num, davinci_model, main_stream_id);
if (ret != SUCCESS) {
GELOGE(RT_FAILED, "Create hccl stream failed.");

@ -2780,8 +2780,10 @@ Status GraphManager::ParseInputsDims(const std::vector<InputTensorInfo> &input_t
if (!GetLocalOmgContext().dynamic_node_type.empty()) {
vector<NodePtr> data_nodes;
vector<NodePtr> getnext_nosink_nodes;
data_nodes = compute_graph_->TryGetExtAttr(kExtAttrDataNodes, data_nodes);
getnext_nosink_nodes = compute_graph_->TryGetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes);
data_nodes = GetLocalOmgContext().data_nodes;
getnext_nosink_nodes = GetLocalOmgContext().getnext_nosink_nodes;
GELOGD("Data nodes count is %zu, getnext nosink nodes count is %zu.", data_nodes.size(),
getnext_nosink_nodes.size());
if (GetLocalOmgContext().dynamic_node_type == DATA) {
if (getnext_nosink_nodes.empty()) {
// just data or data+getnext_sink

@ -26,6 +26,10 @@
namespace ge {
namespace {
std::set<std::string> un_compute_attrs = {
{ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES},
};
std::string GetCseKey(const NodePtr &node) {
std::stringstream ss;
ss << node->GetType() << "-data-inputs-";
@ -49,7 +53,7 @@ std::string GetCseKey(const NodePtr &node) {
ss << name << "-";
}
ss << "attrs-" << AttrUtils::GetAllAttrsStr(node->GetOpDesc());
ss << "attrs-" << AttrUtils::GetAttrsStrAfterRid(node->GetOpDesc(), un_compute_attrs);
return ss.str();
}

File diff suppressed because it is too large Load Diff

@ -36,6 +36,7 @@ class MultiBatchClonePass : public GraphPass {
/// @return 0: SUCCESS / others: FAILED
///
Status CollectIoNodes(const ComputeGraphPtr &graph);
Status InitParamsOfGetNext(const NodePtr &node);
///
/// @ingroup ge
@ -49,10 +50,12 @@ class MultiBatchClonePass : public GraphPass {
/// @ingroup ge
/// @brief Create index data node for root graph.
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
/// @param [in] NodePtr node: index data node.
/// @param [in] NodePtr shape_node: index data node, DATA or GETDYNAMICDIMS type.
/// @return 0: SUCCESS / others: FAILED
///
Status CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node);
Status CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node);
Status CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node);
///
/// @ingroup ge
@ -70,6 +73,9 @@ class MultiBatchClonePass : public GraphPass {
/// @return 0: SUCCESS / others: FAILED
///
Status CreateIndexNode(const ComputeGraphPtr &graph);
Status AddAttrForGetDynamicDims(const NodePtr &shape_node);
Status LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node);
Status LinkGetDynamicDimsToNetOutput(const NodePtr &output_node);
///
/// @ingroup ge
@ -78,39 +84,54 @@ class MultiBatchClonePass : public GraphPass {
/// @return 0: SUCCESS / others: FAILED
///
Status CreateInputNode(const ComputeGraphPtr &graph);
Status LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index);
///
/// @ingroup ge
/// @brief Create Const node for root graph.
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
/// @brief Set max shape to Data node in root graph.
/// @param [in] const NodePtr &data: data in Root/Case graph.
/// @return 0: SUCCESS / others: FAILED
///
Status CreateConstNode(const ComputeGraphPtr &graph);
Status SetMaxShape(const NodePtr &data);
Status SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index);
///
/// @ingroup ge
/// @brief Set max shape to Data/GetNext node in root graph.
/// @param [in] const std::vector<int64_t> &shapes: dims of shape.
/// @param [in] const NodePtr &data: data in Root/Case graph.
/// @param [in] GeShape &data_shape: dims of data node.
/// @param [in] size_t out_anchor_index: out anchor index of data node.
/// @return 0: SUCCESS / others: FAILED
///
Status SetShapeToData(const std::vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape,
size_t out_anchor_index);
Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);
///
/// @ingroup ge
/// @brief Create output node for root graph.
/// @brief Create Const node for root graph.
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
/// @return 0: SUCCESS / others: FAILED
///
Status CreateOutputNode(const ComputeGraphPtr &graph);
Status CreateConstNode(const ComputeGraphPtr &graph);
void ChangeConstToData();
///
/// @ingroup ge
/// @brief Set max shape to Data node in root graph.
/// @param [in] const NodePtr &data: data in Root/Case graph.
/// @brief Create output node for root graph.
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
/// @return 0: SUCCESS / others: FAILED
///
Status SetMaxShapeToData(const NodePtr &data);
Status CreateOutputNode(const ComputeGraphPtr &graph);
///
/// @ingroup ge
/// @brief Update Data node in Subgraph.
/// @param [in] const NodePtr &data: data in Subgraph.
/// @param [in] size_t index: The batch index.
/// @param [in] size_t batch_index: The batch index.
/// @return 0: SUCCESS / others: FAILED
///
Status UpdateSubgraphData(const NodePtr &data, size_t index);
Status UpdateSubgraphData(const NodePtr &data, size_t batch_index);
///
/// @ingroup ge
@ -122,13 +143,12 @@ class MultiBatchClonePass : public GraphPass {
///
/// @ingroup ge
/// @brief Set max shape to Data node in root graph.
/// @param [in] const std::vector<int64_t> &shapes: dims of shape.
/// @param [in] const NodePtr &data: data in Root/Case graph.
/// @param [in] GeShape &data_shape: dims of data node.
/// @brief Create nodes for root graph.
/// @param [in] const ComputeGraphPtr &graph: Original graph.
/// @return 0: SUCCESS / others: FAILED
///
Status SetShapeToData(const std::vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape);
Status CreateOriGraph(const ComputeGraphPtr &graph);
NodePtr CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, size_t data_index);
///
/// @ingroup ge
@ -168,6 +188,10 @@ class MultiBatchClonePass : public GraphPass {
std::map<string, vector<vector<int64_t>>> data_to_dynamic_info_;
NodePtr case_node_;
size_t data_count_from_getnext_ = 0;
bool getnext_sink_dynamic_dims_ = false;
NodePtr shape_node_;
std::set<NodePtr> out_control_nodes_;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_MULTI_BATCH_CLONE_PASS_H_

@ -204,6 +204,10 @@ Status UnusedArgsCleanPass::RemoveInputTensor(const map<ComputeGraphPtr, map<uin
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed");
GELOGI("Remove edge: %s %s", out_node->GetName().c_str(), func_node->GetName().c_str());
if (out_node->GetInDataNodes().size() == 0 && out_node->GetOutAllNodes().size() == 0) {
GE_CHK_GRAPH_STATUS_RET(out_node->GetOwnerComputeGraph()->RemoveNode(out_node), "Remove node failed: %s",
out_node->GetName().c_str());
}
return SUCCESS;
}
} // namespace ge

@ -1692,13 +1692,11 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) {
}
Status ProcessMultiBatch(ComputeGraphPtr &graph) {
if (GetLocalOmgContext().dynamic_node_type.empty()) {
const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN");
if (multi_batch_with_switchn == nullptr) {
PassManager pass_manager;
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
return pass_manager.Run(graph);
}
const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN");
if (multi_batch_with_switchn == nullptr) {
PassManager pass_manager;
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
return pass_manager.Run(graph);
}
if (!GetLocalOmgContext().need_multi_batch) {
GELOGI("No need to process_multi for no_train graph.");

@ -99,9 +99,8 @@ Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_n
}
GELOGI("Data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(),
getnext_nosink_nodes.size(), getnext_sink_nodes.size());
GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrDataNodes, data_nodes), GELOGW("Set data nodes attr failed.");)
GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes),
GELOGW("Set getnext nosink nodes attr failed.");)
GetLocalOmgContext().data_nodes = data_nodes;
GetLocalOmgContext().getnext_nosink_nodes = getnext_nosink_nodes;
return SUCCESS;
}

@ -26,6 +26,7 @@
#include <vector>
#include "framework/common/fmk_error_codes.h"
#include "register/register_fmk_types.h"
#include "graph/node.h"
using domi::DOMI_TENSOR_ND;
using domi::DOMI_TENSOR_RESERVED;
@ -120,6 +121,8 @@ struct OmgContext {
std::vector<std::vector<int64_t>> user_real_input_dims;
std::vector<int64_t> cur_dynamic_dims;
bool need_multi_batch = false;
std::vector<NodePtr> data_nodes;
std::vector<NodePtr> getnext_nosink_nodes;
};
} // namespace ge

@ -1 +1 @@
Subproject commit 44bcbb5ea25ada1a5393aa4c7f554d40b6859b18
Subproject commit fe37bc343ea52c76d35e9e9ec83cea0151bfa900

@ -1 +1 @@
Subproject commit 5b93b050dd7ca5b77c3001a790031d877fa10956
Subproject commit 336cd3107253d3fe41cfb9fec2db62b5f3d8a33b

@ -627,6 +627,7 @@ set(PASS_TEST_FILES
"graph/passes/net_output_pass_unittest.cc"
"graph/passes/no_use_reshape_remove_pass_unittest.cc"
"graph/passes/infershape_pass_unittest.cc"
"graph/passes/multi_batch_clone_pass_unittest.cc"
)
set(KERNEL_TEST_FILES

@ -32,6 +32,18 @@ class UtestDavinciModel : public testing::Test {
void SetUp() {}
void TearDown() {}
public:
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
auto op_desc = std::make_shared<OpDesc>(name, type);
for (auto i = 0; i < in_num; ++i) {
op_desc->AddInputDesc(test_desc);
}
for (auto i = 0; i < out_num; ++i) {
op_desc->AddOutputDesc(test_desc);
}
return graph->AddNode(op_desc);
}
};
TEST_F(UtestDavinciModel, init_success) {
@ -324,5 +336,94 @@ TEST_F(UtestDavinciModel, SyncVarData_test) {
EXPECT_NE(model.SyncVarData(), SUCCESS);
}
TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ1) {
DavinciModel model(0, nullptr);
model.ge_model_ = make_shared<GeModel>();
ComputeGraphPtr graph = make_shared<ComputeGraph>("default");
GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
OpDescPtr op_output = CreateOpDesc("output_ascend_mbatch_batch_1", NETOUTPUT);
op_output->AddInputDesc(tensor);
op_output->SetInputOffset({1024});
NodePtr node_output = graph->AddNode(op_output);
EXPECT_EQ(model.InitRealSizeAndShapeInfo(graph, node_output), SUCCESS);
}
TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ2) {
DavinciModel model(0, nullptr);
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
OpDescPtr data1 = CreateOpDesc("data1", DATA);
GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT);
data1->AddInputDesc(shape_desc);
data1->AddOutputDesc(shape_desc);
NodePtr data1_node = graph->AddNode(data1);
OpDescPtr case_node = CreateOpDesc("case1", CASE);
GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
case_node->AddInputDesc(tensor);
case_node->AddOutputDesc(tensor);
NodePtr case1_node = graph->AddNode(case_node);
OpDescPtr output = CreateOpDesc("output1", NETOUTPUT);
output->AddInputDesc(tensor);
output->SetSrcName( { "case1" } );
output->SetSrcIndex( { 0 } );
NodePtr output_node = graph->AddNode(output);
GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), case1_node->GetInDataAnchor(0));
GraphUtils::AddEdge(case1_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
(void)AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, "1;2;4;8");
(void)AttrUtils::SetBool(case_node, ATTR_INSERT_BY_MBATCH, true);
model.is_getnext_sink_dynamic_ = false;
model.is_online_infer_dynamic_ = true;
auto ret = model.InitRealSizeAndShapeInfo(graph, output_node);
// GetGearAndRealOutShapeInfo without ATTR_NAME_DYNAMIC_OUTPUT_DIMS
EXPECT_EQ(ret, SUCCESS);
vector<string> dynamic_output_dims = {"0,0,1,1,0,2,2,0,4,3,0,8"};
(void)AttrUtils::SetListStr(output_node->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims);
ret = model.InitRealSizeAndShapeInfo(graph, output_node);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ3) {
DavinciModel model(0, nullptr);
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
OpDescPtr data1 = CreateOpDesc("data1", DATA);
GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT);
data1->AddInputDesc(shape_desc);
data1->AddOutputDesc(shape_desc);
NodePtr data1_node = graph->AddNode(data1);
OpDescPtr shape_node = CreateOpDesc("ascend_mbatch_get_dynamic_dims_node", GETDYNAMICDIMS);
GeTensorDesc in_tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
GeTensorDesc out_tensor(GeShape({4,3}), FORMAT_NCHW, DT_FLOAT);
shape_node->AddInputDesc(in_tensor);
shape_node->AddOutputDesc(out_tensor);
NodePtr get_dynamic_dims_node = graph->AddNode(shape_node);
OpDescPtr output = CreateOpDesc("output1", NETOUTPUT);
GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
output->AddInputDesc(tensor);
output->SetSrcName( { "data1", "ascend_mbatch_get_dynamic_dims_node" } );
output->SetSrcIndex( { 0, 1 } );
NodePtr output_node = graph->AddNode(output);
GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
GraphUtils::AddEdge(get_dynamic_dims_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(1));
(void)AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, "1,3;;4,3;,3");
model.is_getnext_sink_dynamic_ = true;
model.is_online_infer_dynamic_ = false;
auto ret = model.InitRealSizeAndShapeInfo(graph, output_node);
EXPECT_EQ(ret, SUCCESS);
model.runtime_param_.mem_base = (uint8_t *)0x08000000;
model.runtime_param_.mem_size = 4;
ret = model.InitRealSizeAndShapeInfo(graph, output_node);
EXPECT_EQ(ret, SUCCESS);
}
} // namespace ge

@ -0,0 +1,247 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/multi_batch_clone_pass.h"
#include <gtest/gtest.h>
#include <set>
#include <string>
#include "inc/pass_manager.h"
#include "graph/utils/tensor_utils.h"
#include "graph/common/local_context.h"
#include "graph/passes/multi_batch_pass.h"
#include "graph/preprocess/multi_batch_copy_graph.h"
#include "graph/preprocess/insert_op/util_insert_aipp_op.h"
#include "framework/omg/omg_inner_types.h"
#include "register/op_registry.h"
namespace ge{
class UtestMultiBatchClonePass : public testing::Test {
protected:
void SetUp() {
SetLocalOmgContext(domi::GetContext());
GetLocalOmgContext().dynamic_image_size.clear();
GetLocalOmgContext().dynamic_batch_size.clear();
}
void TearDown() {
GetLocalOmgContext().dynamic_image_size.clear();
GetLocalOmgContext().dynamic_batch_size.clear();
GetLocalOmgContext().dynamic_node_type.clear();
}
public:
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
auto op_desc = std::make_shared<OpDesc>(name, type);
for (auto i = 0; i < in_num; ++i) {
op_desc->AddInputDesc(test_desc);
}
for (auto i = 0; i < out_num; ++i) {
op_desc->AddOutputDesc(test_desc);
}
return graph->AddNode(op_desc);
}
NodePtr MakeConstNode(const ComputeGraphPtr &graph) {
static uint32_t index = 0;
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
auto op_desc = std::make_shared<OpDesc>("dynamic_const_" + std::to_string(index++), "Const");
op_desc->AddOutputDesc(test_desc);
return graph->AddNode(op_desc);
}
void make_original_graph(const ComputeGraphPtr &graph) {
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D");
{
auto data1 = MakeNode(graph, 1, 1, "data", "Data");
GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT);
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0);
GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector<int64_t>{-1,3,224,224})};
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0));
auto const1 = MakeConstNode(graph);
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1));
auto const2 = MakeConstNode(graph);
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2));
}
auto bn_conv1 = MakeNode(graph, 4, 1, "bn_conv1", "BNInference");
{
GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(0));
auto const1 = MakeConstNode(graph);
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(1));
auto const2 = MakeConstNode(graph);
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(2));
auto const3= MakeConstNode(graph);
GraphUtils::AddEdge(const3->GetOutDataAnchor(0), bn_conv1->GetInDataAnchor(3));
}
auto scale_conv1 = MakeNode(graph, 4, 1, "scale1", "Scale");
{
GraphUtils::AddEdge(bn_conv1->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(0));
auto const1 = MakeConstNode(graph);
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(1));
auto const2 = MakeConstNode(graph);
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), scale_conv1->GetInDataAnchor(2));
}
auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput");
GraphUtils::AddEdge(scale_conv1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
}
void GraphWithJustData(const ComputeGraphPtr &graph) {
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D");
{
auto data1 = MakeNode(graph, 1, 1, "data", "Data");
GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT);
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0);
GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector<int64_t>{-1,3,224,224})};
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0));
auto const1 = MakeConstNode(graph);
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1));
auto const2 = MakeConstNode(graph);
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2));
}
auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput");
GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
}
void GraphWithGetNextNosink(const ComputeGraphPtr &graph) {
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D");
{
auto data1 = MakeNode(graph, 1, 1, "IteratorGetNext_data", "Data");
GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT);
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0);
GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector<int64_t>{-1,3,224,224})};
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0));
auto const1 = MakeConstNode(graph);
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1));
auto const2 = MakeConstNode(graph);
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2));
}
auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput");
GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
}
// getnext has one data and has one out of shape
void GraphWithGetNextSink(const ComputeGraphPtr &graph) {
auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D");
{
auto data1 = MakeNode(graph, 1, 2, "data", "IteratorV2");
GeTensorDesc tensor_desc(GeShape({-1,3,224,224}), FORMAT_NCHW, DT_FLOAT);
GeTensorDesc shape_desc(GeShape({4,3,224,224}), FORMAT_NCHW, DT_FLOAT);
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
data1->GetOpDesc()->UpdateOutputDesc(1, shape_desc);
AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 0);
GetLocalOmgContext().user_input_dims = {std::make_pair(data1->GetOpDesc()->GetName(), vector<int64_t>{-1,3,224,224})};
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0));
auto identity = MakeNode(graph, 1, 0, "identity", "Identity");
GraphUtils::AddEdge(data1->GetOutDataAnchor(1), identity->GetInDataAnchor(0));
auto const1 = MakeConstNode(graph);
GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1));
auto const2 = MakeConstNode(graph);
GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2));
}
auto output_node = MakeNode(graph, 1, 0, "output1", "NetOutput");
GraphUtils::AddEdge(conv2d_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
}
};
// graph is nullptr
TEST_F(UtestMultiBatchClonePass, graph_nullptr) {
PassManager pass_manager;
pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass);
ComputeGraphPtr graph;
EXPECT_EQ(pass_manager.Run(graph), PARAM_INVALID);
}
// graph with subgraph
TEST_F(UtestMultiBatchClonePass, graph_with_subgraph) {
PassManager pass_manager;
pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass);
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
make_original_graph(graph);
EXPECT_EQ(pass_manager.Run(graph), SUCCESS);
ComputeGraphPtr owner = std::make_shared<ComputeGraph>("test_owner");
auto func_node = MakeNode(owner, 3, 1, "test_if", "If");
graph->SetParentNode(func_node);
graph->SetParentGraph(owner);
EXPECT_EQ(pass_manager.Run(graph), SUCCESS);
}
//graph is uncompute graph, not need to do multi batch
TEST_F(UtestMultiBatchClonePass, uncompute_graph) {
MultiBatchClonePass multi_batch_clone;
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
make_original_graph(graph);
GetLocalOmgContext().need_multi_batch = false;
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS);
}
//compute_graph with data from DATA
TEST_F(UtestMultiBatchClonePass, compute_graph_with_data) {
MultiBatchClonePass multi_batch_clone;
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
GraphWithJustData(graph);
GetLocalOmgContext().need_multi_batch = true;
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS);
GetLocalOmgContext().dynamic_node_type = DATA;
GetLocalOmgContext().dynamic_dims = "1;2;4;8";
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS);
EXPECT_EQ(GetLocalOmgContext().data_nodes.size(), 1);
}
//compute_graph with data from GetNext_nosink
TEST_F(UtestMultiBatchClonePass, compute_graph_with_getnext_nosink) {
MultiBatchClonePass multi_batch_clone;
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
GraphWithGetNextNosink(graph);
GetLocalOmgContext().need_multi_batch = true;
GetLocalOmgContext().dynamic_node_type = GETNEXT;
GetLocalOmgContext().dynamic_dims = "1;2;4;8";
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS);
EXPECT_EQ(GetLocalOmgContext().getnext_nosink_nodes.size(), 1);
}
//compute_graph with data from GetNext_nosink
TEST_F(UtestMultiBatchClonePass, compute_graph_with_getnext_sink) {
MultiBatchClonePass multi_batch_clone;
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph");
GraphWithGetNextSink(graph);
GetLocalOmgContext().need_multi_batch = true;
GetLocalOmgContext().dynamic_node_type = GETNEXT;
GetLocalOmgContext().dynamic_dims = "1;2;4;8";
EXPECT_EQ(multi_batch_clone.Run(graph), SUCCESS);
EXPECT_EQ(GetLocalOmgContext().getnext_nosink_nodes.size(), 0);
}
}
Loading…
Cancel
Save