|
|
|
@ -133,6 +133,9 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
|
|
|
|
MS_ASSERT(fullconnect_cnode->inputs().size() == 3);
|
|
|
|
|
auto left_slice_node = fullconnect_cnode->input(1);
|
|
|
|
|
auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
|
|
|
|
|
if (GetCNodeType(left_slice_cnode) != schema::PrimitiveType_Slice) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto left_matmul_input = left_slice_cnode->input(1);
|
|
|
|
|
auto right_reshape_node = fullconnect_cnode->input(2);
|
|
|
|
|
|
|
|
|
|