|
|
|
@ -30,9 +30,8 @@ Flags::Flags() {
|
|
|
|
|
AddFlag(&Flags::weightFile, "weightFile",
|
|
|
|
|
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", "");
|
|
|
|
|
AddFlag(&Flags::inferenceTypeIn, "inferenceType",
|
|
|
|
|
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8", "FLOAT");
|
|
|
|
|
"Real data type saved in output file, reserved param, NOT used for now. SAME | FLOAT | INT8", "FLOAT");
|
|
|
|
|
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", "");
|
|
|
|
|
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT");
|
|
|
|
|
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
|
|
|
|
|
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5");
|
|
|
|
|
AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8");
|
|
|
|
@ -41,8 +40,10 @@ Flags::Flags() {
|
|
|
|
|
"16");
|
|
|
|
|
AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", "");
|
|
|
|
|
AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true");
|
|
|
|
|
AddFlag(&Flags::trainModelIn, "trainModel", "whether the model is going to be trained on device."
|
|
|
|
|
" true | false", "false");
|
|
|
|
|
AddFlag(&Flags::trainModelIn, "trainModel",
|
|
|
|
|
"whether the model is going to be trained on device."
|
|
|
|
|
" true | false",
|
|
|
|
|
"false");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int Flags::Init(int argc, const char **argv) {
|
|
|
|
@ -80,22 +81,15 @@ int Flags::Init(int argc, const char **argv) {
|
|
|
|
|
std::cerr << "INPUT MISSING: fmk is necessary";
|
|
|
|
|
return RET_INPUT_PARAM_LACK;
|
|
|
|
|
}
|
|
|
|
|
if (this->inputInferenceTypeIn == "FLOAT") {
|
|
|
|
|
this->inputInferenceType = TypeId::kNumberTypeFloat;
|
|
|
|
|
} else if (this->inputInferenceTypeIn == "INT8") {
|
|
|
|
|
this->inputInferenceType = TypeId::kNumberTypeInt8;
|
|
|
|
|
} else {
|
|
|
|
|
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8",
|
|
|
|
|
this->inputInferenceTypeIn.c_str();
|
|
|
|
|
return RET_INPUT_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (this->inferenceTypeIn == "FLOAT") {
|
|
|
|
|
this->inferenceType = TypeId::kNumberTypeFloat;
|
|
|
|
|
} else if (this->inferenceTypeIn == "INT8") {
|
|
|
|
|
this->inferenceType = TypeId::kNumberTypeInt8;
|
|
|
|
|
} else if (this->inferenceTypeIn == "SAME") {
|
|
|
|
|
this->inferenceType = TypeId::kTypeUnknown;
|
|
|
|
|
} else {
|
|
|
|
|
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8",
|
|
|
|
|
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | SAME",
|
|
|
|
|
this->inferenceTypeIn.c_str();
|
|
|
|
|
return RET_INPUT_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
@ -130,7 +124,6 @@ int Flags::Init(int argc, const char **argv) {
|
|
|
|
|
return RET_INPUT_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (this->trainModelIn == "true") {
|
|
|
|
|
this->trainModel = true;
|
|
|
|
|
} else if (this->trainModelIn == "false") {
|
|
|
|
|