|
|
|
@ -75,10 +75,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
|
|
|
|
auto &node = *iter;
|
|
|
|
|
auto nodeName = node->name;
|
|
|
|
|
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) {
|
|
|
|
|
if (node->inputIndex.at(inputIndexIdx) == graphInIdx) {
|
|
|
|
|
auto nodeName = (*iter)->name;
|
|
|
|
|
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) {
|
|
|
|
|
if ((*iter)->inputIndex.at(inputIndexIdx) == graphInIdx) {
|
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
|
|
|
|
|
|
// insert dtype cast node between input tensor and input node
|
|
|
|
@ -108,11 +107,10 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
|
|
|
|
|
auto &graphOutIdxes = graph->outputIndex;
|
|
|
|
|
for (auto graphOutIdx : graphOutIdxes) {
|
|
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
|
|
|
|
auto &node = *iter;
|
|
|
|
|
auto nodeName = node->name;
|
|
|
|
|
auto nodeName = (*iter)->name;
|
|
|
|
|
MS_ASSERT(node != nullptr);
|
|
|
|
|
for (size_t outputIndexIdx = 0; outputIndexIdx < node->outputIndex.size(); outputIndexIdx++) {
|
|
|
|
|
if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) {
|
|
|
|
|
for (size_t outputIndexIdx = 0; outputIndexIdx < (*iter)->outputIndex.size(); outputIndexIdx++) {
|
|
|
|
|
if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) {
|
|
|
|
|
// insert transNode
|
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
|
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
|
|
|
|
@ -135,7 +133,6 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
|
|
|
|
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto &node = *iter;
|
|
|
|
|
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -143,8 +140,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
|
|
|
|
if (GetCNodeTType(**iter) == PrimitiveType_Shape) {
|
|
|
|
|
needInsertPost = false;
|
|
|
|
|
}
|
|
|
|
|
auto nodeName = node->name;
|
|
|
|
|
if (node->inputIndex.size() < kMinInputNum) {
|
|
|
|
|
auto nodeName = (*iter)->name;
|
|
|
|
|
if ((*iter)->inputIndex.size() < kMinInputNum) {
|
|
|
|
|
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|