|
|
|
@ -49,9 +49,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
|
|
|
|
|
|
|
|
|
|
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
bool in_black_list = CheckInModeBlackList(kernel_graph);
|
|
|
|
|
for (const auto &kernel_node : kernel_graph->execution_order()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
device::gpu::SetKernelInfo(kernel_node);
|
|
|
|
|
device::gpu::SetKernelInfo(kernel_node, in_black_list);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -75,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
|
|
|
|
if (context_ptr->execution_mode() != kPynativeMode) {
|
|
|
|
|
if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) {
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
|
|
|
@ -192,6 +193,28 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GPUSession::CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
|
|
|
|
auto kernels = kernel_graph->execution_order();
|
|
|
|
|
size_t conv_cnt = 0;
|
|
|
|
|
size_t bn_cnt = 0;
|
|
|
|
|
for (const auto &kernel : kernels) {
|
|
|
|
|
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
|
|
|
|
|
if (kernel_name == prim::kPrimLayerNorm->name()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (kernel_name == prim::kPrimConv2D->name()) {
|
|
|
|
|
conv_cnt++;
|
|
|
|
|
}
|
|
|
|
|
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
|
|
|
|
|
bn_cnt++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
|
|
|
|
// Construct graph, if successfully, graph_sum_ + 1
|
|
|
|
|
auto graph_id = graph_sum_;
|
|
|
|
|