Refactor ms_context implementation

pull/5352/head
fary86 5 years ago
parent 0c5f7377bd
commit fcbb3e0edc

@ -25,7 +25,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true;
}
if (inputs.empty() || hccl_data_type_list_.empty()) {

@ -24,7 +24,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true;
}
if (inputs.empty() || hccl_data_type_list_.empty()) {

@ -24,7 +24,7 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true;
}
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {

@ -25,7 +25,7 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true;
}
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {

@ -101,7 +101,8 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL) &&
IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
kernel_type = KernelType::AKG_KERNEL;
}

@ -328,7 +328,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->device_target() == kGPUDevice);
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) {
MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
<< ", current op num: " << op_info_.size();

@ -249,8 +249,8 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -262,7 +262,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
if (context_ptr->execution_mode() == kPynativeMode) {
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
} else {
@ -276,7 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
AddAscendIRFusionPass(ir_fusion_pm.get());
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) &&
ConfigManager::GetInstance().iter_num() > 1) {
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
@ -296,12 +297,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->ir_fusion_flag()) {
if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) {
MS_LOG(INFO) << "IRFusion is not enable, skip";
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -331,8 +332,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -367,7 +368,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto other2_pm = std::make_shared<PassManager>("other2_pm");
other2_pm->AddPass(std::make_shared<GetitemTuple>());
other2_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) &&
ConfigManager::GetInstance().iter_num() > 1) {
other2_pm->AddPass(std::make_shared<GetnextMemcpyElimination>());
}
other2_pm->AddPass(std::make_shared<CheckConsistency>());
@ -388,11 +390,11 @@ void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &ke
bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -418,11 +420,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -447,11 +449,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -473,12 +475,12 @@ void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &ke
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->ir_fusion_flag()) {
if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) {
MS_LOG(INFO) << "UBFusion is not enable, skip";
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

@ -53,7 +53,8 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
return new_node;
}

@ -44,7 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
auto cnode = node->cast<CNodePtr>();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return RectifyKernelInfoInPynativeProcess(node);
}
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) {

@ -33,8 +33,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

@ -392,7 +392,8 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
bool IsNopNode(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) {
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
return false;
}
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,

@ -40,8 +40,8 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

@ -114,7 +114,7 @@ const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &gr
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->device_target() == kAscendDevice) {
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
if (!CheckAttrs(strided_slice_grad)) {
MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed";
return nullptr;

@ -359,11 +359,11 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto save_graphs_path = context_ptr->save_graphs_path();
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (context_ptr->save_graphs_flag()) {
if (context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir";
DumpIR(file_path, root_graph.get());
}

@ -253,7 +253,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
debugger_->PreExecute(graph);
}
#endif
if (ms_context->precompile_only()) {
if (ms_context->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
MS_LOG(INFO) << "Precompile only, stop in build kernel step";
} else {
// alloc memory, including static memory and dynamic memory
@ -278,8 +278,8 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
child_graph->SetExecOrderByDefault();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -436,7 +436,7 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kGraphMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
if (raise_precision_count > 0) {
MS_LOG(WARNING) << "There has " << raise_precision_count
<< " node/nodes used raise precision to selected the kernel!";
@ -481,8 +481,8 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap
device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -601,11 +601,11 @@ void AscendSession::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs)
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (!save_graphs) {
return;
}
auto save_graphs_path = context_ptr->save_graphs_path();
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -733,7 +733,7 @@ void AscendSession::MergeGraphExecOrder() {
if (graph_order.size() > 1) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->enable_task_sink()) {
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!";
}
}
@ -920,8 +920,8 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs) {
if (save_graphs_path.empty()) {
save_graphs_path = ".";
@ -947,7 +947,7 @@ void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kGraphMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
if (raise_precision_count > 0) {
MS_LOG(WARNING) << "There are " << raise_precision_count
<< " node/nodes used raise precision to selected the kernel!";
@ -992,8 +992,8 @@ void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs) {
if (save_graphs_path.empty()) {
save_graphs_path = ".";

@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) {
if (!CheckInModeBlackList(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
@ -154,7 +154,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
bool need_sync = false;
if (ms_context->enable_pynative_infer()) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
if (tensor_address == nullptr || tensor_address != device_address) {
need_sync = true;
}
@ -223,7 +223,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
// Prepare ms context info for dump .pb graph
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
// Optimize
Optimize(graph);
// Select kernel build info
@ -290,7 +290,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
// Summary
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_gpu_summary()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY)) {
Summary(kernel_graph.get());
}
#ifdef ENABLE_DEBUGGER

@ -268,7 +268,7 @@ void MSInferSession::RegAllOp() {
return;
}
Initialized = true;
MsContext::GetInstance()->set_execution_mode(kGraphMode);
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
Py_Initialize();
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
if (c_expression == nullptr) {
@ -357,13 +357,13 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
}
ms_context->set_execution_mode(kGraphMode);
ms_context->set_device_id(device_id);
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
auto ajust_device = AjustTargetName(device);
if (ajust_device == "") {
return FAILED;
}
ms_context->set_device_target(device);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, device);
if (!context::OpenTsd(ms_context)) {
MS_LOG(ERROR) << "Session init OpenTsd failed!";
return FAILED;

@ -93,10 +93,11 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
// if in paynative mode,data only copyed to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() != kPynativeMode && ms_context->device_target() != kGPUDevice) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
tensor->set_need_sync(true);
}
if (ms_context->execution_mode() != kPynativeMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
tensor->SetNeedWait(true);
}
tensor->set_dirty(false);
@ -938,7 +939,7 @@ bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
if (ms_context->enable_pynative_infer()) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
}
if (tensor->is_dirty()) {
@ -979,7 +980,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
if (ms_context->execution_mode() == kPynativeMode ||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
}
@ -1177,7 +1178,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
if (backend_anf != nullptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->execution_mode() == kPynativeMode) {
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return backend_anf;
}

@ -118,7 +118,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
debugger_ = Debugger::GetInstance();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
debugger_->Init(device_id_, ms_context->device_target());
debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
}
#endif

@ -53,7 +53,7 @@ bool DataDumpParser::DumpEnabled() const {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
if (context->execution_mode() == kPynativeMode) {
if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
MS_LOG(EXCEPTION) << "[DataDump] PyNative mode not support data dump";
}
return true;

@ -142,7 +142,7 @@ void Debugger::EnableDebugger() {
// switch memory reuse on or off
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
context_ptr->set_enable_mem_reuse(partial_memory_);
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_);
// print some message about memory reuse to user
if (partial_memory_) {
MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "

@ -530,7 +530,7 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
MS_LOG(ERROR) << "ms_context is nullptr";
return;
}
auto save_graphs_path = ms_context->save_graphs_path();
auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

@ -112,7 +112,7 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// dump_enable_ is true, close mem reuse
context_ptr->set_enable_mem_reuse(!dump_enable_);
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, !dump_enable_);
trans_flag_ = trans_flag;
dump_mode_ = mode;
dump_path_ = path;
@ -135,7 +135,7 @@ bool Dump::SetDumpConfFromJsonFile() {
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto id = context_ptr->device_id();
auto id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
char real_path[PATH_MAX] = {0};
if (nullptr == realpath(config_path_str, real_path)) {
MS_LOG(ERROR) << "Env e2e dump path error, " << config_path_str;

@ -34,7 +34,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
manager_ptr->AddFuncGraph(func_graph);
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
if (MsContext::GetInstance()->is_multi_graph_sink()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
}

@ -182,7 +182,7 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool check_bprop_flag = context->check_bprop_flag();
bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
// Skip checking if check_bprop not set
if (!check_bprop_flag) {
return;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save