add context graph_kernel_flags

used the flag "opt_level" to control GraphKernel,
0 means disabled while non-zero value means enabled.
the default value is controlled by context "enable_graph_kernel",
but if it's also set in "graph_kernel_flags", then the flag will prevail.

supported the whitelist and blacklist operators for GraphKernelExpander.
"enable_expand_ops", "enable_expand_ops_only", "disable_expand_ops".
pull/13720/head
dayschan 4 years ago
parent b1c86b6a22
commit 11ee3b1624

@ -21,6 +21,7 @@
#include <utility>
#include <vector>
#include "utils/context/graph_kernel_flags.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/kernel_build_info.h"
@ -67,6 +68,29 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimAssignAdd,
#endif
};
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
auto &flags = context::GraphKernelFlags::GetInstance();
auto &enable_ops_only = flags.enable_expand_ops_only;
if (!enable_ops_only.empty()) {
expand_ops.clear();
std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::inserter(expand_ops, expand_ops.end()),
new_prim);
} else {
auto &enable_ops = flags.enable_expand_ops;
auto &disable_ops = flags.disable_expand_ops;
if (!enable_ops.empty()) {
std::transform(enable_ops.begin(), enable_ops.end(), std::inserter(expand_ops, expand_ops.end()), new_prim);
}
if (!disable_ops.empty()) {
for (auto iter = expand_ops.begin(); iter != expand_ops.end();) {
if (std::find(disable_ops.begin(), disable_ops.end(), (*iter)->name()) != disable_ops.end()) {
expand_ops.erase(iter++);
} else {
++iter;
}
}
}
}
return expand_ops;
}
} // namespace

@ -17,6 +17,7 @@
#include "backend/optimizer/mem_reuse/mem_reuse.h"
#include <algorithm>
#include <memory>
#include "utils/context/graph_kernel_flags.h"
#include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
#include "backend/optimizer/common/helper.h"
@ -462,9 +463,7 @@ void MemReuseUtil::SetAllInfo(const KernelGraph *graph) {
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
#endif
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
enable_visit_kernel_cache_ = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL);
enable_visit_kernel_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel();
}
uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {

@ -46,6 +46,7 @@
#include "runtime/device/ascend/ascend_stream_assign.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/ms_utils.h"
#include "utils/context/graph_kernel_flags.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/config_manager.h"
@ -846,9 +847,7 @@ void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_
}
void AscendSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
return;
}
opt::GraphKernelOptimize(kernel_graph);

@ -69,6 +69,7 @@
#include "utils/ms_utils.h"
#include "utils/config_manager.h"
#include "utils/ms_context.h"
#include "utils/context/graph_kernel_flags.h"
#include "utils/utils.h"
#if ENABLE_CPU && ENABLE_GPU
#include "ps/util.h"
@ -127,8 +128,6 @@ void GPUSession::StartKernelRT() const {
void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
@ -136,7 +135,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
}
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
@ -181,9 +180,7 @@ void GPUSession::RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kerne
}
void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
return;
}
opt::GraphKernelOptimize(kernel_graph);

@ -40,6 +40,7 @@
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h"
#include "utils/log_adapter.h"
#include "utils/context/graph_kernel_flags.h"
#include "pipeline/jit/pipeline_split.h"
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
@ -354,9 +355,7 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
g_pass_opts["opt_after_recompute"] =
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
}

@ -97,6 +97,7 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("tune_mode", MsCtxParam::MS_CTX_TUNE_MODE)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH)
.value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS)
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")

@ -24,6 +24,7 @@
#include "runtime/device/gpu/distribution/collective_init.h"
#include "utils/convert_utils.h"
#include "utils/ms_context.h"
#include "utils/context/graph_kernel_flags.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/gpu/gpu_common.h"
#include "utils/ms_utils.h"
@ -66,9 +67,7 @@ bool GPUKernelRuntime::SyncStream() {
}
bool GPUKernelRuntime::Init() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
enable_relation_cache_ = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL);
enable_relation_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel();
if (device_init_ == true) {
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();

@ -0,0 +1,196 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/context/graph_kernel_flags.h"
#include <map>
#include <string>
#include <cstring>
#include <vector>
#include <utility>
#include "nlohmann/json.hpp"
#include "utils/ms_context.h"
namespace mindspore {
namespace context {
namespace {
// Split string to tokens
std::vector<std::string> GetTokens(const std::string &str, const std::string &delim) {
std::vector<std::string> tokens;
std::vector<char> c_str(str.begin(), str.end());
c_str.push_back('\0');
char *saveptr;
char *pch = strtok_r(&c_str[0], delim.c_str(), &saveptr);
while (pch != NULL) {
tokens.emplace_back(pch);
pch = strtok_r(NULL, delim.c_str(), &saveptr);
}
return tokens;
}
// Parse flag string to key-value pair.
// Flag format: "--key=value", bool flag's value can be implicit, the "--key" means "--key=true"
std::pair<std::string, std::string> ParseFlag(const std::string &flag) {
auto i = flag.find("--");
// check the string starts with "--".
if (i != 0 || flag.size() == 2) {
return std::pair<std::string, std::string>();
}
i += 2;
auto j = flag.find('=', i + 1); // the key should not be empty, "--=" is invalid
if (j == std::string::npos) {
// no value, treated as bool flag.
return std::make_pair(flag.substr(i), "");
} else if (j + 1 != flag.size() && flag.find('=', j + 1) == std::string::npos) {
// normal "--key=value" format
return std::make_pair(flag.substr(i, j - i), flag.substr(j + 1));
}
// string with two "=" is invalid.
return std::pair<std::string, std::string>();
}
std::map<std::string, std::string> ParseFlags(const std::string &flags) {
std::map<std::string, std::string> flag_map;
auto tokens = GetTokens(flags, " ");
for (const auto &token : tokens) {
auto flag = ParseFlag(token);
if (flag.first != "") {
if (!flag_map.insert(flag).second) {
MS_LOG(WARNING) << "Repeated GraphKernel flag: " << flag.first;
}
} else {
MS_LOG(WARNING) << "Invalid GraphKernel flag: " << token;
}
}
return flag_map;
}
class FlagRegister {
public:
explicit FlagRegister(std::map<std::string, std::string> *flag_map) : flag_map_(*flag_map) {}
~FlagRegister() = default;
template <typename T>
void AddFlag(std::string flag_name, T *flag_var) {
auto iter = flag_map_.find(flag_name);
if (iter != flag_map_.end()) {
T var;
bool ret = ParseValue(iter->second, &var);
if (ret) {
*flag_var = std::move(var);
} else {
if (iter->second.empty()) {
MS_LOG(WARNING) << "Invalid GraphKernel flag: --" << iter->first;
} else {
MS_LOG(WARNING) << "Invalid GraphKernel flag: --" << iter->first << "=" << iter->second;
}
}
flag_map_.erase(iter);
}
}
private:
bool ParseValue(const std::string &s, std::vector<std::string> *result) {
*result = GetTokens(s, ",");
return !result->empty();
}
bool ParseValue(const std::string &s, bool *result) {
*result = (s.empty() || s == "true" || s == "on" || s == "1");
return *result || s == "false" || s == "off" || s == "0";
}
template <typename T>
bool ParseValue(const std::string &s, T *result) {
if (s.empty()) {
return false;
}
std::istringstream iss(s);
iss >> (*result);
return iss.eof();
}
template <typename T>
bool ParseValue(const std::string &s, std::vector<T> *result) {
result->clear();
auto tokens = GetTokens(s, ",");
if (tokens.empty()) {
return false;
}
for (const auto &tok : tokens) {
T temp;
if (!ParseValue(tok, &temp)) {
result->clear();
return false;
}
result->emplace_back(temp);
}
return true;
}
std::map<std::string, std::string> &flag_map_;
};
} // namespace
void GraphKernelFlags::Refresh() {
auto flag_map = ParseFlags(flags_cache_);
RegisterFlags(&flag_map);
for (auto &item : flag_map) {
MS_LOG(WARNING) << "Unknown GraphKernel flag: " << item.first;
}
}
void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) {
FlagRegister reg(flag_map);
reg.AddFlag("dump_as_text", &dump_as_text);
reg.AddFlag("opt_level", &opt_level);
reg.AddFlag("auto_tune", &auto_tune);
reg.AddFlag("cluster_limit", &cluster_limit);
reg.AddFlag("enable_expand_ops", &enable_expand_ops);
reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only);
reg.AddFlag("disable_expand_ops", &disable_expand_ops);
reg.AddFlag("enable_cluster_ops", &enable_cluster_ops);
reg.AddFlag("enable_cluster_ops_only", &enable_cluster_ops_only);
reg.AddFlag("disable_cluster_ops", &disable_cluster_ops);
reg.AddFlag("enable_pass_only", &enable_pass_only);
reg.AddFlag("disable_pass", &disable_pass);
}
std::string GraphKernelFlags::DumpAllFlags() const {
nlohmann::json json;
json["dump_as_text"] = dump_as_text;
json["opt_level"] = opt_level;
json["auto_tune"] = auto_tune;
json["cluster_limit"] = cluster_limit;
json["enable_expand_ops"] = enable_expand_ops;
json["enable_expand_ops_only"] = enable_expand_ops_only;
json["disable_expand_ops"] = disable_expand_ops;
json["enable_cluster_ops"] = enable_cluster_ops;
json["enable_cluster_ops_only"] = enable_cluster_ops_only;
json["disable_cluster_ops"] = disable_cluster_ops;
json["enable_pass_only"] = enable_pass_only;
json["disable_pass"] = disable_pass;
return json.dump();
}
} // namespace context
} // namespace mindspore

@ -0,0 +1,148 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
#define MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "utils/ms_context.h"
namespace mindspore {
namespace context {
class GraphKernelFlags {
public:
static const GraphKernelFlags &GetInstance() {
static std::unique_ptr<GraphKernelFlags> flags(nullptr);
auto contexts = GetGraphKernelContext();
if (flags == nullptr || contexts.first != flags->flags_cache_ || contexts.second != flags->enable_cache_) {
flags.reset(new GraphKernelFlags(contexts.first, contexts.second));
flags->Refresh();
}
return *flags;
}
// Dump all flags to json-format string
std::string DumpAllFlags() const;
// Check whether graph_kernel is enabled
bool IsEnableGraphKernel() const { return opt_level > 0; }
GraphKernelFlags(const GraphKernelFlags &flags) = delete;
~GraphKernelFlags() = default;
public:
/**
* dump_as_text, unsupported now.
*/
bool dump_as_text{false};
/**
* opt_level, value from 0 to 3.
* 0: GraphKernel disabled
* 1: GraphKernel enabled
* 2 and 3 are not supported now.
* the default value is controlled by context `enable_graph_kernel`,
* but if it's also set in `graph_kernel_flags`, then the flag will prevail.
*/
unsigned int opt_level{0};
/**
* auto_tune, unsupported now.
*/
unsigned int auto_tune{0};
/**
* cluster_limit, unsupported now.
*/
unsigned int cluster_limit{30};
/**
* Additional expanding operators (case sensitive).
* The operators to be added into the default expanding operator list.
*/
std::vector<std::string> enable_expand_ops;
/**
* Expanding operators to be enabled (case sensitive).
* Unlike the "enable_expand_ops", the default list will be overwritten by this list.
* Note that the "enable_expand_ops" and "disable_expand_ops" will be ignored if this flag is set.
*/
std::vector<std::string> enable_expand_ops_only;
/**
* Expanding operators to be disabled (case sensitive).
* The behavior is undefined when this list overlaps with "enable_expand_ops".
*/
std::vector<std::string> disable_expand_ops;
/**
* enable_cluster_ops, unsupported now.
*/
std::vector<std::string> enable_cluster_ops;
/**
* enable_cluster_ops_only, unsupported now.
*/
std::vector<std::string> enable_cluster_ops_only;
/**
* disable_cluster_ops, unsupported now.
*/
std::vector<std::string> disable_cluster_ops;
/**
* enable_pass_only, unsupported now.
*/
std::vector<std::string> enable_pass_only;
/**
* disable_pass, unsupported now.
*/
std::vector<std::string> disable_pass;
private:
GraphKernelFlags(const std::string &graph_kernel_flags, bool enable_graph_kernel)
: flags_cache_(graph_kernel_flags), enable_cache_(enable_graph_kernel) {
opt_level = enable_graph_kernel ? 1 : 0;
}
// get the `graph_kernel_flags` and `enable_graph_kernel`
static std::pair<std::string, bool> GetGraphKernelContext() {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
// Use the environment variable in priority
auto env_flags = std::getenv("MS_GRAPH_KERNEL_FLAGS");
std::string flags = env_flags ? std::string(env_flags) : context->get_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS);
return std::make_pair(flags, context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL));
}
// parse and refresh the flags
void Refresh();
// register the flags defined above
void RegisterFlags(std::map<std::string, std::string> *flag_map);
// cache the flag string to check whether the flags is changed.
std::string flags_cache_;
// cache the enable_graph_kernel value to check whether the context is changed.
bool enable_cache_;
};
} // namespace context
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H

@ -489,6 +489,7 @@ def _check_target_specific_cfgs(device, arg_key):
'enable_dump': ['Ascend'],
'save_dump_path': ['Ascend'],
'enable_graph_kernel': ['Ascend', 'GPU'],
'graph_kernel_flags': ['Ascend', 'GPU'],
'enable_reduce_precision': ['Ascend'],
'enable_profiling': ['Ascend'],
'profiling_options': ['Ascend'],
@ -513,7 +514,7 @@ def _check_target_specific_cfgs(device, arg_key):
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
enable_sparse=bool, max_call_depth=int, env_config_path=str)
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str)
def set_context(**kwargs):
"""
Set context for running environment.
@ -540,14 +541,14 @@ def set_context(**kwargs):
=========================== =========================== =================
check_bprop print_file_path max_device_memory
device_id enable_dump enable_graph_kernel
device_target save_dump_path
device_target save_dump_path graph_kernel_flags
enable_sparse enable_graph_kernel
max_call_depth enable_reduce_precision
mode enable_profiling
reserve_class_name_in_scope profiling_options
save_graphs variable_memory_max_size
save_graphs_path auto_tune_mode
env_config_path
env_config_path graph_kernel_flags
grad_for_scalar
=========================== =========================== =================
@ -566,6 +567,7 @@ def set_context(**kwargs):
`context.set_context(save_graphs_path="path/to/ir/files"+device_id)`.
enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be
compiled into a fused kernel automatically. Default: False.
graph_kernel_flags (str): Set graph_kernel flags.
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
enable_dump (bool): Whether to enable dump. Default: False.

@ -39,6 +39,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
set_param<std::string>(MS_CTX_ENV_CONFIG_PATH, "");
set_param<std::string>(MS_CTX_TUNE_MODE, "NO_TUNE");
set_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS, "");
set_param<uint32_t>(MS_CTX_TSD_REF, 0);
set_param<uint32_t>(MS_CTX_GE_REF, 0);

@ -112,6 +112,7 @@ enum MsCtxParam : unsigned {
MS_CTX_PYTHON_EXE_PATH,
MS_CTX_ENV_CONFIG_PATH,
MS_CTX_TUNE_MODE,
MS_CTX_GRAPH_KERNEL_FLAGS,
MS_CTX_TYPE_STRING_END,
// parameter numbers of each type

Loading…
Cancel
Save