|
|
|
@ -13,37 +13,41 @@
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
#include "tools/optimizer/graph/identity_remove_pass.h"
|
|
|
|
|
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include "mindspore/lite/include/errorcode.h"
|
|
|
|
|
#include "src/ops/primitive_c.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore::opt {
|
|
|
|
|
int RemoveIdentityOpPass::ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr size_t InputDoubleNum = 2;
|
|
|
|
|
constexpr size_t InputTripleNum = 3;
|
|
|
|
|
constexpr auto kNameLoad = "Load";
|
|
|
|
|
constexpr auto kNameUpdateState = "UpdateState";
|
|
|
|
|
} // namespace
|
|
|
|
|
int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
|
|
|
|
if (!utils::isa<CNodePtr>(anf_node)) {
|
|
|
|
|
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
|
|
|
|
return lite::RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
auto type = opt::GetCNodeType(anf_node);
|
|
|
|
|
if (type != schema::PrimitiveType_Identity) {
|
|
|
|
|
MS_LOG(DEBUG) << "anf node is not a identity node.";
|
|
|
|
|
return lite::RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
auto identity_cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
if (identity_cnode->inputs().size() != lite::kDoubleNum) {
|
|
|
|
|
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
|
|
|
|
|
remove_cnode_.insert(anf_node);
|
|
|
|
|
return lite::RET_NO_CHANGE;
|
|
|
|
|
} else {
|
|
|
|
|
bool replace_succ = manager->Replace(anf_node, identity_cnode->input(1));
|
|
|
|
|
if (!replace_succ) {
|
|
|
|
|
MS_LOG(ERROR) << "replace identity failed.";
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
if (type == schema::PrimitiveType_Identity) {
|
|
|
|
|
if (cnode->size() != InputDoubleNum) {
|
|
|
|
|
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
|
|
|
|
|
remove_cnode_.insert(anf_node);
|
|
|
|
|
return lite::RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bool replace_succ = manager->Replace(anf_node, cnode->input(1));
|
|
|
|
|
if (!replace_succ) {
|
|
|
|
|
MS_LOG(ERROR) << "replace redundant op failed.";
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
|
|
|
|
int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
|
|
|
|
if (!utils::isa<CNodePtr>(anf_node)) {
|
|
|
|
|
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
|
|
|
|
return lite::RET_NO_CHANGE;
|
|
|
|
@ -53,7 +57,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const
|
|
|
|
|
return lite::RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
if (cnode->inputs().size() != 3) {
|
|
|
|
|
if (cnode->inputs().size() != InputTripleNum) {
|
|
|
|
|
MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
@ -81,7 +85,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const
|
|
|
|
|
return lite::RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_ASSERT(func_graph != nullptr);
|
|
|
|
|
auto manager = func_graph->manager();
|
|
|
|
|
MS_ASSERT(manager != nullptr);
|
|
|
|
@ -93,10 +97,22 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
}
|
|
|
|
|
auto type = opt::GetCNodeType(node);
|
|
|
|
|
if (type == schema::PrimitiveType_Identity) {
|
|
|
|
|
status = ReplaceIdentity(node, manager);
|
|
|
|
|
} else if (type == schema::PrimitiveType_TupleGetItem) {
|
|
|
|
|
status = ReplaceOp(node, manager);
|
|
|
|
|
}
|
|
|
|
|
if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameLoad))) {
|
|
|
|
|
status = ReplaceOp(node, manager);
|
|
|
|
|
}
|
|
|
|
|
if (CheckPrimitiveType(node, std::make_shared<Primitive>(kNameUpdateState))) {
|
|
|
|
|
status = ReplaceOp(node, manager);
|
|
|
|
|
}
|
|
|
|
|
if (type == schema::PrimitiveType_Depend ||
|
|
|
|
|
type == schema::PrimitiveType_ControlDepend) { // ControlDepend delete next version.
|
|
|
|
|
status = ReplaceOp(node, manager);
|
|
|
|
|
}
|
|
|
|
|
if (type == schema::PrimitiveType_TupleGetItem) {
|
|
|
|
|
status = ReplaceTupleGetItem(node, manager);
|
|
|
|
|
} else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) {
|
|
|
|
|
}
|
|
|
|
|
if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) {
|
|
|
|
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
|
|
|
|
|
if (sub_func_graph == nullptr) {
|
|
|
|
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|