optimize rdr mainly to reduce the redundant code

1. move common function to base recorder
2. add UpdataRdrEnable() for recorder manager, and update once
3. add delimiter virable for base recorder, set it as '.'
4. modify recorders to use the suffix parameter of GetFileRealPath()
5. improve the text in ms_context

Can not add RdrEnablePlatform because linux is supported only.
If os is linux, (!is_rdr_supported && rdr_enable) will be false always,
else, (is_rdr_supported && !rdr_enable) will be false always.
pull/12567/head
luopengting 4 years ago
parent 692d158f5c
commit 7914363d25

@ -104,6 +104,8 @@ void EnvConfigParser::ParseRdrSetting(const nlohmann::json &content) {
return;
}
has_rdr_setting_ = true;
auto rdr_enable = CheckJsonKeyExist(*rdr_setting, kRdrSettings, kEnable);
if (rdr_enable.has_value()) {
ParseRdrEnable(**rdr_enable);

@ -27,11 +27,14 @@ class EnvConfigParser {
public:
static EnvConfigParser &GetInstance() {
static EnvConfigParser instance;
instance.Parse();
return instance;
}
void Parse();
std::string config_path() const { return config_file_; }
bool has_rdr_setting() const { return has_rdr_setting_; }
bool rdr_enabled() const { return rdr_enabled_; }
std::string rdr_path() const { return rdr_path_; }
@ -42,7 +45,9 @@ class EnvConfigParser {
std::mutex lock_;
std::string config_file_{""};
bool already_parsed_{false};
bool rdr_enabled_{false};
bool has_rdr_setting_{false};
std::string rdr_path_{"./rdr/"};
std::string GetIfstreamString(const std::ifstream &ifstream);

@ -39,13 +39,13 @@ void BaseRecorder::SetFilename(const std::string &filename) {
std::optional<std::string> BaseRecorder::GetFileRealPath(const std::string &suffix) {
if (filename_.empty()) {
filename_ = module_ + "_" + tag_;
filename_ = module_ + delimiter_ + tag_;
if (!suffix.empty()) {
filename_ += "_" + suffix;
filename_ += delimiter_ + suffix;
}
filename_ += "_" + timestamp_;
filename_ += delimiter_ + timestamp_;
} else if (!suffix.empty()) {
filename_ += "_" + suffix;
filename_ += delimiter_ + suffix;
}
std::string file_path = directory_ + filename_;
auto realpath = Common::GetRealPath(file_path);

@ -65,6 +65,7 @@ class BaseRecorder {
void SetDirectory(const std::string &directory);
void SetFilename(const std::string &filename);
void SetModule(const std::string &module) { module_ = module; }
virtual void Export() {}
protected:
@ -73,6 +74,7 @@ class BaseRecorder {
std::string directory_;
std::string filename_;
std::string timestamp_; // year,month,day,hour,minute,second
std::string delimiter_{"."};
};
using BaseRecorderPtr = std::shared_ptr<BaseRecorder>;
} // namespace mindspore

@ -46,11 +46,11 @@ bool DumpGraphExeOrder(const std::string &filename, const std::vector<CNodePtr>
} // namespace
void GraphExecOrderRecorder::Export() {
auto realpath = GetFileRealPath();
auto realpath = GetFileRealPath(std::to_string(graph_id_));
if (!realpath.has_value()) {
return;
}
std::string real_file_path = realpath.value() + std::to_string(graph_id_);
std::string real_file_path = realpath.value() + ".txt";
DumpGraphExeOrder(real_file_path, exec_order_);
}
} // namespace mindspore

@ -30,7 +30,6 @@ class GraphExecOrderRecorder : public BaseRecorder {
GraphExecOrderRecorder(const std::string &module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id)
: BaseRecorder(module, tag), exec_order_(final_exec_order), graph_id_(graph_id) {}
void SetModule(const std::string &module) { module_ = module; }
void SetExecOrder(const std::vector<CNodePtr> &final_exec_order) { exec_order_ = final_exec_order; }
virtual void Export();

@ -65,6 +65,7 @@ void GraphRecorder::Export() {
}
std::string suffix = graph_id >= 0 ? std::to_string(graph_id) : "";
auto tmp_realpath = GetFileRealPath(suffix);
if (!tmp_realpath.has_value()) {
return;
}

@ -31,7 +31,6 @@ class GraphRecorder : public BaseRecorder {
const std::string &file_type)
: BaseRecorder(module, tag), func_graph_(graph), graph_type_(file_type) {}
~GraphRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
void SetGraphType(const std::string &file_type) { graph_type_ = file_type; }
void SetFuncGraph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
void SetDumpFlag(bool full_name) { full_name_ = full_name; }

@ -21,7 +21,35 @@
#include "mindspore/core/ir/func_graph.h"
namespace mindspore {
void RecorderManager::UpdateRdrEnable() {
static bool updated = false;
if (updated) {
return;
}
auto &config_parser = mindspore::EnvConfigParser::GetInstance();
rdr_enable_ = config_parser.rdr_enabled();
if (config_parser.has_rdr_setting()) {
#ifdef __linux__
if (!rdr_enable_) {
MS_LOG(WARNING) << "Please set the 'enable' as true using 'rdr' setting in file '" << config_parser.config_path()
<< "' if you want to use RDR.";
}
#else
if (rdr_enable_) {
MS_LOG(WARNING) << "The RDR only supports linux os currently.";
}
rdr_enable_ = false;
#endif
}
updated = true;
}
bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
if (!rdr_enable_) {
return false;
}
if (recorder == nullptr) {
MS_LOG(ERROR) << "register recorder module with nullptr.";
return false;
@ -33,10 +61,7 @@ bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
}
void RecorderManager::TriggerAll() {
auto &config_parser_ptr = mindspore::EnvConfigParser::GetInstance();
config_parser_ptr.Parse();
if (!config_parser_ptr.rdr_enabled()) {
MS_LOG(INFO) << "RDR is not enable.";
if (!rdr_enable_) {
return;
}

@ -31,9 +31,12 @@ class RecorderManager {
public:
static RecorderManager &Instance() {
static RecorderManager manager;
manager.UpdateRdrEnable();
return manager;
}
void UpdateRdrEnable();
bool RdrEnable() const { return rdr_enable_; }
bool RecordObject(const BaseRecorderPtr &recorder);
void TriggerAll();
void ClearAll();
@ -42,6 +45,8 @@ class RecorderManager {
RecorderManager() {}
~RecorderManager() {}
bool rdr_enable_{false};
mutable std::mutex mtx_;
// module, BaserRecorderPtrList
std::unordered_map<std::string, BaseRecorderPtrList> recorder_container_;

@ -65,6 +65,9 @@ namespace RDR {
#ifdef ENABLE_D
bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
const std::vector<TaskDebugInfoPtr> &task_debug_info_list, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
TaskDebugInfoRecorderPtr task_debug_info_recorder =
std::make_shared<TaskDebugInfoRecorder>(submodule_name, tag, task_debug_info_list, graph_id);
@ -73,9 +76,11 @@ bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
}
#endif // ENABLE_D
#ifdef __linux__
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
GraphRecorderPtr graph_recorder = std::make_shared<GraphRecorder>(submodule_name, tag, graph, file_type);
graph_recorder->SetDumpFlag(full_name);
@ -85,6 +90,9 @@ bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const Func
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
GraphExecOrderRecorderPtr graph_exec_order_recorder =
std::make_shared<GraphExecOrderRecorder>(submodule_name, tag, final_exec_order, graph_id);
@ -93,6 +101,9 @@ bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
}
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
StringRecorderPtr string_recorder = std::make_shared<StringRecorder>(submodule_name, tag, data, filename);
string_recorder->SetFilename(filename);
@ -102,6 +113,9 @@ bool RecordString(SubModuleId module, const std::string &tag, const std::string
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
StreamExecOrderRecorderPtr stream_exec_order_recorder =
std::make_shared<StreamExecOrderRecorder>(submodule_name, tag, graph_id, exec_order);
@ -113,68 +127,5 @@ void TriggerAll() { mindspore::RecorderManager::Instance().TriggerAll(); }
void ClearAll() { mindspore::RecorderManager::Instance().ClearAll(); }
#else
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) {
static bool already_printed = false;
std::string submodule_name = std::string(GetSubModuleName(module));
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os " << submodule_name;
return false;
}
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
void TriggerAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
void ClearAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
#endif // __linux__
} // namespace RDR
} // namespace mindspore

@ -98,7 +98,6 @@ class StreamExecOrderRecorder : public BaseRecorder {
exec_order_.push_back(std::move(exec_node_ptr));
}
}
void SetModule(const std::string &module) { module_ = module; }
virtual void Export();
private:

@ -26,14 +26,14 @@ class StringRecorder : public BaseRecorder {
StringRecorder() : BaseRecorder() {}
StringRecorder(const std::string &module, const std::string &tag, const std::string &data,
const std::string &filename)
: BaseRecorder(module, tag), data_(data), filename_(filename) {}
: BaseRecorder(module, tag), data_(data) {
SetFilename(filename);
}
~StringRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
virtual void Export();
private:
std::string data_;
std::string filename_;
};
using StringRecorderPtr = std::shared_ptr<StringRecorder>;
} // namespace mindspore

@ -258,8 +258,8 @@ class _Context:
def set_env_config_path(self, env_config_path):
"""Check and set env_config_path."""
if not self._context_handle.enable_dump_ir():
raise ValueError("The 'env_config_path' is not supported, please turn on ENABLE_DUMP_IR "
"and recompile source to enable it.")
raise ValueError("The 'env_config_path' is not supported, please enable ENABLE_DUMP_IR "
"with '-D on' and recompile source.")
env_config_path = os.path.realpath(env_config_path)
if not os.path.isfile(env_config_path):
raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path)

Loading…
Cancel
Save