!14751 fix train_lenet demo

From: @HilbertDavid
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
pull/14751/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 13e41d3103

@ -1,6 +1,6 @@
BASE_DIR=$(realpath ../../../../) BASE_DIR=$(realpath ../../../../)
APP:=bin/net_runner APP:=bin/net_runner
LMSLIB:=-lmindspore-lite LMSLIB:=-lmindspore-lite-train
LMDLIB:=-lminddata-lite LMDLIB:=-lminddata-lite
MSDIR:=$(realpath package-$(TARGET)/lib) MSDIR:=$(realpath package-$(TARGET)/lib)
ifneq ("$(wildcard $(MSDIR)/libhiai.so)","") ifneq ("$(wildcard $(MSDIR)/libhiai.so)","")

@ -99,11 +99,8 @@ void NetRunner::InitAndFigureInputs() {
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2; context.thread_num_ = 2;
model_ = mindspore::lite::Model::Import(ms_file_); model_ = mindspore::lite::Model::Import(ms_file_.c_str());
if (model_ == nullptr) { MS_ASSERT(nullptr != model_);
MS_LOG(ERROR) << "import model failed";
return nullptr;
}
session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true); session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true);
MS_ASSERT(nullptr != session_); MS_ASSERT(nullptr != session_);
@ -169,7 +166,7 @@ int NetRunner::TrainLoop() {
mindspore::lite::LossMonitor lm(100); mindspore::lite::LossMonitor lm(100);
mindspore::lite::ClassificationTrainAccuracyMonitor am(1); mindspore::lite::ClassificationTrainAccuracyMonitor am(1);
mindspore::lite::CkptSaver cs(1000, std::string("lenet")); mindspore::lite::CkptSaver cs(1000, std::string("lenet"), model_);
Rescaler rescale(255.0); Rescaler rescale(255.0);
loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched}); loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched});
@ -187,7 +184,7 @@ int NetRunner::Main() {
if (epochs_ > 0) { if (epochs_ > 0) {
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms"; auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
Model::Export(model_, trained_fn); mindspore::lite::Model::Export(model_, trained_fn.c_str());
} }
return 0; return 0;
} }

@ -29,14 +29,14 @@ namespace lite {
class CkptSaver : public session::TrainLoopCallBack { class CkptSaver : public session::TrainLoopCallBack {
public: public:
CkptSaver(int save_every_n, const std::string &filename_prefix) CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix) {} : save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {}
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override { int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) { if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms"; auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
remove(cpkt_fn.c_str()); remove(cpkt_fn.c_str());
cb_data.session_->SaveToFile(cpkt_fn); Model::Export(model_, cpkt_fn.c_str());
} }
return session::RET_CONTINUE; return session::RET_CONTINUE;
} }
@ -44,6 +44,7 @@ class CkptSaver : public session::TrainLoopCallBack {
private: private:
int save_every_n_; int save_every_n_;
std::string filename_prefix_; std::string filename_prefix_;
mindspore::lite::Model *model_ = nullptr;
}; };
} // namespace lite } // namespace lite

Loading…
Cancel
Save