optimize rdr mainly to reduce the redundant code

1. move common function to base recorder
2. add UpdataRdrEnable() for recorder manager, and update once
3. add delimiter virable for base recorder, set it as '.'
4. modify recorders to use the suffix parameter of GetFileRealPath()
5. improve the text in ms_context

Can not add RdrEnablePlatform because linux is supported only.
If os is linux, (!is_rdr_supported && rdr_enable) will be false always,
else, (is_rdr_supported && !rdr_enable) will be false always.
pull/12567/head
luopengting 4 years ago
parent 692d158f5c
commit 7914363d25

@ -104,6 +104,8 @@ void EnvConfigParser::ParseRdrSetting(const nlohmann::json &content) {
return; return;
} }
has_rdr_setting_ = true;
auto rdr_enable = CheckJsonKeyExist(*rdr_setting, kRdrSettings, kEnable); auto rdr_enable = CheckJsonKeyExist(*rdr_setting, kRdrSettings, kEnable);
if (rdr_enable.has_value()) { if (rdr_enable.has_value()) {
ParseRdrEnable(**rdr_enable); ParseRdrEnable(**rdr_enable);

@ -27,11 +27,14 @@ class EnvConfigParser {
public: public:
static EnvConfigParser &GetInstance() { static EnvConfigParser &GetInstance() {
static EnvConfigParser instance; static EnvConfigParser instance;
instance.Parse();
return instance; return instance;
} }
void Parse(); void Parse();
std::string config_path() const { return config_file_; }
bool has_rdr_setting() const { return has_rdr_setting_; }
bool rdr_enabled() const { return rdr_enabled_; } bool rdr_enabled() const { return rdr_enabled_; }
std::string rdr_path() const { return rdr_path_; } std::string rdr_path() const { return rdr_path_; }
@ -42,7 +45,9 @@ class EnvConfigParser {
std::mutex lock_; std::mutex lock_;
std::string config_file_{""}; std::string config_file_{""};
bool already_parsed_{false}; bool already_parsed_{false};
bool rdr_enabled_{false}; bool rdr_enabled_{false};
bool has_rdr_setting_{false};
std::string rdr_path_{"./rdr/"}; std::string rdr_path_{"./rdr/"};
std::string GetIfstreamString(const std::ifstream &ifstream); std::string GetIfstreamString(const std::ifstream &ifstream);

@ -39,13 +39,13 @@ void BaseRecorder::SetFilename(const std::string &filename) {
std::optional<std::string> BaseRecorder::GetFileRealPath(const std::string &suffix) { std::optional<std::string> BaseRecorder::GetFileRealPath(const std::string &suffix) {
if (filename_.empty()) { if (filename_.empty()) {
filename_ = module_ + "_" + tag_; filename_ = module_ + delimiter_ + tag_;
if (!suffix.empty()) { if (!suffix.empty()) {
filename_ += "_" + suffix; filename_ += delimiter_ + suffix;
} }
filename_ += "_" + timestamp_; filename_ += delimiter_ + timestamp_;
} else if (!suffix.empty()) { } else if (!suffix.empty()) {
filename_ += "_" + suffix; filename_ += delimiter_ + suffix;
} }
std::string file_path = directory_ + filename_; std::string file_path = directory_ + filename_;
auto realpath = Common::GetRealPath(file_path); auto realpath = Common::GetRealPath(file_path);

@ -65,6 +65,7 @@ class BaseRecorder {
void SetDirectory(const std::string &directory); void SetDirectory(const std::string &directory);
void SetFilename(const std::string &filename); void SetFilename(const std::string &filename);
void SetModule(const std::string &module) { module_ = module; }
virtual void Export() {} virtual void Export() {}
protected: protected:
@ -73,6 +74,7 @@ class BaseRecorder {
std::string directory_; std::string directory_;
std::string filename_; std::string filename_;
std::string timestamp_; // year,month,day,hour,minute,second std::string timestamp_; // year,month,day,hour,minute,second
std::string delimiter_{"."};
}; };
using BaseRecorderPtr = std::shared_ptr<BaseRecorder>; using BaseRecorderPtr = std::shared_ptr<BaseRecorder>;
} // namespace mindspore } // namespace mindspore

@ -46,11 +46,11 @@ bool DumpGraphExeOrder(const std::string &filename, const std::vector<CNodePtr>
} // namespace } // namespace
void GraphExecOrderRecorder::Export() { void GraphExecOrderRecorder::Export() {
auto realpath = GetFileRealPath(); auto realpath = GetFileRealPath(std::to_string(graph_id_));
if (!realpath.has_value()) { if (!realpath.has_value()) {
return; return;
} }
std::string real_file_path = realpath.value() + std::to_string(graph_id_); std::string real_file_path = realpath.value() + ".txt";
DumpGraphExeOrder(real_file_path, exec_order_); DumpGraphExeOrder(real_file_path, exec_order_);
} }
} // namespace mindspore } // namespace mindspore

@ -30,7 +30,6 @@ class GraphExecOrderRecorder : public BaseRecorder {
GraphExecOrderRecorder(const std::string &module, const std::string &tag, GraphExecOrderRecorder(const std::string &module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) const std::vector<CNodePtr> &final_exec_order, int graph_id)
: BaseRecorder(module, tag), exec_order_(final_exec_order), graph_id_(graph_id) {} : BaseRecorder(module, tag), exec_order_(final_exec_order), graph_id_(graph_id) {}
void SetModule(const std::string &module) { module_ = module; }
void SetExecOrder(const std::vector<CNodePtr> &final_exec_order) { exec_order_ = final_exec_order; } void SetExecOrder(const std::vector<CNodePtr> &final_exec_order) { exec_order_ = final_exec_order; }
virtual void Export(); virtual void Export();

@ -65,6 +65,7 @@ void GraphRecorder::Export() {
} }
std::string suffix = graph_id >= 0 ? std::to_string(graph_id) : ""; std::string suffix = graph_id >= 0 ? std::to_string(graph_id) : "";
auto tmp_realpath = GetFileRealPath(suffix); auto tmp_realpath = GetFileRealPath(suffix);
if (!tmp_realpath.has_value()) { if (!tmp_realpath.has_value()) {
return; return;
} }

@ -31,7 +31,6 @@ class GraphRecorder : public BaseRecorder {
const std::string &file_type) const std::string &file_type)
: BaseRecorder(module, tag), func_graph_(graph), graph_type_(file_type) {} : BaseRecorder(module, tag), func_graph_(graph), graph_type_(file_type) {}
~GraphRecorder() {} ~GraphRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
void SetGraphType(const std::string &file_type) { graph_type_ = file_type; } void SetGraphType(const std::string &file_type) { graph_type_ = file_type; }
void SetFuncGraph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } void SetFuncGraph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
void SetDumpFlag(bool full_name) { full_name_ = full_name; } void SetDumpFlag(bool full_name) { full_name_ = full_name; }

@ -21,7 +21,35 @@
#include "mindspore/core/ir/func_graph.h" #include "mindspore/core/ir/func_graph.h"
namespace mindspore { namespace mindspore {
void RecorderManager::UpdateRdrEnable() {
static bool updated = false;
if (updated) {
return;
}
auto &config_parser = mindspore::EnvConfigParser::GetInstance();
rdr_enable_ = config_parser.rdr_enabled();
if (config_parser.has_rdr_setting()) {
#ifdef __linux__
if (!rdr_enable_) {
MS_LOG(WARNING) << "Please set the 'enable' as true using 'rdr' setting in file '" << config_parser.config_path()
<< "' if you want to use RDR.";
}
#else
if (rdr_enable_) {
MS_LOG(WARNING) << "The RDR only supports linux os currently.";
}
rdr_enable_ = false;
#endif
}
updated = true;
}
bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) { bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
if (!rdr_enable_) {
return false;
}
if (recorder == nullptr) { if (recorder == nullptr) {
MS_LOG(ERROR) << "register recorder module with nullptr."; MS_LOG(ERROR) << "register recorder module with nullptr.";
return false; return false;
@ -33,10 +61,7 @@ bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
} }
void RecorderManager::TriggerAll() { void RecorderManager::TriggerAll() {
auto &config_parser_ptr = mindspore::EnvConfigParser::GetInstance(); if (!rdr_enable_) {
config_parser_ptr.Parse();
if (!config_parser_ptr.rdr_enabled()) {
MS_LOG(INFO) << "RDR is not enable.";
return; return;
} }

@ -31,9 +31,12 @@ class RecorderManager {
public: public:
static RecorderManager &Instance() { static RecorderManager &Instance() {
static RecorderManager manager; static RecorderManager manager;
manager.UpdateRdrEnable();
return manager; return manager;
} }
void UpdateRdrEnable();
bool RdrEnable() const { return rdr_enable_; }
bool RecordObject(const BaseRecorderPtr &recorder); bool RecordObject(const BaseRecorderPtr &recorder);
void TriggerAll(); void TriggerAll();
void ClearAll(); void ClearAll();
@ -42,6 +45,8 @@ class RecorderManager {
RecorderManager() {} RecorderManager() {}
~RecorderManager() {} ~RecorderManager() {}
bool rdr_enable_{false};
mutable std::mutex mtx_; mutable std::mutex mtx_;
// module, BaserRecorderPtrList // module, BaserRecorderPtrList
std::unordered_map<std::string, BaseRecorderPtrList> recorder_container_; std::unordered_map<std::string, BaseRecorderPtrList> recorder_container_;

@ -65,6 +65,9 @@ namespace RDR {
#ifdef ENABLE_D #ifdef ENABLE_D
bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag, bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
const std::vector<TaskDebugInfoPtr> &task_debug_info_list, int graph_id) { const std::vector<TaskDebugInfoPtr> &task_debug_info_list, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
TaskDebugInfoRecorderPtr task_debug_info_recorder = TaskDebugInfoRecorderPtr task_debug_info_recorder =
std::make_shared<TaskDebugInfoRecorder>(submodule_name, tag, task_debug_info_list, graph_id); std::make_shared<TaskDebugInfoRecorder>(submodule_name, tag, task_debug_info_list, graph_id);
@ -73,9 +76,11 @@ bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
} }
#endif // ENABLE_D #endif // ENABLE_D
#ifdef __linux__
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name, bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) { const std::string &file_type) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
GraphRecorderPtr graph_recorder = std::make_shared<GraphRecorder>(submodule_name, tag, graph, file_type); GraphRecorderPtr graph_recorder = std::make_shared<GraphRecorder>(submodule_name, tag, graph, file_type);
graph_recorder->SetDumpFlag(full_name); graph_recorder->SetDumpFlag(full_name);
@ -85,6 +90,9 @@ bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const Func
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag, bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) { const std::vector<CNodePtr> &final_exec_order, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
GraphExecOrderRecorderPtr graph_exec_order_recorder = GraphExecOrderRecorderPtr graph_exec_order_recorder =
std::make_shared<GraphExecOrderRecorder>(submodule_name, tag, final_exec_order, graph_id); std::make_shared<GraphExecOrderRecorder>(submodule_name, tag, final_exec_order, graph_id);
@ -93,6 +101,9 @@ bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
} }
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) { bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
StringRecorderPtr string_recorder = std::make_shared<StringRecorder>(submodule_name, tag, data, filename); StringRecorderPtr string_recorder = std::make_shared<StringRecorder>(submodule_name, tag, data, filename);
string_recorder->SetFilename(filename); string_recorder->SetFilename(filename);
@ -102,6 +113,9 @@ bool RecordString(SubModuleId module, const std::string &tag, const std::string
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id, bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) { const std::vector<CNodePtr> &exec_order) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
StreamExecOrderRecorderPtr stream_exec_order_recorder = StreamExecOrderRecorderPtr stream_exec_order_recorder =
std::make_shared<StreamExecOrderRecorder>(submodule_name, tag, graph_id, exec_order); std::make_shared<StreamExecOrderRecorder>(submodule_name, tag, graph_id, exec_order);
@ -113,68 +127,5 @@ void TriggerAll() { mindspore::RecorderManager::Instance().TriggerAll(); }
void ClearAll() { mindspore::RecorderManager::Instance().ClearAll(); } void ClearAll() { mindspore::RecorderManager::Instance().ClearAll(); }
#else
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) {
static bool already_printed = false;
std::string submodule_name = std::string(GetSubModuleName(module));
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os " << submodule_name;
return false;
}
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
void TriggerAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
void ClearAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
#endif // __linux__
} // namespace RDR } // namespace RDR
} // namespace mindspore } // namespace mindspore

@ -98,7 +98,6 @@ class StreamExecOrderRecorder : public BaseRecorder {
exec_order_.push_back(std::move(exec_node_ptr)); exec_order_.push_back(std::move(exec_node_ptr));
} }
} }
void SetModule(const std::string &module) { module_ = module; }
virtual void Export(); virtual void Export();
private: private:

@ -26,14 +26,14 @@ class StringRecorder : public BaseRecorder {
StringRecorder() : BaseRecorder() {} StringRecorder() : BaseRecorder() {}
StringRecorder(const std::string &module, const std::string &tag, const std::string &data, StringRecorder(const std::string &module, const std::string &tag, const std::string &data,
const std::string &filename) const std::string &filename)
: BaseRecorder(module, tag), data_(data), filename_(filename) {} : BaseRecorder(module, tag), data_(data) {
SetFilename(filename);
}
~StringRecorder() {} ~StringRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
virtual void Export(); virtual void Export();
private: private:
std::string data_; std::string data_;
std::string filename_;
}; };
using StringRecorderPtr = std::shared_ptr<StringRecorder>; using StringRecorderPtr = std::shared_ptr<StringRecorder>;
} // namespace mindspore } // namespace mindspore

@ -258,8 +258,8 @@ class _Context:
def set_env_config_path(self, env_config_path): def set_env_config_path(self, env_config_path):
"""Check and set env_config_path.""" """Check and set env_config_path."""
if not self._context_handle.enable_dump_ir(): if not self._context_handle.enable_dump_ir():
raise ValueError("The 'env_config_path' is not supported, please turn on ENABLE_DUMP_IR " raise ValueError("The 'env_config_path' is not supported, please enable ENABLE_DUMP_IR "
"and recompile source to enable it.") "with '-D on' and recompile source.")
env_config_path = os.path.realpath(env_config_path) env_config_path = os.path.realpath(env_config_path)
if not os.path.isfile(env_config_path): if not os.path.isfile(env_config_path):
raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path) raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path)

Loading…
Cancel
Save