diff --git a/mindspore/lite/examples/train_lenet/Makefile b/mindspore/lite/examples/train_lenet/Makefile index 202cc3669a..a120fac317 100644 --- a/mindspore/lite/examples/train_lenet/Makefile +++ b/mindspore/lite/examples/train_lenet/Makefile @@ -1,6 +1,6 @@ BASE_DIR=$(realpath ../../../../) APP:=bin/net_runner -LMSLIB:=-lmindspore-lite +LMSLIB:=-lmindspore-lite-train LMDLIB:=-lminddata-lite MSDIR:=$(realpath package-$(TARGET)/lib) ifneq ("$(wildcard $(MSDIR)/libhiai.so)","") diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.cc b/mindspore/lite/examples/train_lenet/src/net_runner.cc index fa726639b4..1e96338175 100644 --- a/mindspore/lite/examples/train_lenet/src/net_runner.cc +++ b/mindspore/lite/examples/train_lenet/src/net_runner.cc @@ -99,11 +99,8 @@ void NetRunner::InitAndFigureInputs() { context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; context.thread_num_ = 2; - model_ = mindspore::lite::Model::Import(ms_file_); - if (model_ == nullptr) { - MS_LOG(ERROR) << "import model failed"; - return nullptr; - } + model_ = mindspore::lite::Model::Import(ms_file_.c_str()); + MS_ASSERT(nullptr != model_); session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true); MS_ASSERT(nullptr != session_); @@ -169,7 +166,7 @@ int NetRunner::TrainLoop() { mindspore::lite::LossMonitor lm(100); mindspore::lite::ClassificationTrainAccuracyMonitor am(1); - mindspore::lite::CkptSaver cs(1000, std::string("lenet")); + mindspore::lite::CkptSaver cs(1000, std::string("lenet"), model_); Rescaler rescale(255.0); loop_->Train(epochs_, train_ds_.get(), std::vector{&rescale, &lm, &cs, &am, &step_lr_sched}); @@ -187,7 +184,7 @@ int NetRunner::Main() { if (epochs_ > 0) { auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms"; - Model::Export(model_, trained_fn); + mindspore::lite::Model::Export(model_, trained_fn.c_str()); } return 0; } diff --git a/mindspore/lite/include/train/ckpt_saver.h b/mindspore/lite/include/train/ckpt_saver.h index 5582167b20..c54c8b56ed 100644 --- a/mindspore/lite/include/train/ckpt_saver.h +++ b/mindspore/lite/include/train/ckpt_saver.h @@ -29,14 +29,14 @@ namespace lite { class CkptSaver : public session::TrainLoopCallBack { public: - CkptSaver(int save_every_n, const std::string &filename_prefix) - : save_every_n_(save_every_n), filename_prefix_(filename_prefix) {} + CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model) + : save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {} int EpochEnd(const session::TrainLoopCallBackData &cb_data) override { if ((cb_data.epoch_ + 1) % save_every_n_ == 0) { auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms"; remove(cpkt_fn.c_str()); - cb_data.session_->SaveToFile(cpkt_fn); + Model::Export(model_, cpkt_fn.c_str()); } return session::RET_CONTINUE; } @@ -44,6 +44,7 @@ class CkptSaver : public session::TrainLoopCallBack { private: int save_every_n_; std::string filename_prefix_; + mindspore::lite::Model *model_ = nullptr; }; } // namespace lite