!13302 Fixed a bug in benchmark_train

From: @louisncu
Reviewed-by: @zhang_xue_tong,@HilbertDavid
Signed-off-by: @zhang_xue_tong
pull/13302/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d831aba239

@ -21,7 +21,8 @@ if(PLATFORM_ARM64)
if(ENABLE_FP16)
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc)
if(SUPPORT_TRAIN)
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
file(GLOB FP16_KERNEL_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
set(FP16_KERNEL_SRC ${FP16_KERNEL_SRC} ${FP16_KERNEL_TRAIN_SRC})
endif()
add_library(cpu_fp16_kernel_mid OBJECT ${FP16_KERNEL_SRC})
add_dependencies(cpu_fp16_kernel_mid fbs_src)

@ -385,7 +385,7 @@ int NetTrain::RunNetTrain() {
} else {
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
}
context->device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
layer_checksum_ = flags_->layer_checksum_;
context->thread_num_ = flags_->num_threads_;
session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get());
@ -545,6 +545,7 @@ int NetTrain::Init() {
MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_;
MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_;
MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_;
if (this->flags_->epochs_ < 0) {
MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0";

@ -66,6 +66,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", "");
AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false);
AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false);
}
~NetTrainFlags() override = default;
@ -82,6 +83,7 @@ class MS_API NetTrainFlags : public virtual FlagParser {
DataType in_data_type_;
std::string in_data_type_in_ = "bin";
int cpu_bind_mode_ = 1;
bool enable_fp16_ = false;
// MarkPerformance
int num_threads_ = 1;
int warm_up_loop_count_ = 0;

Loading…
Cancel
Save