add dump level

pull/11277/head
chenfei 4 years ago
parent 1221eac84a
commit 874f350eb9

@ -19,8 +19,8 @@
#endif
#include <fstream>
#include <iomanip>
#include <map>
#include <memory>
#include <unordered_map>
#include "ir/primitive.h"
#include "ir/func_graph.h"
#include "runtime/device/kernel_info.h"
@ -543,8 +543,21 @@ std::string AddGlobalId(const std::string &filename) {
return s.str();
}
void GetEnvDumpIrLineLevel(LocDumpMode *dump_location) {
static std::unordered_map<std::string, enum LocDumpMode> dump_level_map = {
{std::to_string(kOff), kOff}, {std::to_string(kTopStack), kTopStack}, {std::to_string(kWholeStack), kWholeStack}};
static auto dump_level_in_env = common::GetEnv("ENV_DUMP_IR_LINE_LEVEL");
auto it = dump_level_map.find(dump_level_in_env);
if (it == dump_level_map.end()) {
return;
}
// Use the env setting instead parameter setting.
*dump_location = it->second;
}
#ifdef ENABLE_DUMP_IR
void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_full_name, LocDumpMode dump_location) {
GetEnvDumpIrLineLevel(&dump_location);
if (graph == nullptr) {
return;
}

@ -206,6 +206,16 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize
loop = loop || change;
// record the status of each transform
static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1");
if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
auto fg_name = optimizer->name() + "_" + std::to_string(optimizer->CurPass_.counter) + "_" +
optimizer->CurPass_.name + "_" + list_[i]->name_;
DumpIR(fg_name + ".ir", func_graph);
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
func_graph->DumpFuncGraph(fg_name);
ExportIR(fg_name + ".dat", "", func_graph);
}
}
if (optimizer->is_on_debug_) {
status[list_[i]->name_ + std::to_string(i)].push_back(change);
space = std::max(list_[i]->name_.size(), space);

@ -181,10 +181,11 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
}
};
use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
if (is_on_debug_ && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1");
if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
auto fg_name =
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
DumpIR(fg_name + ".ir", func_graph);
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
func_graph->DumpFuncGraph(fg_name);

@ -22,6 +22,7 @@
#include <string>
#include <memory>
#include <unordered_map>
#include <sstream>
#include <algorithm>
#include "pipeline/jit/parse/resolve.h"
#include "frontend/operator/ops.h"
@ -1215,6 +1216,21 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
return body_block;
}
int64_t GetForTransToWhileLoop() {
static const auto loop_str = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
// int64 support 63bits positive num mostly.
if (loop_str.size() > 63 || loop_str.empty()) {
return MAX_FOR_LOOP_COUNT;
}
if (std::any_of(loop_str.begin(), loop_str.end(), [](char c) { return c < '0' || c > '9'; })) {
return MAX_FOR_LOOP_COUNT;
}
int64_t loop_count;
std::stringstream ss;
ss << loop_str;
ss >> loop_count;
return loop_count;
}
// A for loop will generate 3 functions :the test, the body, and the continuation
// for x in xs:
// body
@ -1231,8 +1247,8 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
CNodePtr bool_node =
block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)});
CNodePtr bool_node = block->func_graph()->NewCNode(
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())});
// create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
FunctionBlockPtr true_block = nullptr;

Loading…
Cancel
Save