You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/tests/ut/cpp/optimizer/clean_test.cc

250 lines
8.8 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <string>
#include "common/common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "utils/log_adapter.h"
#include "pipeline/parse/parse.h"
#include "debug/draw.h"
#include "optimizer/clean.h"
namespace mindspore {
namespace opt {
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractError;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple;
class TestClean : public UT::Common {
public:
TestClean() : getPyFun("gtest_input.optimizer.clean_test", true) {}
virtual void SetUp();
virtual void TearDown();
public:
UT::PyFuncGraphFetcher getPyFun;
FuncGraphPtr me_graph;
};
void TestClean::SetUp() {
// build the func_graph.
me_graph = std::make_shared<FuncGraph>();
me_graph->debug_info()->set_name("next");
// build the nodes
AnfNodePtr valuenode_next = NewValueNode(std::string("ms_next"));
ParameterPtr parameter = std::make_shared<Parameter>(me_graph);
AbstractBasePtr para_scalar = std::make_shared<AbstractScalar>(0);
AbstractBasePtr para_list = std::make_shared<AbstractList>(
AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
AbstractBasePtrList para_elem{para_scalar, para_list};
AbstractBasePtr para_tuple = std::make_shared<AbstractTuple>(para_elem);
parameter->set_abstract(para_tuple);
AbstractBasePtr app_float = std::make_shared<AbstractScalar>(kFloat64);
AbstractBasePtr app_int = std::make_shared<AbstractScalar>(kFloat64);
AbstractBasePtr app_list = std::make_shared<AbstractList>(
AbstractBasePtrList({std::make_shared<AbstractScalar>(kFloat64), std::make_shared<AbstractScalar>(kFloat64)}));
AbstractBasePtr app_tuple_inner = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_int, app_list});
AbstractBasePtr app_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList{app_float, app_tuple_inner});
AnfNodePtr cnode_57 = me_graph->NewCNode({valuenode_next, parameter});
cnode_57->set_abstract(app_tuple);
AnfNodePtr cnode_67 = me_graph->NewCNode({NewValueNode(prim::kPrimPartial), valuenode_next, parameter});
cnode_67->set_abstract(app_tuple);
AnfNodePtr cnode_66 = me_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), cnode_57, cnode_67});
cnode_66->set_abstract(app_float);
AnfNodePtr valuenode_return = NewValueNode(prim::kPrimReturn);
CNodePtr cnode_55 = me_graph->NewCNode({valuenode_return, cnode_66});
cnode_55->set_abstract(app_tuple);
me_graph->set_output(cnode_66);
me_graph->set_return(cnode_55);
me_graph->add_parameter(parameter);
}
void TestClean::TearDown() {}
TEST_F(TestClean, TestEraseClassGetAttr) {
FuncGraphPtr func_graph;
func_graph = getPyFun("test_erase_class_fn");
ASSERT_TRUE(nullptr != func_graph);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
int dataclass_count = 0;
for (auto node : manager->all_nodes()) {
if (IsValueNode<parse::ClassObject>(node)) {
dataclass_count++;
}
if (!node->isa<CNode>()) {
continue;
}
auto input0 = node->cast<CNodePtr>()->input(0);
if (IsValueNode<parse::ClassObject>(input0)) {
std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kFloat64)},
{"y", std::make_shared<AbstractScalar>(kFloat64)}};
std::unordered_map<std::string, ValuePtr> methods;
AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
node->set_abstract(abs_ptr);
}
}
ASSERT_EQ(dataclass_count, 1);
// draw func_graph before erase class
draw::Draw("opt_before_erase_class.dot", func_graph);
SimplifyDataStructures(func_graph, manager);
// draw func_graph after erase class
draw::Draw("opt_after_erase_class.dot", func_graph);
int tuple_getitem_count = 0;
for (auto node : manager->all_nodes()) {
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
tuple_getitem_count++;
}
}
ASSERT_EQ(dataclass_count, 1);
ASSERT_EQ(tuple_getitem_count, 2);
}
TEST_F(TestClean, TestEraseClassMakeRecord) {
// build the graph
auto func_graph = std::make_shared<FuncGraph>();
func_graph->debug_info()->set_name("test_make_record");
auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
auto para1 = std::make_shared<Parameter>(func_graph);
auto para2 = std::make_shared<Parameter>(func_graph);
para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
para2->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
{"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
std::unordered_map<std::string, ValuePtr> methods;
AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
auto cons_class = NewValueNode(abs_ptr->BuildValue());
cons_class->set_abstract(abs_ptr);
std::vector<AnfNodePtr> inputs{cons_make_record, cons_class, para1, para2};
auto apply22 = func_graph->NewCNode(inputs);
auto cons_return = NewValueNode(prim::kPrimReturn);
auto apply11 = func_graph->NewCNode({cons_return, apply22});
apply11->set_abstract(abs_ptr);
func_graph->set_output(apply22);
func_graph->set_return(apply11);
func_graph->add_parameter(para1);
func_graph->add_parameter(para2);
auto manager = Manage(func_graph);
draw::Draw("opt_erase_class_record_before.dot", func_graph);
SimplifyDataStructures(func_graph, manager);
draw::Draw("opt_erase_class_record_after.dot", func_graph);
}
TEST_F(TestClean, TestEraseClassPartial) {
// build the graph
auto func_graph = std::make_shared<FuncGraph>();
func_graph->debug_info()->set_name("test_partial");
auto cons_partial = NewValueNode(prim::kPrimPartial);
auto para1 = std::make_shared<Parameter>(func_graph);
para1->set_abstract(std::make_shared<AbstractScalar>(kAnyValue, kInt64));
auto cons_make_record = NewValueNode(prim::kPrimMakeRecord);
std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
{"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
std::unordered_map<std::string, ValuePtr> methods;
AbstractBasePtr abs_ptr = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
auto cons_class = NewValueNode(abs_ptr->BuildValue());
cons_class->set_abstract(abs_ptr);
std::vector<AnfNodePtr> inputs{cons_partial, cons_make_record, cons_class, para1};
auto apply22 = func_graph->NewCNode(inputs);
std::vector<AnfNodePtr> inputs_nopara{cons_partial, cons_make_record, cons_class};
auto apply33 = func_graph->NewCNode(inputs_nopara);
auto apply11 = func_graph->NewCNode({NewValueNode(prim::kPrimScalarAdd), apply22, apply33});
auto cons_return = NewValueNode(prim::kPrimReturn);
auto apply00 = func_graph->NewCNode({cons_return, apply11});
apply00->set_abstract(abs_ptr);
func_graph->set_output(apply22);
func_graph->set_return(apply11);
func_graph->add_parameter(para1);
auto manager = Manage(func_graph);
draw::Draw("opt_erase_class_partial_before.dot", func_graph);
SimplifyDataStructures(func_graph, manager);
draw::Draw("opt_erase_class_partial_after.dot", func_graph);
}
TEST_F(TestClean, TestEraseTuple) {
ASSERT_TRUE(nullptr != me_graph);
std::shared_ptr<FuncGraphManager> manager = Manage(me_graph);
draw::Draw("opt_before_erase_tuple.dot", me_graph);
int abstract_tuple_count = 0;
for (auto node : manager->all_nodes()) {
auto dt = node->abstract();
if (dyn_cast<AbstractTuple>(dt) != nullptr) {
abstract_tuple_count++;
}
}
ASSERT_EQ(abstract_tuple_count, 4);
// erase tuple in CNode57 and Parameter
EraseTuple(me_graph, manager);
abstract_tuple_count = 0;
for (auto node : manager->all_nodes()) {
auto dt = node->abstract();
if (dyn_cast<AbstractTuple>(dt) != nullptr) {
abstract_tuple_count++;
}
}
ASSERT_EQ(abstract_tuple_count, 3);
draw::Draw("opt_after_erase_tuple.dot", me_graph);
}
} // namespace opt
} // namespace mindspore