|
|
|
@ -19,6 +19,7 @@
|
|
|
|
|
#include "device/kernel_info.h"
|
|
|
|
|
#include "pre_activate/pass/convert_const_input_to_attr.h"
|
|
|
|
|
#include "debug/anf_ir_dump.h"
|
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
|
#define private public
|
|
|
|
|
#define protected public
|
|
|
|
|
#include "pre_activate/ascend/ir_fission/topk_split.h"
|
|
|
|
@ -32,6 +33,21 @@ class TestHWTopKSplit : public BackendCommon {
|
|
|
|
|
TestHWTopKSplit() : get_py_fun_("gtest_input.pre_activate.topk_split_test", true) {}
|
|
|
|
|
~TestHWTopKSplit() override = default;
|
|
|
|
|
|
|
|
|
|
CNodePtr GetTopkCNodeFromKernelGraph(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
auto ret = func_graph->get_return();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ret);
|
|
|
|
|
auto make_tuple = ret->input(1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
|
|
|
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
|
|
|
|
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(topk);
|
|
|
|
|
auto topk_cnode = topk->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(topk_cnode);
|
|
|
|
|
return topk_cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
UT::PyFuncGraphFetcher get_py_fun_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -39,7 +55,8 @@ class MockSupportedChecker : public SupportedChecker {
|
|
|
|
|
public:
|
|
|
|
|
MockSupportedChecker() = default;
|
|
|
|
|
~MockSupportedChecker() override = default;
|
|
|
|
|
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
|
|
|
|
bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
|
|
|
|
|
const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}; // namespace opt
|
|
|
|
@ -66,14 +83,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
|
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph);
|
|
|
|
|
|
|
|
|
|
auto ret = new_graph->get_return();
|
|
|
|
|
EXPECT_NE(ret, nullptr);
|
|
|
|
|
auto make_tuple = ret->input(1);
|
|
|
|
|
EXPECT_NE(make_tuple, nullptr);
|
|
|
|
|
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1);
|
|
|
|
|
EXPECT_NE(tuple_getitem, nullptr);
|
|
|
|
|
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
|
|
|
|
|
auto topk_cnode = topk->cast<CNodePtr>();
|
|
|
|
|
auto topk_cnode = GetTopkCNodeFromKernelGraph(new_graph);
|
|
|
|
|
EXPECT_EQ(topk_cnode->inputs().size(), 3);
|
|
|
|
|
EXPECT_TRUE(topk_cnode->input(2)->isa<ValueNode>());
|
|
|
|
|
auto value_node = topk_cnode->input(2)->cast<ValueNodePtr>();
|
|
|
|
@ -82,5 +92,39 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
|
|
|
|
|
EXPECT_EQ(tensor->shape().size(), 1);
|
|
|
|
|
EXPECT_EQ(tensor->shape()[0], 4);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWTopKSplit, test_topk_no_split) {
|
|
|
|
|
/*
|
|
|
|
|
* def before(input):
|
|
|
|
|
* topk = TopKSplit(input)
|
|
|
|
|
* output = tuple_getitem(topk, 0)
|
|
|
|
|
* return output
|
|
|
|
|
*/
|
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before");
|
|
|
|
|
std::vector<int> shp{4, 4};
|
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
|
AbstractBasePtrList args_spec_list{x_abstract};
|
|
|
|
|
auto kernel_graph = GetKernelGraph(g, args_spec_list);
|
|
|
|
|
|
|
|
|
|
CNodePtr topk_cnode = GetTopkCNodeFromKernelGraph(kernel_graph);
|
|
|
|
|
EXPECT_EQ(topk_cnode->inputs().size(), 3);
|
|
|
|
|
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(topk_cnode, kAttrInputNames);
|
|
|
|
|
EXPECT_EQ(input_names_vec.size(), 2);
|
|
|
|
|
std::unordered_set<size_t> attr_index{1};
|
|
|
|
|
ConstInputToAttr(topk_cnode, attr_index);
|
|
|
|
|
EXPECT_EQ(topk_cnode->inputs().size(), 2);
|
|
|
|
|
input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(topk_cnode, kAttrInputNames);
|
|
|
|
|
EXPECT_EQ(input_names_vec.size(), 1);
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ConvertConstInputToAttr>());
|
|
|
|
|
auto topk_split = std::make_shared<opt::TopKSplit>();
|
|
|
|
|
topk_split->supported_checker_ = std::make_shared<MockSupportedChecker>();
|
|
|
|
|
pm->AddPass(topk_split);
|
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph);
|
|
|
|
|
EXPECT_EQ(topk_cnode, GetTopkCNodeFromKernelGraph(new_graph));
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|