fix train_lenet demo

fix code-style
pull/14751/head
lz 4 years ago
parent 5d96d0f7e9
commit 88106f3f59

@ -1,6 +1,6 @@
BASE_DIR=$(realpath ../../../../)
APP:=bin/net_runner
LMSLIB:=-lmindspore-lite
LMSLIB:=-lmindspore-lite-train
LMDLIB:=-lminddata-lite
MSDIR:=$(realpath package-$(TARGET)/lib)
ifneq ("$(wildcard $(MSDIR)/libhiai.so)","")

@ -99,11 +99,8 @@ void NetRunner::InitAndFigureInputs() {
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2;
model_ = mindspore::lite::Model::Import(ms_file_);
if (model_ == nullptr) {
MS_LOG(ERROR) << "import model failed";
return nullptr;
}
model_ = mindspore::lite::Model::Import(ms_file_.c_str());
MS_ASSERT(nullptr != model_);
session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true);
MS_ASSERT(nullptr != session_);
@ -169,7 +166,7 @@ int NetRunner::TrainLoop() {
mindspore::lite::LossMonitor lm(100);
mindspore::lite::ClassificationTrainAccuracyMonitor am(1);
mindspore::lite::CkptSaver cs(1000, std::string("lenet"));
mindspore::lite::CkptSaver cs(1000, std::string("lenet"), model_);
Rescaler rescale(255.0);
loop_->Train(epochs_, train_ds_.get(), std::vector<TrainLoopCallBack *>{&rescale, &lm, &cs, &am, &step_lr_sched});
@ -187,7 +184,7 @@ int NetRunner::Main() {
if (epochs_ > 0) {
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
Model::Export(model_, trained_fn);
mindspore::lite::Model::Export(model_, trained_fn.c_str());
}
return 0;
}

@ -29,14 +29,14 @@ namespace lite {
class CkptSaver : public session::TrainLoopCallBack {
public:
CkptSaver(int save_every_n, const std::string &filename_prefix)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix) {}
CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {}
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
remove(cpkt_fn.c_str());
cb_data.session_->SaveToFile(cpkt_fn);
Model::Export(model_, cpkt_fn.c_str());
}
return session::RET_CONTINUE;
}
@ -44,6 +44,7 @@ class CkptSaver : public session::TrainLoopCallBack {
private:
int save_every_n_;
std::string filename_prefix_;
mindspore::lite::Model *model_ = nullptr;
};
} // namespace lite

Loading…
Cancel
Save