!12567 optimize rdr mainly to reduce the redundant code

From: @luopengting
Reviewed-by: 
Signed-off-by:
pull/12567/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e5aedcca47

@ -104,6 +104,8 @@ void EnvConfigParser::ParseRdrSetting(const nlohmann::json &content) {
return; return;
} }
has_rdr_setting_ = true;
auto rdr_enable = CheckJsonKeyExist(*rdr_setting, kRdrSettings, kEnable); auto rdr_enable = CheckJsonKeyExist(*rdr_setting, kRdrSettings, kEnable);
if (rdr_enable.has_value()) { if (rdr_enable.has_value()) {
ParseRdrEnable(**rdr_enable); ParseRdrEnable(**rdr_enable);

@ -27,11 +27,14 @@ class EnvConfigParser {
public: public:
static EnvConfigParser &GetInstance() { static EnvConfigParser &GetInstance() {
static EnvConfigParser instance; static EnvConfigParser instance;
instance.Parse();
return instance; return instance;
} }
void Parse(); void Parse();
std::string config_path() const { return config_file_; }
bool has_rdr_setting() const { return has_rdr_setting_; }
bool rdr_enabled() const { return rdr_enabled_; } bool rdr_enabled() const { return rdr_enabled_; }
std::string rdr_path() const { return rdr_path_; } std::string rdr_path() const { return rdr_path_; }
@ -42,7 +45,9 @@ class EnvConfigParser {
std::mutex lock_; std::mutex lock_;
std::string config_file_{""}; std::string config_file_{""};
bool already_parsed_{false}; bool already_parsed_{false};
bool rdr_enabled_{false}; bool rdr_enabled_{false};
bool has_rdr_setting_{false};
std::string rdr_path_{"./rdr/"}; std::string rdr_path_{"./rdr/"};
std::string GetIfstreamString(const std::ifstream &ifstream); std::string GetIfstreamString(const std::ifstream &ifstream);

@ -39,13 +39,13 @@ void BaseRecorder::SetFilename(const std::string &filename) {
std::optional<std::string> BaseRecorder::GetFileRealPath(const std::string &suffix) { std::optional<std::string> BaseRecorder::GetFileRealPath(const std::string &suffix) {
if (filename_.empty()) { if (filename_.empty()) {
filename_ = module_ + "_" + tag_; filename_ = module_ + delimiter_ + tag_;
if (!suffix.empty()) { if (!suffix.empty()) {
filename_ += "_" + suffix; filename_ += delimiter_ + suffix;
} }
filename_ += "_" + timestamp_; filename_ += delimiter_ + timestamp_;
} else if (!suffix.empty()) { } else if (!suffix.empty()) {
filename_ += "_" + suffix; filename_ += delimiter_ + suffix;
} }
std::string file_path = directory_ + filename_; std::string file_path = directory_ + filename_;
auto realpath = Common::GetRealPath(file_path); auto realpath = Common::GetRealPath(file_path);

@ -65,6 +65,7 @@ class BaseRecorder {
void SetDirectory(const std::string &directory); void SetDirectory(const std::string &directory);
void SetFilename(const std::string &filename); void SetFilename(const std::string &filename);
void SetModule(const std::string &module) { module_ = module; }
virtual void Export() {} virtual void Export() {}
protected: protected:
@ -73,6 +74,7 @@ class BaseRecorder {
std::string directory_; std::string directory_;
std::string filename_; std::string filename_;
std::string timestamp_; // year,month,day,hour,minute,second std::string timestamp_; // year,month,day,hour,minute,second
std::string delimiter_{"."};
}; };
using BaseRecorderPtr = std::shared_ptr<BaseRecorder>; using BaseRecorderPtr = std::shared_ptr<BaseRecorder>;
} // namespace mindspore } // namespace mindspore

@ -46,11 +46,11 @@ bool DumpGraphExeOrder(const std::string &filename, const std::vector<CNodePtr>
} // namespace } // namespace
void GraphExecOrderRecorder::Export() { void GraphExecOrderRecorder::Export() {
auto realpath = GetFileRealPath(); auto realpath = GetFileRealPath(std::to_string(graph_id_));
if (!realpath.has_value()) { if (!realpath.has_value()) {
return; return;
} }
std::string real_file_path = realpath.value() + std::to_string(graph_id_); std::string real_file_path = realpath.value() + ".txt";
DumpGraphExeOrder(real_file_path, exec_order_); DumpGraphExeOrder(real_file_path, exec_order_);
} }
} // namespace mindspore } // namespace mindspore

@ -30,7 +30,6 @@ class GraphExecOrderRecorder : public BaseRecorder {
GraphExecOrderRecorder(const std::string &module, const std::string &tag, GraphExecOrderRecorder(const std::string &module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) const std::vector<CNodePtr> &final_exec_order, int graph_id)
: BaseRecorder(module, tag), exec_order_(final_exec_order), graph_id_(graph_id) {} : BaseRecorder(module, tag), exec_order_(final_exec_order), graph_id_(graph_id) {}
void SetModule(const std::string &module) { module_ = module; }
void SetExecOrder(const std::vector<CNodePtr> &final_exec_order) { exec_order_ = final_exec_order; } void SetExecOrder(const std::vector<CNodePtr> &final_exec_order) { exec_order_ = final_exec_order; }
virtual void Export(); virtual void Export();

@ -65,6 +65,7 @@ void GraphRecorder::Export() {
} }
std::string suffix = graph_id >= 0 ? std::to_string(graph_id) : ""; std::string suffix = graph_id >= 0 ? std::to_string(graph_id) : "";
auto tmp_realpath = GetFileRealPath(suffix); auto tmp_realpath = GetFileRealPath(suffix);
if (!tmp_realpath.has_value()) { if (!tmp_realpath.has_value()) {
return; return;
} }

@ -31,7 +31,6 @@ class GraphRecorder : public BaseRecorder {
const std::string &file_type) const std::string &file_type)
: BaseRecorder(module, tag), func_graph_(graph), graph_type_(file_type) {} : BaseRecorder(module, tag), func_graph_(graph), graph_type_(file_type) {}
~GraphRecorder() {} ~GraphRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
void SetGraphType(const std::string &file_type) { graph_type_ = file_type; } void SetGraphType(const std::string &file_type) { graph_type_ = file_type; }
void SetFuncGraph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } void SetFuncGraph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
void SetDumpFlag(bool full_name) { full_name_ = full_name; } void SetDumpFlag(bool full_name) { full_name_ = full_name; }

@ -21,7 +21,35 @@
#include "mindspore/core/ir/func_graph.h" #include "mindspore/core/ir/func_graph.h"
namespace mindspore { namespace mindspore {
void RecorderManager::UpdateRdrEnable() {
static bool updated = false;
if (updated) {
return;
}
auto &config_parser = mindspore::EnvConfigParser::GetInstance();
rdr_enable_ = config_parser.rdr_enabled();
if (config_parser.has_rdr_setting()) {
#ifdef __linux__
if (!rdr_enable_) {
MS_LOG(WARNING) << "Please set the 'enable' as true using 'rdr' setting in file '" << config_parser.config_path()
<< "' if you want to use RDR.";
}
#else
if (rdr_enable_) {
MS_LOG(WARNING) << "The RDR only supports linux os currently.";
}
rdr_enable_ = false;
#endif
}
updated = true;
}
bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) { bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
if (!rdr_enable_) {
return false;
}
if (recorder == nullptr) { if (recorder == nullptr) {
MS_LOG(ERROR) << "register recorder module with nullptr."; MS_LOG(ERROR) << "register recorder module with nullptr.";
return false; return false;
@ -33,10 +61,7 @@ bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
} }
void RecorderManager::TriggerAll() { void RecorderManager::TriggerAll() {
auto &config_parser_ptr = mindspore::EnvConfigParser::GetInstance(); if (!rdr_enable_) {
config_parser_ptr.Parse();
if (!config_parser_ptr.rdr_enabled()) {
MS_LOG(INFO) << "RDR is not enable.";
return; return;
} }

@ -31,9 +31,12 @@ class RecorderManager {
public: public:
static RecorderManager &Instance() { static RecorderManager &Instance() {
static RecorderManager manager; static RecorderManager manager;
manager.UpdateRdrEnable();
return manager; return manager;
} }
void UpdateRdrEnable();
bool RdrEnable() const { return rdr_enable_; }
bool RecordObject(const BaseRecorderPtr &recorder); bool RecordObject(const BaseRecorderPtr &recorder);
void TriggerAll(); void TriggerAll();
void ClearAll(); void ClearAll();
@ -42,6 +45,8 @@ class RecorderManager {
RecorderManager() {} RecorderManager() {}
~RecorderManager() {} ~RecorderManager() {}
bool rdr_enable_{false};
mutable std::mutex mtx_; mutable std::mutex mtx_;
// module, BaserRecorderPtrList // module, BaserRecorderPtrList
std::unordered_map<std::string, BaseRecorderPtrList> recorder_container_; std::unordered_map<std::string, BaseRecorderPtrList> recorder_container_;

@ -65,6 +65,9 @@ namespace RDR {
#ifdef ENABLE_D #ifdef ENABLE_D
bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag, bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
const std::vector<TaskDebugInfoPtr> &task_debug_info_list, int graph_id) { const std::vector<TaskDebugInfoPtr> &task_debug_info_list, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
TaskDebugInfoRecorderPtr task_debug_info_recorder = TaskDebugInfoRecorderPtr task_debug_info_recorder =
std::make_shared<TaskDebugInfoRecorder>(submodule_name, tag, task_debug_info_list, graph_id); std::make_shared<TaskDebugInfoRecorder>(submodule_name, tag, task_debug_info_list, graph_id);
@ -73,9 +76,11 @@ bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
} }
#endif // ENABLE_D #endif // ENABLE_D
#ifdef __linux__
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name, bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) { const std::string &file_type) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
GraphRecorderPtr graph_recorder = std::make_shared<GraphRecorder>(submodule_name, tag, graph, file_type); GraphRecorderPtr graph_recorder = std::make_shared<GraphRecorder>(submodule_name, tag, graph, file_type);
graph_recorder->SetDumpFlag(full_name); graph_recorder->SetDumpFlag(full_name);
@ -85,6 +90,9 @@ bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const Func
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag, bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) { const std::vector<CNodePtr> &final_exec_order, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
GraphExecOrderRecorderPtr graph_exec_order_recorder = GraphExecOrderRecorderPtr graph_exec_order_recorder =
std::make_shared<GraphExecOrderRecorder>(submodule_name, tag, final_exec_order, graph_id); std::make_shared<GraphExecOrderRecorder>(submodule_name, tag, final_exec_order, graph_id);
@ -93,6 +101,9 @@ bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
} }
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) { bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
StringRecorderPtr string_recorder = std::make_shared<StringRecorder>(submodule_name, tag, data, filename); StringRecorderPtr string_recorder = std::make_shared<StringRecorder>(submodule_name, tag, data, filename);
string_recorder->SetFilename(filename); string_recorder->SetFilename(filename);
@ -102,6 +113,9 @@ bool RecordString(SubModuleId module, const std::string &tag, const std::string
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id, bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) { const std::vector<CNodePtr> &exec_order) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module)); std::string submodule_name = std::string(GetSubModuleName(module));
StreamExecOrderRecorderPtr stream_exec_order_recorder = StreamExecOrderRecorderPtr stream_exec_order_recorder =
std::make_shared<StreamExecOrderRecorder>(submodule_name, tag, graph_id, exec_order); std::make_shared<StreamExecOrderRecorder>(submodule_name, tag, graph_id, exec_order);
@ -113,68 +127,5 @@ void TriggerAll() { mindspore::RecorderManager::Instance().TriggerAll(); }
void ClearAll() { mindspore::RecorderManager::Instance().ClearAll(); } void ClearAll() { mindspore::RecorderManager::Instance().ClearAll(); }
#else
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) {
static bool already_printed = false;
std::string submodule_name = std::string(GetSubModuleName(module));
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os " << submodule_name;
return false;
}
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
void TriggerAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
void ClearAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
#endif // __linux__
} // namespace RDR } // namespace RDR
} // namespace mindspore } // namespace mindspore

@ -98,7 +98,6 @@ class StreamExecOrderRecorder : public BaseRecorder {
exec_order_.push_back(std::move(exec_node_ptr)); exec_order_.push_back(std::move(exec_node_ptr));
} }
} }
void SetModule(const std::string &module) { module_ = module; }
virtual void Export(); virtual void Export();
private: private:

@ -26,14 +26,14 @@ class StringRecorder : public BaseRecorder {
StringRecorder() : BaseRecorder() {} StringRecorder() : BaseRecorder() {}
StringRecorder(const std::string &module, const std::string &tag, const std::string &data, StringRecorder(const std::string &module, const std::string &tag, const std::string &data,
const std::string &filename) const std::string &filename)
: BaseRecorder(module, tag), data_(data), filename_(filename) {} : BaseRecorder(module, tag), data_(data) {
SetFilename(filename);
}
~StringRecorder() {} ~StringRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
virtual void Export(); virtual void Export();
private: private:
std::string data_; std::string data_;
std::string filename_;
}; };
using StringRecorderPtr = std::shared_ptr<StringRecorder>; using StringRecorderPtr = std::shared_ptr<StringRecorder>;
} // namespace mindspore } // namespace mindspore

@ -258,8 +258,8 @@ class _Context:
def set_env_config_path(self, env_config_path): def set_env_config_path(self, env_config_path):
"""Check and set env_config_path.""" """Check and set env_config_path."""
if not self._context_handle.enable_dump_ir(): if not self._context_handle.enable_dump_ir():
raise ValueError("The 'env_config_path' is not supported, please turn on ENABLE_DUMP_IR " raise ValueError("The 'env_config_path' is not supported, please enable ENABLE_DUMP_IR "
"and recompile source to enable it.") "with '-D on' and recompile source.")
env_config_path = os.path.realpath(env_config_path) env_config_path = os.path.realpath(env_config_path)
if not os.path.isfile(env_config_path): if not os.path.isfile(env_config_path):
raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path) raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path)

Loading…
Cancel
Save