|
|
|
@ -37,7 +37,6 @@ class CoderFlags : public virtual FlagParser {
|
|
|
|
|
CoderFlags() {
|
|
|
|
|
AddFlag(&CoderFlags::model_path_, "modelPath", "Input model path", "");
|
|
|
|
|
AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", ".");
|
|
|
|
|
AddFlag(&CoderFlags::code_module_name_, "moduleName", "Input code module name", "");
|
|
|
|
|
AddFlag(&CoderFlags::target_, "target", "generated code target, x86| ARM32M| ARM32A| ARM64", "x86");
|
|
|
|
|
AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Inference | Train", "Inference");
|
|
|
|
|
AddFlag(&CoderFlags::support_parallel_, "supportParallel", "whether support parallel launch, true | false", false);
|
|
|
|
@ -48,7 +47,6 @@ class CoderFlags : public virtual FlagParser {
|
|
|
|
|
|
|
|
|
|
std::string model_path_;
|
|
|
|
|
bool support_parallel_{false};
|
|
|
|
|
std::string code_module_name_;
|
|
|
|
|
std::string code_path_;
|
|
|
|
|
std::string code_mode_;
|
|
|
|
|
bool debug_mode_{false};
|
|
|
|
@ -84,6 +82,27 @@ int Coder::Run(const std::string &model_path) {
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int Configurator::ParseProjDir(std::string model_path) {
|
|
|
|
|
// split model_path to get model file name
|
|
|
|
|
proj_dir_ = model_path;
|
|
|
|
|
size_t found = proj_dir_.find_last_of("/\\");
|
|
|
|
|
if (found != std::string::npos) {
|
|
|
|
|
proj_dir_ = proj_dir_.substr(found + 1);
|
|
|
|
|
}
|
|
|
|
|
found = proj_dir_.find(".ms");
|
|
|
|
|
if (found != std::string::npos) {
|
|
|
|
|
proj_dir_ = proj_dir_.substr(0, found);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "model file's name must be end with \".ms\".";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (proj_dir_.size() == 0) {
|
|
|
|
|
proj_dir_ = "net";
|
|
|
|
|
MS_LOG(WARNING) << "parse model's name failed, use \"net\" instead.";
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int Coder::Init(const CoderFlags &flags) const {
|
|
|
|
|
static const std::map<std::string, Target> kTargetMap = {
|
|
|
|
|
{"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}};
|
|
|
|
@ -91,6 +110,17 @@ int Coder::Init(const CoderFlags &flags) const {
|
|
|
|
|
Configurator *config = Configurator::GetInstance();
|
|
|
|
|
|
|
|
|
|
std::vector<std::function<bool()>> parsers;
|
|
|
|
|
parsers.emplace_back([&flags, config]() -> bool {
|
|
|
|
|
if (!FileExists(flags.model_path_)) {
|
|
|
|
|
MS_LOG(ERROR) << "model_path \"" << flags.model_path_ << "\" is not valid";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (config->ParseProjDir(flags.model_path_) != RET_OK) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
parsers.emplace_back([&flags, config]() -> bool {
|
|
|
|
|
auto target_item = kTargetMap.find(flags.target_);
|
|
|
|
|
MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + flags.target_);
|
|
|
|
@ -119,20 +149,6 @@ int Coder::Init(const CoderFlags &flags) const {
|
|
|
|
|
return true;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
parsers.emplace_back([&flags, config]() -> bool {
|
|
|
|
|
if (!FileExists(flags.model_path_)) {
|
|
|
|
|
MS_LOG(ERROR) << "model_path \"" << flags.model_path_ << "\" is not valid";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (flags.code_module_name_.empty() || isdigit(flags.code_module_name_.at(0))) {
|
|
|
|
|
MS_LOG(ERROR) << "code_gen code module name " << flags.code_module_name_
|
|
|
|
|
<< " not valid: it must be given and the first char could not be number";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
config->set_module_name(flags.code_module_name_);
|
|
|
|
|
return true;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
parsers.emplace_back([&flags, config]() -> bool {
|
|
|
|
|
const std::string slash = std::string(kSlash);
|
|
|
|
|
if (!flags.code_path_.empty() && !DirExists(flags.code_path_)) {
|
|
|
|
@ -141,18 +157,18 @@ int Coder::Init(const CoderFlags &flags) const {
|
|
|
|
|
}
|
|
|
|
|
config->set_code_path(flags.code_path_);
|
|
|
|
|
if (flags.code_path_.empty()) {
|
|
|
|
|
std::string path = ".." + slash + config->module_name();
|
|
|
|
|
std::string path = ".." + slash + config->proj_dir();
|
|
|
|
|
config->set_code_path(path);
|
|
|
|
|
} else {
|
|
|
|
|
if (flags.code_path_.substr(flags.code_path_.size() - 1, 1) != slash) {
|
|
|
|
|
std::string path = flags.code_path_ + slash + config->module_name();
|
|
|
|
|
std::string path = flags.code_path_ + slash + config->proj_dir();
|
|
|
|
|
config->set_code_path(path);
|
|
|
|
|
} else {
|
|
|
|
|
std::string path = flags.code_path_ + config->module_name();
|
|
|
|
|
std::string path = flags.code_path_ + config->proj_dir();
|
|
|
|
|
config->set_code_path(path);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return InitProjDirs(flags.code_path_, config->module_name()) != RET_ERROR;
|
|
|
|
|
return InitProjDirs(flags.code_path_, config->proj_dir()) != RET_ERROR;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
if (!std::all_of(parsers.begin(), parsers.end(), [](auto &parser) -> bool { return parser(); })) {
|
|
|
|
@ -162,17 +178,15 @@ int Coder::Init(const CoderFlags &flags) const {
|
|
|
|
|
}
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
config->set_module_name(kModelName);
|
|
|
|
|
|
|
|
|
|
auto print_parameter = [](auto name, auto value) {
|
|
|
|
|
MS_LOG(INFO) << std::setw(20) << std::left << name << "= " << value;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
print_parameter("modelPath", flags.model_path_);
|
|
|
|
|
print_parameter("projectName", config->proj_dir());
|
|
|
|
|
print_parameter("target", config->target());
|
|
|
|
|
print_parameter("codePath", config->code_path());
|
|
|
|
|
print_parameter("codeMode", config->code_mode());
|
|
|
|
|
print_parameter("codeModuleName", config->module_name());
|
|
|
|
|
print_parameter("debugMode", config->debug_mode());
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|