|
|
|
@ -23,33 +23,33 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
const BaseRef RemoveReshapePair::DefinePattern() const {
|
|
|
|
|
const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name());
|
|
|
|
|
VectorRef reshape({prim_reshape, input_varptr_});
|
|
|
|
|
|
|
|
|
|
return VectorRef({prim::kPrimReshape, reshape});
|
|
|
|
|
VarPtr X = std::make_shared<Var>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(X);
|
|
|
|
|
return VectorRef({prim::kPrimReshape, VectorRef({prim::kPrimReshape, X})});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const EquivPtr &equiv) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equiv);
|
|
|
|
|
auto manager = func_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(reshape_op_1);
|
|
|
|
|
// If reshape operator used by more than one other operators, reshape operator cant not be deleted directly
|
|
|
|
|
auto users = manager->node_users()[reshape_op_1];
|
|
|
|
|
if (users.size() > 1) {
|
|
|
|
|
if (IsUsedByOthers(func_graph, reshape_op_1)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(reshape_op_2);
|
|
|
|
|
users = manager->node_users()[reshape_op_2];
|
|
|
|
|
if (users.size() > 1) {
|
|
|
|
|
if (IsUsedByOthers(func_graph, reshape_op_2)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto input_node = reshape_op_2->input(1);
|
|
|
|
|
return input_node;
|
|
|
|
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0);
|
|
|
|
|
auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0);
|
|
|
|
|
if (input_shape == output_shape) {
|
|
|
|
|
auto input_node = reshape_op_2->input(1);
|
|
|
|
|
return input_node;
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|