!12182 [MSLITE]feature, micro support train
From: @yangjie159 Reviewed-by: @wangchengyuan,@HilbertDavid Signed-off-by: @wangchengyuanpull/12182/MERGE
commit
9d2e07ae24
@ -0,0 +1,180 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/generator/component/train_component.h"
|
||||
#include <string>
|
||||
#include "coder/utils/type_cast.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
|
||||
void CodeTrainParams(std::ofstream &ofs) {
|
||||
ofs << "struct TrainParameter {\n"
|
||||
" float beta1_;\n"
|
||||
" float beta2_;\n"
|
||||
" float epsilon_;\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"enum EarlyStopType {\n"
|
||||
" Diff = 0,\n"
|
||||
" WeigthDiff = 1,\n"
|
||||
" Abs = 2,\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"struct EarlyStop {\n"
|
||||
" enum EarlyStopType type;\n"
|
||||
" float tolerate;\n"
|
||||
"};\n\n";
|
||||
}
|
||||
|
||||
void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name) {
|
||||
ofs << "/**\n"
|
||||
" *\n"
|
||||
" * @param size, return the number of features\n"
|
||||
" * @return, the address of features\n"
|
||||
" */\n"
|
||||
<< "FeatureParam *" << module_name << "_GetFeatures(int *size);\n\n";
|
||||
ofs << "/**\n"
|
||||
" *\n"
|
||||
" * @param features, the address of features\n"
|
||||
" * @param size, the number of features\n"
|
||||
" * @return, status\n"
|
||||
" */\n"
|
||||
<< "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size);\n\n";
|
||||
}
|
||||
|
||||
void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name,
|
||||
const std::unique_ptr<CoderContext> &ctx) {
|
||||
size_t features_num = 0;
|
||||
ofs << "static FeatureParam feature_params[] = {\n";
|
||||
for (const auto &item : ctx->saved_weights()) {
|
||||
std::string addr = item.first;
|
||||
Tensor *tensor = item.second;
|
||||
if (tensor->tensor_name().empty()) {
|
||||
MS_LOG(ERROR) << "exist empty feature";
|
||||
continue;
|
||||
}
|
||||
ofs << "\t{\"" << tensor->tensor_name() << "\", " << addr << ", " << tensor->ElementsNum() << ", "
|
||||
<< EnumMicroTensorDataType(tensor->data_type()) << "}, \n";
|
||||
features_num++;
|
||||
}
|
||||
ofs << "};\n";
|
||||
|
||||
ofs << "FeatureParam *" << module_name << "_GetFeatures(int *size) {\n"
|
||||
<< " *size = " << features_num << ";\n"
|
||||
<< " return feature_params;\n"
|
||||
"}\n\n";
|
||||
|
||||
ofs << "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size) {\n"
|
||||
<< " for (int i = 0; i < size; ++i) {\n"
|
||||
" FeatureParam *src = features + i;\n"
|
||||
" FeatureParam dst;\n"
|
||||
" // find the dst feature\n"
|
||||
" bool is_find = false;\n"
|
||||
<< " for (int j = 0; j < " << features_num << "; ++j) {\n"
|
||||
<< " if (strcmp(src->name, feature_params[j].name) == 0) {\n"
|
||||
" dst = feature_params[j];\n"
|
||||
" is_find = true;\n"
|
||||
" break;\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" if (!is_find) {\n"
|
||||
" MICRO_ERROR(\"invalid feature param: %s\", src->name);\n"
|
||||
" return RET_ERROR;\n"
|
||||
" }\n"
|
||||
" if (src->elenums != dst.elenums) {\n"
|
||||
" MICRO_ERROR(\"feature %s elenums is mismatch, src: %lu, dst: %lu\", src->name, src->elenums, "
|
||||
"dst.elenums);\n"
|
||||
" return RET_ERROR;\n"
|
||||
" }\n"
|
||||
" memcpy(dst.data, src->data, src->elenums * sizeof(float));\n"
|
||||
" }\n"
|
||||
" MICRO_INFO(\"update features map success\");\n"
|
||||
" return RET_OK;\n"
|
||||
"}\n\n";
|
||||
}
|
||||
|
||||
void CodeTrainState(std::ofstream &ofs, const std::string &module_name) {
|
||||
ofs << "/**\n"
|
||||
" * Train Function\n"
|
||||
" * @param epoch, the train epoch\n"
|
||||
" * @param iterations, which is equal to batch_num, the number of iterations of each epoch\n"
|
||||
" * @param use_train_param, default parameters already exists, such as the momentum, user can update these\n"
|
||||
" * parameters to improve the accuracy\n"
|
||||
" * @param parameter, the TrainParameter contains epsilon/beta1/beta2\n"
|
||||
" * @return status\n"
|
||||
" */\n"
|
||||
<< "int " << module_name
|
||||
<< "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter *parameter, "
|
||||
"const struct EarlyStop *early_stop);\n\n";
|
||||
}
|
||||
|
||||
void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr<CoderContext> &ctx) {
|
||||
std::vector<Tensor *> inputs = ctx->graph_inputs();
|
||||
size_t inputs_num = inputs.size();
|
||||
auto inputs_tostring = [&]() {
|
||||
std::string result;
|
||||
result += "{";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
result += ctx->input_name() + std::to_string(i) + ", ";
|
||||
}
|
||||
result += "}";
|
||||
return result;
|
||||
};
|
||||
auto wrap = [](int i) { return "[" + std::to_string(i) + "]"; };
|
||||
auto offset_inputs = [&]() {
|
||||
std::string src = "origin_inputs";
|
||||
std::string dst = "input_ptr";
|
||||
std::string result;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
result += dst + wrap(i) += " = " + src + wrap(i) + " + j * " + std::to_string(inputs[i]->Size()) + ";\n";
|
||||
}
|
||||
return result;
|
||||
};
|
||||
auto varify_inputs = [&]() {
|
||||
std::string result;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
result += "origin_input" + wrap(i) + " + iterations * " + std::to_string(inputs[i]->Size()) + " == NULL";
|
||||
i < inputs.size() - 1 ? result += " || " : result += "";
|
||||
}
|
||||
return result;
|
||||
};
|
||||
ofs << "int " << module_name
|
||||
<< "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter "
|
||||
"*parameter, const struct EarlyStop *early_stop) {\n"
|
||||
" if (iterations <= 0 || epoch <= 0) {\n"
|
||||
" MICRO_ERROR(\"error iterations or epoch!, epoch:%d, iterations:%d\", epoch, iterations);\n"
|
||||
" return RET_ERROR;\n"
|
||||
" }\n"
|
||||
" MICRO_INFO(\"train epoch: %d, batch_num: %d\", epoch, iterations);\n"
|
||||
<< " const void *origin_input[] = " << inputs_tostring() << ";\n";
|
||||
ofs << " if (" << varify_inputs() << ") {\n"
|
||||
<< " MICRO_ERROR(\"input data is invalid, epoch: %d, iterations: %d\", epoch, iterations);\n"
|
||||
" return RET_ERROR;\n"
|
||||
" }\n";
|
||||
ofs << " for (int i = 0; i < epoch; ++i) {\n"
|
||||
<< " const void *input_ptr[" << inputs_num << "];\n"
|
||||
<< " float loss = 0;\n"
|
||||
<< " for (int j = 0; j < iterations; ++j) {\n"
|
||||
<< " " << offset_inputs() << "\n"
|
||||
<< " " << module_name << "_SetInputs(input_ptr, " << inputs_num << ");\n"
|
||||
<< " " << module_name << "_Inference();\n"
|
||||
<< " loss = " << module_name << "_ComputeLossAndGradient();\n"
|
||||
<< " }\n"
|
||||
" }\n"
|
||||
" return RET_OK;\n"
|
||||
"};\n\n";
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
@ -0,0 +1,39 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <fstream>
|
||||
#include "src/tensor.h"
|
||||
#include "coder/context.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
void CodeTrainParams(std::ofstream &ofs);
|
||||
|
||||
void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name);
|
||||
void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name,
|
||||
const std::unique_ptr<CoderContext> &ctx);
|
||||
|
||||
void CodeTrainState(std::ofstream &ofs, const std::string &module_name);
|
||||
void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr<CoderContext> &ctx);
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_
|
@ -0,0 +1,108 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/generator/train/train_generator.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "coder/generator/component/common_component.h"
|
||||
#include "coder/generator/component/benchmark_component.h"
|
||||
#include "coder/generator/component/train_component.h"
|
||||
#include "coder/generator/component/const_blocks/license.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
void TrainGenerator::CodeGradientFunc(std::ofstream &ofs) const {
|
||||
ofs << "float " << config_->module_name() << "_ComputeLossAndGradient() {\n";
|
||||
ofs << " float loss = 0;\n";
|
||||
for (const auto &block : ctx_->train_blocks()) {
|
||||
ofs << " {\n" << block << " }\n";
|
||||
}
|
||||
ofs << " return loss;\n";
|
||||
ofs << "}\n";
|
||||
}
|
||||
|
||||
int TrainGenerator::CodeNetHFile() {
|
||||
std::string net_include_file = net_inc_file_path_ + net_inc_hfile_;
|
||||
std::ofstream ofs(net_include_file);
|
||||
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
|
||||
MS_LOG(INFO) << "write " << net_include_file;
|
||||
ofs << g_hwLicense;
|
||||
if (config_->code_mode() == CodeMode::Code_Inference) {
|
||||
ofs << "#include \"src/runtime/thread_pool.h\"\n";
|
||||
}
|
||||
ofs << "#include \"microtensor.h\"\n\n";
|
||||
CodeTrainParams(ofs);
|
||||
CodeInputAndOutputState(ofs, config_->module_name());
|
||||
if (is_get_quant_args_) {
|
||||
CodeGraphQuantArgsState(ofs, config_->module_name());
|
||||
}
|
||||
if (config_->is_weight_file()) {
|
||||
CodeInitWeightState(ofs, config_->module_name());
|
||||
}
|
||||
CodeManageResourceState(ofs, config_->module_name());
|
||||
CodeInferenceState(ofs, config_->module_name());
|
||||
CodeFeaturesState(ofs, config_->module_name());
|
||||
CodeTrainState(ofs, config_->module_name());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TrainGenerator::CodeNetCFile() {
|
||||
std::string net_impl_file = net_src_file_path_ + net_src_cfile_;
|
||||
std::ofstream ofs(net_impl_file);
|
||||
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
|
||||
MS_LOG(INFO) << "write " << net_impl_file;
|
||||
CodeSourceFileInclude(ofs, net_weight_hfile_, net_inc_hfile_);
|
||||
CodeInputAndOutputImplement(ofs, config_->module_name(), ctx_);
|
||||
CodeInitResourceImplement(ofs, config_->module_name(), ctx_);
|
||||
CodeFreeResourceImplement(ofs, config_->module_name(), ctx_);
|
||||
CodeFeaturesImplement(ofs, config_->module_name(), ctx_);
|
||||
if (is_get_quant_args_) {
|
||||
CodeGraphQuantArgsImplement(ofs, config_->module_name(), ctx_);
|
||||
}
|
||||
CodeNetRunFunc(ofs);
|
||||
CodeGradientFunc(ofs);
|
||||
CodeTrainImplement(ofs, config_->module_name(), ctx_);
|
||||
ofs.close();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TrainGenerator::CodeBenchmarkFile() {
|
||||
std::string net_main_impl_file = net_main_file_path_ + net_main_cfile_;
|
||||
std::ofstream ofs(net_main_impl_file);
|
||||
MS_LOG(INFO) << "write " << net_main_impl_file;
|
||||
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
|
||||
std::vector<Tensor *> inputs = ctx_->graph_inputs();
|
||||
size_t inputs_num = inputs.size();
|
||||
|
||||
CodeBenchmarkHeader(ofs, net_inc_hfile_);
|
||||
CodeBenchmarkUsage(ofs);
|
||||
CodeBenchmarkWarmup(ofs, config_->module_name());
|
||||
|
||||
CodeBenchmarkSetInputs(ofs, config_->module_name(), ctx_);
|
||||
CodeBenchmarkSetBuffer(ofs, config_->module_name());
|
||||
if (config_->is_weight_file()) {
|
||||
CodeBenchmarkInitWeight(ofs, config_->module_name());
|
||||
}
|
||||
if (config_->code_mode() == CodeMode::Code_Inference) {
|
||||
CodeBenchmarkConfigThread(ofs);
|
||||
}
|
||||
CodeBenchmarkInference(ofs, config_->module_name());
|
||||
CodeBenchmarkPrintOutputs(ofs, config_->module_name());
|
||||
|
||||
CodeBenchmarkFreeResourse(ofs, config_->module_name(), inputs_num);
|
||||
ofs.close();
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
@ -0,0 +1,39 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "micro/coder/generator/generator.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
class TrainGenerator : public Generator {
|
||||
public:
|
||||
explicit TrainGenerator(std::unique_ptr<CoderContext> ctx) : Generator(std::move(ctx)) {}
|
||||
~TrainGenerator() override = default;
|
||||
|
||||
private:
|
||||
int CodeNetHFile() override;
|
||||
int CodeNetCFile() override;
|
||||
|
||||
int CodeBenchmarkFile() override;
|
||||
|
||||
void CodeGradientFunc(std::ofstream &ofs) const;
|
||||
};
|
||||
} // namespace mindspore::lite::micro
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_
|
@ -0,0 +1,95 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/train.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
|
||||
std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) {
|
||||
std::set<OperatorCoder *> subgraph;
|
||||
std::queue<OperatorCoder *> to_visit;
|
||||
to_visit.push(edge);
|
||||
while (!to_visit.empty()) {
|
||||
size_t size = to_visit.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
OperatorCoder *curr = to_visit.front();
|
||||
to_visit.pop();
|
||||
if (subgraph.find(curr) != subgraph.end()) {
|
||||
continue;
|
||||
}
|
||||
subgraph.insert(curr);
|
||||
for (const auto &op : curr->input_ops()) {
|
||||
to_visit.push(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto item = subgraph.find(edge);
|
||||
if (item == subgraph.end()) {
|
||||
MS_LOG(ERROR) << "failed to find the edge in the subgraph";
|
||||
return subgraph;
|
||||
}
|
||||
// erase edge operator coder from subgraph
|
||||
subgraph.erase(item);
|
||||
return subgraph;
|
||||
}
|
||||
|
||||
int Train::TransformGraphForTrain(CoderContext *context, const std::vector<std::unique_ptr<OperatorCoder>> &op_coders) {
|
||||
const std::set<schema::PrimitiveType> loss_types = {schema::PrimitiveType_SoftmaxCrossEntropy,
|
||||
schema::PrimitiveType_SparseSoftmaxCrossEntropy,
|
||||
schema::PrimitiveType_BinaryCrossEntropy,
|
||||
schema::PrimitiveType_SmoothL1Loss,
|
||||
schema::PrimitiveType_SmoothL1LossGrad,
|
||||
schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
|
||||
schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad};
|
||||
OperatorCoder *loss_op = nullptr;
|
||||
for (const auto &opcoder : op_coders) {
|
||||
auto primitive_type = static_cast<schema::PrimitiveType>(opcoder->primitive()->Type());
|
||||
auto item = loss_types.find(primitive_type);
|
||||
if (item != loss_types.end()) {
|
||||
loss_op = opcoder.get();
|
||||
break;
|
||||
}
|
||||
}
|
||||
MS_CHECK_PTR(loss_op);
|
||||
size_t op_num = op_coders.size();
|
||||
std::vector<std::string> code_blocks = context->code_blocks();
|
||||
if (op_num != code_blocks.size()) {
|
||||
MS_LOG(INFO) << "the number of code blocks and op coders is not equal";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::set<OperatorCoder *> inference_ops = FindInferenceOpcoders(loss_op);
|
||||
std::vector<std::string> inferences_blocks;
|
||||
std::vector<std::string> train_blocks;
|
||||
for (size_t i = 0; i < op_num; ++i) {
|
||||
auto &opcoder = op_coders.at(i);
|
||||
std::string block = code_blocks.at(i);
|
||||
if (inference_ops.find(opcoder.get()) != inference_ops.end()) {
|
||||
inferences_blocks.push_back(block);
|
||||
}
|
||||
train_blocks.push_back(block);
|
||||
}
|
||||
context->set_inference_blocks(inferences_blocks);
|
||||
context->set_train_blocks(train_blocks);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite::micro
|
@ -0,0 +1,33 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "coder/context.h"
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
class Train {
|
||||
public:
|
||||
static int TransformGraphForTrain(CoderContext *context,
|
||||
const std::vector<std::unique_ptr<OperatorCoder>> &op_coders);
|
||||
};
|
||||
|
||||
} // namespace mindspore::lite::micro
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_
|
Loading…
Reference in new issue