fix bug for converter_flags

pull/5428/head
cjh9368 5 years ago
parent 2ef3216755
commit da1c32a7bf

@ -29,12 +29,12 @@ Flags::Flags() {
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
AddFlag(&Flags::weightFile, "weightFile", AddFlag(&Flags::weightFile, "weightFile",
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", "");
AddFlag(&Flags::inferenceType, "inferenceType", AddFlag(&Flags::inferenceTypeIn, "inferenceType",
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT"); "Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8", "FLOAT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | WeightQuant | PostTraining", ""); AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining", "");
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT"); AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT");
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127"); AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5");
AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", "");
AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true");
@ -77,14 +77,24 @@ int Flags::Init(int argc, const char **argv) {
} }
if (this->inputInferenceTypeIn == "FLOAT") { if (this->inputInferenceTypeIn == "FLOAT") {
this->inputInferenceType = TypeId::kNumberTypeFloat; this->inputInferenceType = TypeId::kNumberTypeFloat;
} else if (this->inputInferenceTypeIn == "UINT8") {
this->inputInferenceType = TypeId::kNumberTypeUInt8;
} else if (this->inputInferenceTypeIn == "INT8") { } else if (this->inputInferenceTypeIn == "INT8") {
this->inputInferenceType = TypeId::kNumberTypeInt8; this->inputInferenceType = TypeId::kNumberTypeInt8;
} else { } else {
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8",
this->inputInferenceTypeIn.c_str();
return 1; return 1;
} }
if (this->inferenceTypeIn == "FLOAT") {
this->inferenceType = TypeId::kNumberTypeFloat;
} else if (this->inferenceTypeIn == "INT8") {
this->inferenceType = TypeId::kNumberTypeInt8;
} else {
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8",
this->inferenceTypeIn.c_str();
return 1;
}
if (this->fmkIn == "CAFFE") { if (this->fmkIn == "CAFFE") {
this->fmk = FmkType_CAFFE; this->fmk = FmkType_CAFFE;
} else if (this->fmkIn == "MS") { } else if (this->fmkIn == "MS") {

@ -63,10 +63,10 @@ class Flags : public virtual mindspore::lite::FlagParser {
// used for quantization // used for quantization
std::string quantTypeIn; std::string quantTypeIn;
QuantType quantType; QuantType quantType;
std::string inferenceType; std::string inferenceTypeIn;
TypeId inferenceType = TypeId::kNumberTypeFloat;
// used for parse aware trainning // used for parse aware trainning
std::string inputInferenceTypeIn; std::string inputInferenceTypeIn;
// mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT;
TypeId inputInferenceType = TypeId::kNumberTypeFloat; TypeId inputInferenceType = TypeId::kNumberTypeFloat;
std::string stdDev; std::string stdDev;
std::string mean; std::string mean;

@ -194,6 +194,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return RET_ERROR; return RET_ERROR;
} }
dTypeTransPass->SetInputDataDType(ctx.inputInferenceType); dTypeTransPass->SetInputDataDType(ctx.inputInferenceType);
dTypeTransPass->SetOutputDataDType(ctx.inferenceType);
quantNodeOptimizer.AddPass(dTypeTransPass); quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());

@ -101,7 +101,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
if (inputDataDType == TypeId::kNumberTypeInt8) { if (outputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK; return RET_OK;
} }
MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat); MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat);
@ -231,5 +231,8 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
} }
void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
void DTypeTransPass::SetOutputDataDType(TypeId dataType) { this->outputDataDType = dataType; }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -38,6 +38,8 @@ class DTypeTransPass : public GraphPass {
void SetInputDataDType(TypeId dataType); void SetInputDataDType(TypeId dataType);
void SetOutputDataDType(TypeId dataType);
private: private:
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);
@ -51,6 +53,7 @@ class DTypeTransPass : public GraphPass {
private: private:
size_t id; size_t id;
TypeId inputDataDType = TypeId::kNumberTypeFloat; TypeId inputDataDType = TypeId::kNumberTypeFloat;
TypeId outputDataDType = TypeId::kNumberTypeFloat;
OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> { OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> {
std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT); std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT);

@ -88,7 +88,7 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph,
if (inputInferType == "FLOAT") { if (inputInferType == "FLOAT") {
inArr.reset(new (std::nothrow) InputArray(mean, stdValue)); inArr.reset(new (std::nothrow) InputArray(mean, stdValue));
} else { } else {
inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeUInt8)); inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8));
} }
mInputArray = inArr.get(); mInputArray = inArr.get();
mInputArray->InitQuantParam(); mInputArray->InitQuantParam();

@ -37,8 +37,8 @@ struct InputArray {
InputArray(float mean, float stdDev, InputArray(float mean, float stdDev,
TypeId dataType = TypeId::kNumberTypeFloat) { TypeId dataType = TypeId::kNumberTypeFloat) {
this->dataType = dataType; this->dataType = dataType;
constexpr float qmin = 0; constexpr float qmin = -128;
constexpr float qmax = 255; constexpr float qmax = 127;
mMin = (qmin - mean) / stdDev; mMin = (qmin - mean) / stdDev;
mMax = (qmax - mean) / stdDev; mMax = (qmax - mean) / stdDev;
} }

@ -246,8 +246,8 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
return RET_OK; return RET_OK;
} }
int quantMin = narrowRange ? 1 : 0; int quantMin = narrowRange ? 1 : 0 - 128;
int quantMax = (1 << (unsigned int) numBits) - 1; int quantMax = (1 << (unsigned int) numBits) - 1 - 128;
auto quantMinFloat = static_cast<double>(quantMin); auto quantMinFloat = static_cast<double>(quantMin);
auto quantMaxFloat = static_cast<double>(quantMax); auto quantMaxFloat = static_cast<double>(quantMax);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);

Loading…
Cancel
Save