remove inputInferenceType Param

pull/6314/head
cjh9368 4 years ago committed by cjh9368
parent 88dfcda3e6
commit fb217f3c81

@ -34,7 +34,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath
std::ofstream output(outputPath + ".ms", std::ofstream::binary);
if (!output.is_open()) {
MS_LOG(ERROR) << "ofstream open failed";
MS_LOG(ERROR) << "Output file path is error";
return RET_ERROR;
}

@ -30,9 +30,8 @@ Flags::Flags() {
AddFlag(&Flags::weightFile, "weightFile",
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", "");
AddFlag(&Flags::inferenceTypeIn, "inferenceType",
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8", "FLOAT");
"Real data type saved in output file, reserved param, NOT used for now. SAME | FLOAT | INT8", "FLOAT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", "");
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT");
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5");
AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8");
@ -41,8 +40,10 @@ Flags::Flags() {
"16");
AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", "");
AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true");
AddFlag(&Flags::trainModelIn, "trainModel", "whether the model is going to be trained on device."
" true | false", "false");
AddFlag(&Flags::trainModelIn, "trainModel",
"whether the model is going to be trained on device."
" true | false",
"false");
}
int Flags::Init(int argc, const char **argv) {
@ -80,22 +81,15 @@ int Flags::Init(int argc, const char **argv) {
std::cerr << "INPUT MISSING: fmk is necessary";
return RET_INPUT_PARAM_LACK;
}
if (this->inputInferenceTypeIn == "FLOAT") {
this->inputInferenceType = TypeId::kNumberTypeFloat;
} else if (this->inputInferenceTypeIn == "INT8") {
this->inputInferenceType = TypeId::kNumberTypeInt8;
} else {
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8",
this->inputInferenceTypeIn.c_str();
return RET_INPUT_PARAM_INVALID;
}
if (this->inferenceTypeIn == "FLOAT") {
this->inferenceType = TypeId::kNumberTypeFloat;
} else if (this->inferenceTypeIn == "INT8") {
this->inferenceType = TypeId::kNumberTypeInt8;
} else if (this->inferenceTypeIn == "SAME") {
this->inferenceType = TypeId::kTypeUnknown;
} else {
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8",
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | SAME",
this->inferenceTypeIn.c_str();
return RET_INPUT_PARAM_INVALID;
}
@ -130,7 +124,6 @@ int Flags::Init(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID;
}
if (this->trainModelIn == "true") {
this->trainModel = true;
} else if (this->trainModelIn == "false") {

@ -56,10 +56,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string quantTypeIn;
QuantType quantType;
std::string inferenceTypeIn;
TypeId inferenceType = TypeId::kNumberTypeFloat;
// used for parse aware trainning
std::string inputInferenceTypeIn;
TypeId inputInferenceType = TypeId::kNumberTypeFloat;
TypeId inferenceType = TypeId::kNumberTypeFloat;
std::string stdDev;
std::string mean;
// used for post-trainning-weight

@ -51,7 +51,7 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
case QuantType::QuantType_AwareTraining: {
MS_LOG(INFO) << "create AwareTrainingQuantizer!";
fbQuantizer =
std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean);
std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inferenceType, flags->stdDev, flags->mean);
break;
}
default:
@ -194,11 +194,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
if (ctx.quantType == QuantType_AwareTraining) {
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_MEMORY_FAILED;
}
dTypeTransPass->SetInputDataDType(ctx.inputInferenceType);
dTypeTransPass->SetInputDataDType(ctx.inferenceType);
dTypeTransPass->SetOutputDataDType(ctx.inferenceType);
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());

@ -71,7 +71,7 @@ STATUS InputArray::SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensor
return RET_OK;
}
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues,
AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType, const string &stdValues,
const string &meanValues)
: FbQuantizer(graph) {
MS_ASSERT(graph != nullptr);
@ -80,7 +80,7 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInf
sz = 0;
const float mean = std::stof(meanValues, &sz);
std::unique_ptr<InputArray> inArr = nullptr;
if (inputInferType == "FLOAT") {
if (inferType == kNumberTypeFloat) {
inArr.reset(new (std::nothrow) InputArray(mean, stdValue));
} else {
inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8));

@ -48,7 +48,7 @@ struct InputArray {
class AwareQuantizer : public FbQuantizer {
public:
AwareQuantizer(schema::MetaGraphT *graph, const std::string &inputInferType, const std::string &stdValues,
AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType, const std::string &stdValues,
const std::string &meanValues);
~AwareQuantizer() { delete (mInputArray); }

@ -116,11 +116,11 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
return status;
}
if (inputParamDone != node.inputIndex.size()) {
MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name;
MS_LOG(WARNING) << "Can not determine inputTensor quantParam, node " << node.name;
return RET_ERROR;
}
if (outputParamDone != node.outputIndex.size()) {
MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name;
MS_LOG(WARNING) << "Can not determine outputTensor quantParam, node " << node.name;
return RET_ERROR;
}
return RET_OK;
@ -138,7 +138,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
MS_ASSERT(outTensor != nullptr);
auto outputQuantParam = GetTensorQuantParam(outTensor);
MS_ASSERT(outputQuantParam != nullptr);
if (!outputQuantParam->inited) {
if (outputQuantParam == nullptr || !outputQuantParam->inited) {
MS_LOG(WARNING) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name;
return RET_ERROR;
}
@ -204,8 +204,7 @@ class CalcConcat : public QuantParamCalcer {
auto &inTensor = graph->allTensors.at(i);
MS_ASSERT(inTensor != nullptr);
auto inQuantParam = GetTensorQuantParam(inTensor);
MS_ASSERT(inQuantParam != nullptr);
if (!inQuantParam->inited) {
if (inQuantParam == nullptr || !inQuantParam->inited) {
return RET_ERROR;
}
if (numBits == -1) {

Loading…
Cancel
Save