|
|
|
|
@ -66,6 +66,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
|
|
|
|
|
AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", "");
|
|
|
|
|
AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
|
|
|
|
|
AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false);
|
|
|
|
|
AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~NetTrainFlags() override = default;
|
|
|
|
|
@ -82,6 +83,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
|
|
|
|
|
DataType in_data_type_;
|
|
|
|
|
std::string in_data_type_in_ = "bin";
|
|
|
|
|
int cpu_bind_mode_ = 1;
|
|
|
|
|
bool enable_fp16_ = false;
|
|
|
|
|
// MarkPerformance
|
|
|
|
|
int num_threads_ = 1;
|
|
|
|
|
int warm_up_loop_count_ = 0;
|
|
|
|
|
|