parent
8008843562
commit
33d7741904
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_
|
||||
#define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "include/train/train_loop.h"
|
||||
#include "src/dataset.h"
|
||||
|
||||
using GraphPoint = std::pair<int, float>;
|
||||
|
||||
class AccuracyMonitor : public mindspore::session::TrainLoopCallBack {
|
||||
public:
|
||||
explicit AccuracyMonitor(DataSet *dataset, int check_every_n, int max_steps = -1)
|
||||
: ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {}
|
||||
|
||||
int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override;
|
||||
|
||||
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
|
||||
|
||||
private:
|
||||
DataSet *ds_;
|
||||
std::vector<GraphPoint> accuracies_;
|
||||
int check_every_n_;
|
||||
int max_steps_;
|
||||
};
|
||||
|
||||
#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_
|
@ -0,0 +1,103 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <math.h>
|
||||
#include <getopt.h>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <utility>
|
||||
#include "src/net_runner.h"
|
||||
#include "include/context.h"
|
||||
#include "src/utils.h"
|
||||
#include "src/data_loader.h"
|
||||
#include "src/accuracy_monitor.h"
|
||||
|
||||
static unsigned int seed = time(NULL);
|
||||
|
||||
std::vector<int> FillInputDataUtil(const mindspore::session::TrainLoopCallBackData &cb_data,
|
||||
const std::vector<DataLabelTuple> &dataset, bool serially) {
|
||||
static unsigned int idx = 1;
|
||||
int total_size = dataset.size();
|
||||
std::vector<int> labels_vec;
|
||||
|
||||
auto inputs = cb_data.session_->GetInputs();
|
||||
char *input_data = reinterpret_cast<char *>(inputs.at(0)->MutableData());
|
||||
auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
|
||||
int batch_size = inputs.at(0)->shape()[0];
|
||||
int num_of_classes = inputs.at(1)->shape()[1];
|
||||
int data_size = inputs.at(0)->Size() / batch_size;
|
||||
MS_ASSERT(total_size > 0);
|
||||
MS_ASSERT(input_data != nullptr);
|
||||
std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
if (serially) {
|
||||
idx = ++idx % total_size;
|
||||
} else {
|
||||
idx = rand_r(&seed) % total_size;
|
||||
}
|
||||
int label = 0;
|
||||
char *data = nullptr;
|
||||
std::tie(data, label) = dataset[idx];
|
||||
std::copy(data, data + data_size, input_data + i * data_size);
|
||||
labels[i * num_of_classes + label] = 1.0; // Model expects labels in onehot representation
|
||||
labels_vec.push_back(label);
|
||||
}
|
||||
return labels_vec;
|
||||
}
|
||||
|
||||
void DataLoader::StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) {
|
||||
FillInputDataUtil(cb_data, ds_->train_data(), false);
|
||||
}
|
||||
|
||||
int AccuracyMonitor::EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) {
|
||||
if ((cb_data.epoch_ + 1) % check_every_n_ != 0) return mindspore::session::RET_CONTINUE;
|
||||
|
||||
float accuracy = 0.0;
|
||||
auto inputs = cb_data.session_->GetInputs();
|
||||
int batch_size = inputs.at(0)->shape()[0];
|
||||
int num_of_classes = ds_->num_of_classes();
|
||||
int tests = ds_->test_data().size() / batch_size;
|
||||
if (max_steps_ != -1 && tests > max_steps_) tests = max_steps_;
|
||||
cb_data.session_->Eval();
|
||||
for (int i = 0; i < tests; i++) {
|
||||
auto labels = FillInputDataUtil(cb_data, ds_->test_data(), false);
|
||||
cb_data.session_->RunGraph();
|
||||
auto outputs = cb_data.session_->GetPredictions();
|
||||
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
|
||||
if (it->second->ElementsNum() == batch_size * num_of_classes) {
|
||||
auto scores = reinterpret_cast<float *>(it->second->MutableData());
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
int max_idx = 0;
|
||||
float max_score = scores[num_of_classes * b];
|
||||
for (int c = 1; c < num_of_classes; c++) {
|
||||
if (scores[num_of_classes * b + c] > max_score) {
|
||||
max_score = scores[num_of_classes * b + c];
|
||||
max_idx = c;
|
||||
}
|
||||
}
|
||||
if (labels[b] == max_idx) accuracy += 1.0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
accuracy /= static_cast<float>(batch_size * tests);
|
||||
accuracies_.push_back(std::make_pair(cb_data.epoch_, accuracy));
|
||||
std::cout << cb_data.epoch_ + 1 << ":\tAccuracy is " << accuracy << std::endl;
|
||||
cb_data.session_->Train();
|
||||
return mindspore::session::RET_CONTINUE;
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_
|
||||
#define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "include/train/train_loop.h"
|
||||
#include "src/dataset.h"
|
||||
|
||||
class DataLoader : public mindspore::session::TrainLoopCallBack {
|
||||
public:
|
||||
explicit DataLoader(DataSet *dataset) : ds_(dataset) {}
|
||||
void StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) override;
|
||||
|
||||
private:
|
||||
DataSet *ds_;
|
||||
};
|
||||
|
||||
#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_
|
@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "include/train/train_loop.h"
|
||||
|
||||
using GraphPoint = std::pair<int, float>;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
class CkptSaver : public session::TrainLoopCallBack {
|
||||
public:
|
||||
CkptSaver(int save_every_n, const std::string &filename_prefix)
|
||||
: save_every_n_(save_every_n), filename_prefix_(filename_prefix) {}
|
||||
|
||||
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
|
||||
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
|
||||
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
|
||||
remove(cpkt_fn.c_str());
|
||||
cb_data.session_->SaveToFile(cpkt_fn);
|
||||
}
|
||||
return session::RET_CONTINUE;
|
||||
}
|
||||
|
||||
private:
|
||||
int save_every_n_;
|
||||
std::string filename_prefix_;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_
|
@ -0,0 +1,48 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <climits>
|
||||
#include <unordered_map>
|
||||
#include "include/train/train_loop.h"
|
||||
|
||||
using GraphPoint = std::pair<int, float>;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack {
|
||||
public:
|
||||
explicit ClassificationTrainAccuracyMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {}
|
||||
|
||||
virtual ~ClassificationTrainAccuracyMonitor() = default;
|
||||
void Begin(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
|
||||
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
|
||||
|
||||
private:
|
||||
std::vector<GraphPoint> accuracies_;
|
||||
int print_every_n_ = 0;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <climits>
|
||||
#include <unordered_map>
|
||||
#include "include/train/train_loop_callback.h"
|
||||
|
||||
using GraphPoint = std::pair<int, float>;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
class LossMonitor : public session::TrainLoopCallBack {
|
||||
public:
|
||||
explicit LossMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {}
|
||||
virtual ~LossMonitor() = default;
|
||||
void Begin(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
|
||||
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
const std::vector<GraphPoint> &GetLossPoints() const { return losses_; }
|
||||
|
||||
private:
|
||||
std::vector<GraphPoint> losses_;
|
||||
int print_every_n_;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_
|
@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include "include/train/train_loop_callback.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
constexpr int DONT_UPDATE_LR = 0;
|
||||
constexpr int UPDATE_LR = 1;
|
||||
|
||||
using LR_Lambda = std::function<int(float *lr, int epoch, void *cb_data)>;
|
||||
|
||||
/// \brief Multiply the LR by a factor of gamma every epoch
|
||||
int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication);
|
||||
|
||||
/// \brief Multiply the LR by a factor of gamma every step_size
|
||||
int StepLRLambda(float *lr, int epoch, void *step_size);
|
||||
struct StepLRLambda {
|
||||
StepLRLambda(int step, float g) : step_size(step), gamma(g) {}
|
||||
|
||||
int step_size; // period of LR decay
|
||||
float gamma; // LR decay factor
|
||||
};
|
||||
|
||||
class LRScheduler : public session::TrainLoopCallBack {
|
||||
public:
|
||||
explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step_ = 1);
|
||||
virtual ~LRScheduler() = default;
|
||||
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
|
||||
|
||||
private:
|
||||
LR_Lambda lambda_func_;
|
||||
void *lr_data_ = nullptr;
|
||||
int step_ = 1;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_
|
@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include "include/train/train_loop_callback.h"
|
||||
#include "include/train_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
||||
class TrainLoop {
|
||||
public:
|
||||
/// \brief Static method to create a TrainLoop object
|
||||
///
|
||||
/// \param[in] filename Filename to read flatbuffer from
|
||||
/// \param[in] context Defines the context of the session to be created
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite TrainLoop
|
||||
static TrainLoop *CreateTrainLoop(const std::string &model_filename, lite::Context *context, int batch_size = -1);
|
||||
|
||||
/// \brief Class destructor
|
||||
virtual ~TrainLoop() = default;
|
||||
|
||||
/// \brief Resets the epoch counter
|
||||
///
|
||||
/// \return 0 on success or -1 in case of error
|
||||
virtual int Reset() = 0; // resets the epoch counter to 0.
|
||||
|
||||
/// \brief Accessor to the TrainSession
|
||||
///
|
||||
/// \return pointer of the train_session
|
||||
virtual session::TrainSession *train_session() = 0;
|
||||
|
||||
/// \brief Accessor to the Session KernelCallbacks
|
||||
///
|
||||
/// \param[in] before Define a call_back_function to be called before running each node.
|
||||
/// \param[in] after Define a call_back_function called after running each node.
|
||||
///
|
||||
/// \return 0 on success or -1 in case of error
|
||||
virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0;
|
||||
|
||||
/// \brief Performs the training Loop
|
||||
///
|
||||
/// \param[in] epoch The number of epochs to run
|
||||
/// \param[in] cbs A vector of TrainLoopCallBack objects
|
||||
///
|
||||
/// \return 0 on success or -1 in case of error
|
||||
virtual int Train(int epochs, std::vector<TrainLoopCallBack *> cbs) = 0;
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
|
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
||||
class TrainSession;
|
||||
class TrainLoop;
|
||||
|
||||
struct TrainLoopCallBackData {
|
||||
TrainLoopCallBackData(bool train_mode, int epoch, TrainSession *session, TrainLoop *loop)
|
||||
: train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {}
|
||||
|
||||
bool train_mode_; /**< training mode of TrainSession object */
|
||||
unsigned int epoch_; /**< the current training epoch (starts at 0) */
|
||||
unsigned int step_ = 0; /**< the current step within the epoch */
|
||||
TrainSession *session_; /**< pointer to the TrainSession */
|
||||
TrainLoop *loop_;
|
||||
};
|
||||
|
||||
constexpr int RET_CONTINUE = 0;
|
||||
constexpr int RET_STOP_TRAINING = 1;
|
||||
constexpr int RET_EXIT = 2;
|
||||
|
||||
class TrainLoopCallBack {
|
||||
public:
|
||||
virtual ~TrainLoopCallBack() = default;
|
||||
virtual void Begin(const TrainLoopCallBackData &cb_data) {}
|
||||
virtual void End(const TrainLoopCallBackData &cb_data) {}
|
||||
virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {}
|
||||
virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; }
|
||||
virtual void StepBegin(const TrainLoopCallBackData &cb_data) {}
|
||||
virtual void StepEnd(const TrainLoopCallBackData &cb_data) {}
|
||||
};
|
||||
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
|
@ -0,0 +1,178 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.lite;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
|
||||
public class TrainSession {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
private long sessionPtr;
|
||||
|
||||
public TrainSession() {
|
||||
this.sessionPtr = 0;
|
||||
}
|
||||
|
||||
public boolean init(String modelFilename, MSConfig config) {
|
||||
this.sessionPtr = createSession(modelFilename, config.getMSConfigPtr());
|
||||
return this.sessionPtr != 0;
|
||||
}
|
||||
|
||||
public long getSessionPtr() {
|
||||
return sessionPtr;
|
||||
}
|
||||
|
||||
public void bindThread(boolean if_bind) {
|
||||
this.bindThread(this.sessionPtr, if_bind);
|
||||
}
|
||||
|
||||
public boolean runGraph() {
|
||||
return this.runGraph(this.sessionPtr);
|
||||
}
|
||||
|
||||
public List<MSTensor> getInputs() {
|
||||
List<Long> ret = this.getInputs(this.sessionPtr);
|
||||
ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
|
||||
for (Long ms_tensor_addr : ret) {
|
||||
MSTensor msTensor = new MSTensor(ms_tensor_addr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
public MSTensor getInputsByTensorName(String tensorName) {
|
||||
Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName);
|
||||
if(tensor_addr == null){
|
||||
return null;
|
||||
}
|
||||
MSTensor msTensor = new MSTensor(tensor_addr);
|
||||
return msTensor;
|
||||
}
|
||||
|
||||
public List<MSTensor> getOutputsByNodeName(String nodeName) {
|
||||
List<Long> ret = this.getOutputsByNodeName(this.sessionPtr, nodeName);
|
||||
ArrayList<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
public Map<String, MSTensor> getOutputMapByTensor() {
|
||||
Map<String, Long> ret = this.getOutputMapByTensor(this.sessionPtr);
|
||||
Map<String, MSTensor> tensorMap = new HashMap<>();
|
||||
Set<Map.Entry<String, Long>> entrySet = ret.entrySet();
|
||||
for (Map.Entry<String, Long> entry : entrySet) {
|
||||
String name = entry.getKey();
|
||||
Long msTensorAddr = entry.getValue();
|
||||
tensorMap.put(name, new MSTensor(msTensorAddr));
|
||||
}
|
||||
return tensorMap;
|
||||
}
|
||||
|
||||
public List<String> getOutputTensorNames() {
|
||||
return getOutputTensorNames(this.sessionPtr);
|
||||
}
|
||||
|
||||
public MSTensor getOutputByTensorName(String tensorName) {
|
||||
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
|
||||
if(tensor_addr == null){
|
||||
return null;
|
||||
}
|
||||
return new MSTensor(tensor_addr);
|
||||
}
|
||||
|
||||
public void free() {
|
||||
this.free(this.sessionPtr);
|
||||
this.sessionPtr = 0;
|
||||
}
|
||||
|
||||
public boolean resize(List<MSTensor> inputs, int[][] dims) {
|
||||
long[] inputs_array = new long[inputs.size()];
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
inputs_array[i] = inputs.get(i).getMSTensorPtr();
|
||||
}
|
||||
return this.resize(this.sessionPtr, inputs_array, dims);
|
||||
}
|
||||
|
||||
public boolean saveToFile(String modelFilename) {
|
||||
return this.saveToFile(this.sessionPtr, modelFilename);
|
||||
}
|
||||
|
||||
public boolean train() {
|
||||
return this.train(this.sessionPtr);
|
||||
}
|
||||
|
||||
public boolean eval() {
|
||||
return this.eval(this.sessionPtr);
|
||||
}
|
||||
|
||||
public boolean isTrain() {
|
||||
return this.isTrain(this.sessionPtr);
|
||||
}
|
||||
|
||||
public boolean isEval() {
|
||||
return this.isEval(this.sessionPtr);
|
||||
}
|
||||
|
||||
public boolean setLearningRate(float learning_rate) {
|
||||
return this.setLearningRate(this.sessionPtr, learning_rate);
|
||||
}
|
||||
|
||||
private native long createSession(String modelFilename, long msConfigPtr);
|
||||
|
||||
private native void bindThread(long sessionPtr, boolean if_bind);
|
||||
|
||||
private native boolean runGraph(long sessionPtr);
|
||||
|
||||
private native List<Long> getInputs(long sessionPtr);
|
||||
|
||||
private native long getInputsByTensorName(long sessionPtr, String tensorName);
|
||||
|
||||
private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);
|
||||
|
||||
private native Map<String, Long> getOutputMapByTensor(long sessionPtr);
|
||||
|
||||
private native List<String> getOutputTensorNames(long sessionPtr);
|
||||
|
||||
private native long getOutputByTensorName(long sessionPtr, String tensorName);
|
||||
|
||||
private native void free(long sessionPtr);
|
||||
|
||||
private native boolean resize(long sessionPtr, long[] inputs, int[][] dims);
|
||||
|
||||
private native boolean saveToFile(long sessionPtr, String modelFilename);
|
||||
|
||||
private native boolean train(long sessionPtr);
|
||||
|
||||
private native boolean eval(long sessionPtr);
|
||||
|
||||
private native boolean isTrain(long sessionPtr);
|
||||
|
||||
private native boolean isEval(long sessionPtr);
|
||||
|
||||
private native boolean setLearningRate(long sessionPtr, float learning_rate);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue