gpu optimize Nop node

pull/2094/head
limingqi107 5 years ago
parent ff0590315c
commit b83f90a8d8

@ -384,7 +384,7 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
bool IsNopNode(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->device_target() != kAscendDevice) {
if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) {
return false;
}
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,

@ -154,6 +154,8 @@ tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_pt
tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple);
bool IsAllNopNode(const session::KernelGraph *const graph);
bool IsNopNode(const AnfNodePtr &node);
void HideNopNode(session::KernelGraph *const graph);

@ -18,6 +18,8 @@
#include <algorithm>
#include <memory>
#include "pre_activate/mem_reuse/mem_reuse_checker.h"
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace memreuse {
bool MemReuseUtil::InitDynamicOutputKernelRef() {
@ -324,9 +326,17 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
}
void MemReuseUtil::SetGraphOutputRefCount() {
auto is_all_nop_node = opt::IsAllNopNode(graph_);
auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem});
for (const auto &node : nodes) {
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0);
session::KernelWithIndex kernel_input;
if (is_all_nop_node) {
// The graph does not remove the nop node.
kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false);
} else {
// The graph removes the nop node.
kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
}
MS_EXCEPTION_IF_NULL(kernel_input.first);
if (!kernel_input.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_input.first)) {
continue;

@ -75,7 +75,6 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
// opt::RemoveNopNode(kernel_graph);
runtime_instance->AssignMemory(kernel_graph);
}
@ -84,7 +83,6 @@ void GPUSession::RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input
MS_EXCEPTION_IF_NULL(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
// opt::RemoveNopNode(kernel_graph);
runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph);
}
@ -156,14 +154,16 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
Optimize(graph);
// Assign CUDA streams
AssignStream(graph);
// Remove NoOp from execution graph
// opt::HideNopNode(graph.get());
// Hide NoOp from execution graph
opt::HideNopNode(graph.get());
// Build kernel if node is cnode
BuildKernel(graph);
// Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph
auto execution_order = graph->execution_order();
Reorder(&execution_order);
graph->set_execution_order(execution_order);
// Remove NoOp from execution graph
opt::RemoveNopNode(graph.get());
// Alloc memory, including static memory and dynamic memory
AllocateMemory(graph.get());
MS_EXCEPTION_IF_NULL(context_);
@ -205,6 +205,8 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in
MS_EXCEPTION_IF_NULL(kernel_graph);
SelectKernel(kernel_graph);
StartKernelRT();
// Hide NoOp from execution graph
opt::HideNopNode(kernel_graph.get());
BuildKernel(kernel_graph);
run_op_graphs_[graph_info] = kernel_graph;
}
@ -213,6 +215,8 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
const std::vector<tensor::TensorPtr> &input_tensors) {
auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph);
// Remove NoOp from execution graph
opt::RemoveNopNode(kernel_graph.get());
RunOpAllocateMemory(input_tensors, kernel_graph.get());
// Execute the computation
LoadInputData(kernel_graph, input_tensors);

Loading…
Cancel
Save