reformat need to delete

pull/1105/head
wxl 4 years ago
parent 019e39e4fc
commit 2971522f1b

@ -15,22 +15,32 @@
*/
#include "graph/passes/reshape_remove_pass.h"
#include <map>
#include <string>
#include "framework/common/util.h"
#include "framework/common/types.h"
#include "graph/passes/pass_utils.h"
#include "graph/utils/node_utils.h"
namespace ge {
namespace {
const int kReshapeDataIndex = 0;
const int kReshapeType = 0;
const int kReformatType = 1;
std::map<const char *, int> kOpTypeHash = {
{RESHAPE, kReshapeType},
{REFORMAT, kReformatType}
};
}
Status ReshapeRemovePass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
if (node->GetType() != RESHAPE && node->GetType() != REFORMAT) {
return SUCCESS;
}
switch(kOpTypeHash.find(node->GetType())) {
case kReshapeType:
bool is_shape_unknown = false;
if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) {
if (is_shape_unknown) {
@ -39,6 +49,12 @@ Status ReshapeRemovePass::Run(NodePtr &node) {
return SUCCESS;
}
}
break;
case kReformatType:
break;
default:
return SUCCESS;
}
GELOGI("Remove %s node %s", node->GetType().c_str(), node->GetName().c_str());
return IsolateAndDeleteNode(node, {kReshapeDataIndex});

Loading…
Cancel
Save