|
|
|
@ -99,11 +99,8 @@ void NetRunner::InitAndFigureInputs() {
|
|
|
|
|
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
|
|
|
|
|
context.thread_num_ = 2;
|
|
|
|
|
|
|
|
|
|
model_ = mindspore::lite::Model::Import(ms_file_);
|
|
|
|
|
if (model_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "import model failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
model_ = mindspore::lite::Model::Import(ms_file_.c_str());
|
|
|
|
|
MS_ASSERT(nullptr != model_);
|
|
|
|
|
session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true);
|
|
|
|
|
|
|
|
|
|
MS_ASSERT(nullptr != session_);
|
|
|
|
@ -169,7 +166,7 @@ int NetRunner::TrainLoop() {
|
|
|
|
|
|
|
|
|
|
mindspore::lite::LossMonitor lm(100);
|
|
|
|
|
mindspore::lite::ClassificationTrainAccuracyMonitor am(1);
|
|
|
|
|
mindspore::lite::CkptSaver cs(1000, std::string("lenet"));
|
|
|
|
|
mindspore::lite::CkptSaver cs(1000, std::string("lenet"), model_);
|
|
|
|
|
Rescaler rescale(255.0);
|
|
|
|
|
|
|
|
|
|
loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched});
|
|
|
|
@ -187,7 +184,7 @@ int NetRunner::Main() {
|
|
|
|
|
|
|
|
|
|
if (epochs_ > 0) {
|
|
|
|
|
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
|
|
|
|
|
Model::Export(model_, trained_fn);
|
|
|
|
|
mindspore::lite::Model::Export(model_, trained_fn.c_str());
|
|
|
|
|
}
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|