/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef PREDICT_COMMON_FLAG_PARSER_H_ #define PREDICT_COMMON_FLAG_PARSER_H_ #include #include #include #include #include "common/utils.h" #include "common/option.h" namespace mindspore { namespace predict { struct FlagInfo; struct Nothing {}; class FlagParser { public: FlagParser() { AddFlag(&FlagParser::help, "help", "print usage message", false); } virtual ~FlagParser() = default; // only support read flags from command line virtual Option ParseFlags(int argc, const char *const *argv, bool supportUnknown = false, bool supportDuplicate = false); std::string Usage(const Option &usgMsg = Option(None())) const; template void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); template void AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); template void AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo); // Option-type fields template void AddFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo); bool help; protected: std::string binName; Option usageMsg; private: struct FlagInfo { std::string flagName; bool isRequired; bool isBoolean; std::string helpInfo; bool isParsed; std::function(FlagParser *, const std::string &)> parse; }; inline void AddFlag(const FlagInfo &flag); // construct a temporary flag template void ConstructFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag); // construct a temporary flag template void ConstructFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag); Option InnerParseFlags(std::multimap> *values); bool GetRealFlagName(const std::string &oriFlagName, std::string *flagName); std::map flags; }; // convert to std::string template Option ConvertToString(T Flags::*t, const FlagParser &baseFlag) { const Flags *flag = dynamic_cast(&baseFlag); if (flag != nullptr) { return std::to_string(flag->*t); } return Option(None()); } // construct for a Option-type flag template void FlagParser::ConstructFlag(Option Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag) { if (flag == nullptr) { MS_LOGE("FlagInfo is nullptr"); return; } flag->flagName = flagName; flag->helpInfo = helpInfo; flag->isBoolean = typeid(T) == typeid(bool); flag->isParsed = false; } // construct a temporary flag template void FlagParser::ConstructFlag(T Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag) { if (flag == nullptr) { MS_LOGE("FlagInfo is nullptr"); return; } if (t1 == nullptr) { MS_LOGE("t1 is nullptr"); return; } flag->flagName = flagName; flag->helpInfo = helpInfo; flag->isBoolean = typeid(T) == typeid(bool); flag->isParsed = false; } inline void FlagParser::AddFlag(const FlagInfo &flagItem) { flags[flagItem.flagName] = flagItem; } template void FlagParser::AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo) { if (t == nullptr) { MS_LOGE("t1 is nullptr"); return; } Flags *flag = dynamic_cast(this); if (flag == nullptr) { MS_LOGI("dynamic_cast failed"); return; } FlagInfo flagItem; // flagItem is as a output parameter ConstructFlag(t, flagName, helpInfo, &flagItem); flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option { Flags *flag = dynamic_cast(base); if (base != nullptr) { Option ret = Option(GenericParseValue(value)); if (ret.IsNone()) { return Option(None()); } else { flag->*t = ret.Get(); } } return Option(Nothing()); }; flagItem.isRequired = true; flagItem.helpInfo += !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; flagItem.helpInfo += ")"; // add this flag to a std::map AddFlag(flagItem); } template void FlagParser::AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) { if (t1 == nullptr) { MS_LOGE("t1 is nullptr"); return; } FlagInfo flagItem; // flagItem is as a output parameter ConstructFlag(t1, flagName, helpInfo, flagItem); flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option { if (base != nullptr) { Option ret = Option(GenericParseValue(value)); if (ret.IsNone()) { return Option(None()); } else { *t1 = ret.Get(); } } return Option(Nothing()); }; flagItem.isRequired = false; *t1 = t2; flagItem.helpInfo += !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; flagItem.helpInfo += ToString(t2).Get(); flagItem.helpInfo += ")"; // add this flag to a std::map AddFlag(flagItem); } template void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) { if (t1 == nullptr) { MS_LOGE("t1 is nullptr"); return; } Flags *flag = dynamic_cast(this); if (flag == nullptr) { MS_LOGI("dynamic_cast failed"); return; } FlagInfo flagItem; // flagItem is as a output parameter ConstructFlag(t1, flagName, helpInfo, &flagItem); flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option { Flags *flag = dynamic_cast(base); if (base != nullptr) { Option ret = Option(GenericParseValue(value)); if (ret.IsNone()) { return Option(None()); } else { flag->*t1 = ret.Get(); } } return Option(Nothing()); }; flagItem.isRequired = false; flag->*t1 = t2; flagItem.helpInfo += !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; flagItem.helpInfo += ToString(t2).Get(); flagItem.helpInfo += ")"; // add this flag to a std::map AddFlag(flagItem); } // option-type add flag template void FlagParser::AddFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo) { if (t == nullptr) { MS_LOGE("t is nullptr"); return; } Flags *flag = dynamic_cast(this); if (flag == nullptr) { MS_LOGE("dynamic_cast failed"); return; } FlagInfo flagItem; // flagItem is as a output parameter ConstructFlag(t, flagName, helpInfo, &flagItem); flagItem.isRequired = false; flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option { Flags *flag = dynamic_cast(base); if (base != nullptr) { Option ret = Option(GenericParseValue(value)); if (ret.IsNone()) { return Option(None()); } else { flag->*t = Option(Some(ret.Get())); } } return Option(Nothing()); }; // add this flag to a std::map AddFlag(flagItem); } } // namespace predict } // namespace mindspore #endif // PREDICT_COMMON_FLAG_PARSER_H_