From 0517ffbff3b070195b913cf79894f262aef4e119 Mon Sep 17 00:00:00 2001 From: lz Date: Thu, 1 Apr 2021 18:55:11 +0800 Subject: [PATCH] delete trainModel class formating add include fix arm32 delete useness code use char * in external APIs formatting change virtual export to static --- .../examples/train_lenet/src/net_runner.cc | 11 +- .../examples/train_lenet/src/net_runner.h | 2 + mindspore/lite/include/model.h | 14 ++- mindspore/lite/include/train/train_session.h | 30 +----- mindspore/lite/src/CMakeLists.txt | 1 - mindspore/lite/src/lite_model.cc | 101 ++++++++++++++++++ mindspore/lite/src/train/train_model.cc | 91 ---------------- mindspore/lite/src/train/train_model.h | 57 ---------- mindspore/lite/src/train/train_session.cc | 84 +-------------- mindspore/lite/src/train/train_session.h | 8 +- mindspore/lite/src/train/transfer_session.cc | 2 +- mindspore/lite/src/train/transfer_session.h | 1 - mindspore/lite/test/CMakeLists.txt | 1 - .../kernel/arm/fp32_grad/network_test.cc | 18 +++- .../lite/tools/benchmark_train/net_train.cc | 20 +++- 15 files changed, 158 insertions(+), 283 deletions(-) delete mode 100644 mindspore/lite/src/train/train_model.cc delete mode 100644 mindspore/lite/src/train/train_model.h diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.cc b/mindspore/lite/examples/train_lenet/src/net_runner.cc index a44551aeaf..fa726639b4 100644 --- a/mindspore/lite/examples/train_lenet/src/net_runner.cc +++ b/mindspore/lite/examples/train_lenet/src/net_runner.cc @@ -99,7 +99,13 @@ void NetRunner::InitAndFigureInputs() { context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; context.thread_num_ = 2; - session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context); + model_ = mindspore::lite::Model::Import(ms_file_); + if (model_ == nullptr) { + MS_LOG(ERROR) << "import model failed"; + return nullptr; + } + session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true); + MS_ASSERT(nullptr != session_); loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session_); @@ -154,7 +160,6 @@ int NetRunner::InitDB() { std::cout << "No relevant data was found in " << data_dir_ << std::endl; MS_ASSERT(train_ds_->GetDatasetSize() != 0); } - return 0; } @@ -182,7 +187,7 @@ int NetRunner::Main() { if (epochs_ > 0) { auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms"; - session_->SaveToFile(trained_fn); + Model::Export(model_, trained_fn); } return 0; } diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.h b/mindspore/lite/examples/train_lenet/src/net_runner.h index 824663047f..eb9845255c 100644 --- a/mindspore/lite/examples/train_lenet/src/net_runner.h +++ b/mindspore/lite/examples/train_lenet/src/net_runner.h @@ -27,6 +27,7 @@ #include "include/train/accuracy_metrics.h" #include "include/ms_tensor.h" #include "include/datasets.h" +#include "include/model.h" using mindspore::dataset::Dataset; using mindspore::lite::AccuracyMetrics; @@ -36,6 +37,7 @@ class NetRunner { int Main(); bool ReadArgs(int argc, char *argv[]); ~NetRunner(); + mindspore::lite::Model *model_ = nullptr; private: void Usage(); diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h index 22d2ffa30b..48e1e767ce 100644 --- a/mindspore/lite/include/model.h +++ b/mindspore/lite/include/model.h @@ -45,13 +45,17 @@ struct MS_API Model { SubGraphPtrVector sub_graphs_; /// \brief Static method to create a Model pointer. - /// - /// \param[in] model_buf Define the buffer read from a model file. - /// \param[in] size Define bytes number of model buffer. - /// - /// \return Pointer of MindSpore Lite Model. static Model *Import(const char *model_buf, size_t size); + /// \brief Static method to create a Model pointer. + static Model *Import(const char *filename); + + /// \brief method to export model to file. + static int Export(Model *model, const char *filename); + + /// \brief method to export model to buffer. + static int Export(Model *model, char *buf, size_t *size); + /// \brief Free meta graph temporary buffer virtual void Free() = 0; diff --git a/mindspore/lite/include/train/train_session.h b/mindspore/lite/include/train/train_session.h index 97a80b37d5..2c80b65357 100644 --- a/mindspore/lite/include/train/train_session.h +++ b/mindspore/lite/include/train/train_session.h @@ -32,23 +32,12 @@ class TrainSession : public session::LiteSession { /// \brief Static method to create a TrainSession object /// - /// \param[in] model_buf A buffer that was read from a MS model file - /// \param[in] size Length of the buffer + /// \param[in] model A buffer that was read from a MS model file /// \param[in] context Defines the context of the session to be created /// \param[in] train_mode training mode to initialize Session with /// /// \return Pointer of MindSpore Lite TrainSession - static TrainSession *CreateSession(const char *model_buf, size_t size, lite::Context *context, - bool train_mode = false); - - /// \brief Static method to create a TrainSession object - /// - /// \param[in] filename Filename to read flatbuffer from - /// \param[in] context Defines the context of the session to be created - /// \param[in] train_mode training mode to initialize Session with - /// - /// \return Pointer of MindSpore Lite TrainSession - static TrainSession *CreateSession(const std::string &filename, lite::Context *context, bool train_mode = false); + static TrainSession *CreateSession(mindspore::lite::Model *model, lite::Context *context, bool train_mode = false); /// \brief Static method to create a transfer lernning support TrainSession object /// @@ -75,21 +64,6 @@ class TrainSession : public session::LiteSession { static TrainSession *CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head, lite::Context *context, bool train_mode = false); - /// \brief Export the trained model into a buffer - /// - /// \param[in] buf The buffer to Export into. If equal to nullptr, buf will be allocated - /// \param[in,out] len Size of the pre-allocated buffer, and returned size of the exported buffer - /// - /// \return pointer to the export buffer - virtual void *ExportToBuf(char *buf, size_t *len) const = 0; - - /// \brief Save the trained model into a flatbuffer file - /// - /// \param[in] filename Filename to save flatbuffer to - /// - /// \return 0 on success or -1 in case of error - virtual int SaveToFile(const std::string &filename) const = 0; - /// \brief Set model to train mode /// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h virtual int Train() = 0; diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 0a2d0f0143..954ce132b4 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -111,7 +111,6 @@ if(SUPPORT_TRAIN) ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/transfer_session.cc - ${CMAKE_CURRENT_SOURCE_DIR}/train/train_model.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/loss_monitor.cc diff --git a/mindspore/lite/src/lite_model.cc b/mindspore/lite/src/lite_model.cc index ea4195af8f..87e656126d 100644 --- a/mindspore/lite/src/lite_model.cc +++ b/mindspore/lite/src/lite_model.cc @@ -15,9 +15,13 @@ */ #include "src/lite_model.h" +#include +#include +#include #include #include #include +#include #include "src/common/prim_util.h" #ifdef ENABLE_V0 #include "src/ops/compat/compat_register.h" @@ -343,5 +347,102 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { return model; } +std::unique_ptr ReadFileToBuf(const std::string &filename, size_t *size) { + std::ifstream ifs(filename); + if (!ifs.good()) { + MS_LOG(ERROR) << "File: " << filename << " does not exist"; + return std::unique_ptr(nullptr); + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "File: " << filename << " open failed"; + return std::unique_ptr(nullptr); + } + + ifs.seekg(0, std::ios::end); + auto tellg_ret = ifs.tellg(); + if (tellg_ret <= 0) { + MS_LOG(ERROR) << "Could not read file " << filename; + return std::unique_ptr(nullptr); + } + size_t fsize = static_cast(tellg_ret); + + std::unique_ptr buf(new (std::nothrow) char[fsize]); + if (buf == nullptr) { + MS_LOG(ERROR) << "malloc buf failed, file: " << filename; + ifs.close(); + return std::unique_ptr(nullptr); + } + + ifs.seekg(0, std::ios::beg); + ifs.read(buf.get(), fsize); + if (!ifs) { + MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename; + ifs.close(); + return std::unique_ptr(nullptr); + } + ifs.close(); + if (size != nullptr) { + *size = fsize; + } + return buf; +} + Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); } + +Model *Model::Import(const char *filename) { + size_t size = -1; + auto buf = ReadFileToBuf(filename, &size); + if (buf == nullptr) { + return nullptr; + } + return ImportFromBuffer(buf.get(), size, false); +} + +int Model::Export(Model *model, char *buffer, size_t *len) { + if (len == nullptr) { + MS_LOG(ERROR) << "len is nullptr"; + return RET_ERROR; + } + auto *liteModel = reinterpret_cast(model); + + if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) { + MS_LOG(ERROR) << "model buffer is invalid"; + return RET_ERROR; + } + if (*len < liteModel->buf_size_ && buffer != nullptr) { + MS_LOG(ERROR) << "Buffer is too small, Export Failed"; + return RET_ERROR; + } + if (buffer == nullptr) { + buffer = reinterpret_cast(malloc(liteModel->buf_size_)); + if (buffer == nullptr) { + MS_LOG(ERROR) << "allocated model buf fail!"; + return RET_ERROR; + } + } + memcpy(buffer, liteModel->buf, liteModel->buf_size_); + *len = liteModel->buf_size_; + return RET_OK; +} + +int Model::Export(Model *model, const char *filename) { + auto *liteModel = reinterpret_cast(model); + if (liteModel->buf_size_ == 0 || liteModel->buf == nullptr) { + MS_LOG(ERROR) << "model buf is invalid"; + return RET_ERROR; + } + + std::ofstream ofs(filename); + if (!ofs.good() || !ofs.is_open()) { + MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing"; + return RET_ERROR; + } + + ofs.seekp(0, std::ios::beg); + ofs.write(liteModel->buf, liteModel->buf_size_); + ofs.close(); + return chmod(filename, S_IRUSR); +} + } // namespace mindspore::lite diff --git a/mindspore/lite/src/train/train_model.cc b/mindspore/lite/src/train/train_model.cc deleted file mode 100644 index bab09e3206..0000000000 --- a/mindspore/lite/src/train/train_model.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/train/train_model.h" -#include "src/common/log_adapter.h" -#include "include/errorcode.h" -#include "src/common/graph_util.h" - -namespace mindspore::lite { - -TrainModel *TrainModel::Import(const char *model_buf, size_t size) { - if (model_buf == nullptr) { - MS_LOG(ERROR) << "The model buf is nullptr"; - return nullptr; - } - TrainModel *model = new (std::nothrow) TrainModel(); - if (model == nullptr) { - MS_LOG(ERROR) << "new model fail!"; - return nullptr; - } - model->buf = reinterpret_cast(malloc(size)); - if (model->buf == nullptr) { - delete model; - MS_LOG(ERROR) << "malloc inner model buf fail!"; - return nullptr; - } - memcpy(model->buf, model_buf, size); - model->buf_size_ = size; - auto status = model->ConstructModel(); - if (status != RET_OK) { - MS_LOG(ERROR) << "construct model failed."; - delete model; - return nullptr; - } - return model; -} - -void TrainModel::Free() {} - -char *TrainModel::ExportBuf(char *buffer, size_t *len) const { - if (len == nullptr) { - MS_LOG(ERROR) << "len is nullptr"; - return nullptr; - } - if (buf_size_ == 0 || buf == nullptr) { - MS_LOG(ERROR) << "Model::Export is only available for Train Session"; - return nullptr; - } - if (*len < buf_size_ && buffer != nullptr) { - MS_LOG(ERROR) << "Buffer is too small, Export Failed"; - return nullptr; - } - if (buffer == nullptr) { - buffer = reinterpret_cast(malloc(buf_size_)); - if (buffer == nullptr) { - MS_LOG(ERROR) << "allocated model buf fail!"; - return nullptr; - } - } - - memcpy(buffer, buf, buf_size_); - *len = buf_size_; - return buffer; -} - -char *TrainModel::GetBuffer(size_t *len) const { - if (len == nullptr) { - MS_LOG(ERROR) << "len is nullptr"; - return nullptr; - } - if (buf_size_ == 0 || buf == nullptr) { - MS_LOG(ERROR) << "Model::Export is only available for Train Session"; - return nullptr; - } - - *len = buf_size_; - return buf; -} -} // namespace mindspore::lite diff --git a/mindspore/lite/src/train/train_model.h b/mindspore/lite/src/train/train_model.h deleted file mode 100644 index b2c591eef6..0000000000 --- a/mindspore/lite/src/train/train_model.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ -#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ -#include -#include "src/lite_model.h" - -namespace mindspore { -namespace lite { -/// \brief TrainModel Defines a class that allows to import and export a mindsport trainable model -struct TrainModel : public lite::LiteModel { - /// \brief Static method to create a TrainModel object - /// - /// \param[in] model_buf A buffer that was read from a MS model file - /// \param[in] size Length of the buffer - // - /// \return Pointer to MindSpore Lite TrainModel - static TrainModel *Import(const char *model_buf, size_t size); - - /// \brief Free meta graph related data - void Free() override; - - /// \brief Class destructor, free all memory - virtual ~TrainModel() = default; - - /// \brief Export Model into a buffer - /// - /// \param[in] buf The buffer to Export into. If equal to nullptr, buf will be allocated - /// \param[in,out] len Size of the pre-allocated buffer, and returned size of the exported buffer - /// - /// \return Pointer to buffer with exported model - char *ExportBuf(char *buf, size_t *len) const; - - /// \brief Get Model buffer - /// - /// \param[in,out] len Return size of the buffer - /// - /// \return Pointer to model buffer - char *GetBuffer(size_t *len) const; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 8dd5fc563c..a180c25a45 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -25,6 +25,7 @@ #include "include/errorcode.h" #include "src/common/utils.h" #include "src/tensor.h" +#include "src/lite_model.h" #include "src/train/loss_kernel.h" #include "src/train/optimizer_kernel.h" #include "src/sub_graph_kernel.h" @@ -39,47 +40,6 @@ namespace mindspore { namespace lite { -std::unique_ptr ReadFileToBuf(const std::string &filename, size_t *size) { - std::ifstream ifs(filename); - if (!ifs.good()) { - MS_LOG(ERROR) << "File: " << filename << " does not exist"; - return std::unique_ptr(nullptr); - } - - if (!ifs.is_open()) { - MS_LOG(ERROR) << "File: " << filename << " open failed"; - return std::unique_ptr(nullptr); - } - - ifs.seekg(0, std::ios::end); - auto tellg_ret = ifs.tellg(); - if (tellg_ret <= 0) { - MS_LOG(ERROR) << "Could not read file " << filename; - return std::unique_ptr(nullptr); - } - size_t fsize = static_cast(tellg_ret); - - std::unique_ptr buf(new (std::nothrow) char[fsize]); - if (buf == nullptr) { - MS_LOG(ERROR) << "malloc buf failed, file: " << filename; - ifs.close(); - return std::unique_ptr(nullptr); - } - - ifs.seekg(0, std::ios::beg); - ifs.read(buf.get(), fsize); - if (!ifs) { - MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename; - ifs.close(); - return std::unique_ptr(nullptr); - } - ifs.close(); - if (size != nullptr) { - *size = fsize; - } - return buf; -} - static size_t TSFindTensor(const std::vector &where, const lite::Tensor *searchParameter) { for (size_t i = 0; i < where.size(); i++) { if (where[i] == searchParameter) { @@ -139,7 +99,7 @@ void TrainSession::AllocWorkSpace() { int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; } -int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) { +int TrainSession::CompileTrainGraph(mindspore::lite::Model *model) { model_ = model; auto restore = ReplaceOps(); @@ -171,8 +131,6 @@ TrainSession::~TrainSession() { } } -void *TrainSession::ExportToBuf(char *buf, size_t *len) const { return model_->ExportBuf(buf, len); } - int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) { this->outputs_.clear(); @@ -214,25 +172,6 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a return RET_OK; } -int TrainSession::SaveToFile(const std::string &filename) const { - size_t fb_size = 0; - const auto *buf = reinterpret_cast(model_->GetBuffer(&fb_size)); - if (buf == nullptr) { - MS_LOG(ERROR) << "Could not Export Trained model"; - return lite::RET_NULL_PTR; - } - std::ofstream ofs(filename); - if ((true != ofs.good()) || (true != ofs.is_open())) { - MS_LOG(ERROR) << "Could not open file \"" << filename << "\" for writing"; - return RET_ERROR; - } - - ofs.seekp(0, std::ios::beg); - ofs.write(buf, fb_size); - ofs.close(); - return chmod(filename.c_str(), S_IRUSR); -} - int TrainSession::Train() { // shift kernels to train mode train_mode_ = true; @@ -522,14 +461,8 @@ int TrainSession::SetLossName(std::string loss_name) { } } // namespace lite -session::TrainSession *session::TrainSession::CreateSession(const char *model_buf, size_t size, lite::Context *context, +session::TrainSession *session::TrainSession::CreateSession(mindspore::lite::Model *model, lite::Context *context, bool train_mode) { - auto model = mindspore::lite::TrainModel::Import(model_buf, size); - if (model == nullptr) { - MS_LOG(ERROR) << "create model for train session failed"; - return nullptr; - } - auto session = new (std::nothrow) lite::TrainSession(); if (session == nullptr) { delete model; @@ -564,15 +497,4 @@ session::TrainSession *session::TrainSession::CreateSession(const char *model_bu return session; } - -session::TrainSession *session::TrainSession::CreateSession(const std::string &filename, lite::Context *context, - bool train_mode) { - size_t size = -1; - auto buf = lite::ReadFileToBuf(filename, &size); - if (buf == nullptr) { - return nullptr; - } - return session::TrainSession::CreateSession(buf.get(), size, context, train_mode); -} - } // namespace mindspore diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index 69e73e8417..cee89e1ff2 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -21,7 +21,6 @@ #include #include #include "include/train/train_session.h" -#include "src/train/train_model.h" #include "src/lite_session.h" /* @@ -52,10 +51,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; int CompileGraph(lite::Model *model) override; - virtual int CompileTrainGraph(lite::TrainModel *model); - - void *ExportToBuf(char *buf, size_t *len) const override; - int SaveToFile(const std::string &filename) const override; + virtual int CompileTrainGraph(lite::Model *model); int Train() override; int Eval() override; @@ -108,7 +104,7 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: virtual void CompileTrainOutputs(); virtual void CompileEvalOutputs(); - TrainModel *model_ = nullptr; + Model *model_ = nullptr; std::unordered_map> orig_output_node_map_; std::unordered_map orig_output_tensor_map_; std::vector orig_output_tensor_names_; diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc index 7be49ca0fe..81dd5027db 100644 --- a/mindspore/lite/src/train/transfer_session.cc +++ b/mindspore/lite/src/train/transfer_session.cc @@ -190,7 +190,7 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char * return nullptr; } - auto model = lite::TrainModel::Import(model_buf_head, size_head); + auto model = lite::Model::Import(model_buf_head, size_head); if (model == nullptr) { MS_LOG(ERROR) << "create model for head train session failed"; delete session; diff --git a/mindspore/lite/src/train/transfer_session.h b/mindspore/lite/src/train/transfer_session.h index 3ad8219a8e..85b4f2414c 100644 --- a/mindspore/lite/src/train/transfer_session.h +++ b/mindspore/lite/src/train/transfer_session.h @@ -20,7 +20,6 @@ #include #include #include -#include "src/train/train_model.h" #include "src/lite_session.h" #include "src/train/train_session.h" diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 66634ae183..8c15c4d12e 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -281,7 +281,6 @@ if(SUPPORT_TRAIN) ${LITE_DIR}/src/train/train_populate_parameter_v0.cc ${LITE_DIR}/src/train/train_session.cc ${LITE_DIR}/src/train/transfer_session.cc - ${LITE_DIR}/src/train/train_model.cc ${LITE_DIR}/src/lite_session.cc ) else() diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc index 11071725f6..c51315d1e4 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc @@ -359,10 +359,14 @@ TEST_F(NetworkTest, tuning_layer) { meta_graph.reset(); content = nullptr; + + auto *model = mindspore::lite::Model::Import(content, size); + ASSERT_NE(nullptr, model); + lite::Context context; context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context.thread_num_ = 1; - auto session = session::TrainSession::CreateSession(content, size, &context); + auto session = session::TrainSession::CreateSession(model, &context); ASSERT_NE(nullptr, session); session->Train(); session->Train(); // Just double check that calling Train twice does not cause a problem @@ -513,7 +517,10 @@ TEST_F(NetworkTest, efficient_net) { context->thread_num_ = 1; std::string net = "./test_data/nets/effnetb0_fwd_nofuse.ms"; - auto session = session::TrainSession::CreateSession(net, context, false); + + auto *model = mindspore::lite::Model::Import(net.c_str()); + ASSERT_NE(model, nullptr); + auto session = session::TrainSession::CreateSession(model, context, false); ASSERT_NE(session, nullptr); std::string in = "./test_data/nets/effNet_input_x_1_3_224_224.bin"; @@ -530,7 +537,6 @@ TEST_F(NetworkTest, mobileface_net) { std::string net = "./test_data/nets/mobilefacenet0924.ms"; ReadFile(net.c_str(), &net_size, &buf); - // auto model = lite::TrainModel::Import(buf, net_size); auto model = lite::Model::Import(buf, net_size); delete[] buf; auto context = new lite::Context; @@ -538,7 +544,6 @@ TEST_F(NetworkTest, mobileface_net) { context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context->thread_num_ = 1; - // auto session = session::TrainSession::CreateSession(context); auto session = session::LiteSession::CreateSession(context); ASSERT_NE(session, nullptr); auto ret = session->CompileGraph(model); @@ -560,7 +565,10 @@ TEST_F(NetworkTest, setname) { lite::Context context; context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context.thread_num_ = 1; - auto session = mindspore::session::TrainSession::CreateSession(net, &context); + + auto *model = mindspore::lite::Model::Import(net.c_str()); + ASSERT_NE(model, nullptr); + auto session = mindspore::session::TrainSession::CreateSession(model, &context); ASSERT_NE(session, nullptr); auto tensors_map = session->GetOutputs(); diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc index 7ad4c8ec3e..d44d735760 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.cc +++ b/mindspore/lite/tools/benchmark_train/net_train.cc @@ -25,6 +25,7 @@ #include "include/context.h" #include "src/runtime/runtime_api.h" #include "include/version.h" +#include "include/model.h" namespace mindspore { namespace lite { @@ -326,7 +327,14 @@ int NetTrain::RunExportedNet() { } context->thread_num_ = flags_->num_threads_; - session_ = session::TrainSession::CreateSession(flags_->export_file_.c_str(), context.get()); + + auto *model = mindspore::lite::Model::Import(flags_->export_file_.c_str()); + if (model == nullptr) { + MS_LOG(ERROR) << "create model for train session failed"; + return RET_ERROR; + } + + session_ = session::TrainSession::CreateSession(model, context.get()); if (session_ == nullptr) { MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str(); std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl; @@ -388,7 +396,13 @@ int NetTrain::RunNetTrain() { context->device_list_[0].device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; layer_checksum_ = flags_->layer_checksum_; context->thread_num_ = flags_->num_threads_; - session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get()); + + auto *model = mindspore::lite::Model::Import(flags_->model_file_.c_str()); + if (model == nullptr) { + MS_LOG(ERROR) << "create model for train session failed"; + return RET_ERROR; + } + session_ = session::TrainSession::CreateSession(model, context.get()); if (session_ == nullptr) { MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str(); std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl; @@ -432,7 +446,7 @@ int NetTrain::RunNetTrain() { } } if (!flags_->export_file_.empty()) { - auto ret = session_->SaveToFile(flags_->export_file_); + auto ret = Model::Export(model, flags_->export_file_.c_str()); if (ret != RET_OK) { MS_LOG(ERROR) << "SaveToFile error"; std::cout << "Run SaveToFile error";