diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc index bf88fef013..f323e016fc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc @@ -18,9 +18,16 @@ #include #include "backend/kernel_compiler/common_utils.h" #include "backend/session/anf_runtime_algorithm.h" +#include "utils/ms_context.h" #include "utils/utils.h" namespace mindspore { +bool IsPyNativeMode() { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + return ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode; +} + bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_intput_shape_list) { MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list); @@ -129,10 +136,16 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vectorcast(); + if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr(anf_node, kAttrFusion)) { + block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; + } else { + block_size = input_size; + } } else { - block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; + block_size = + IsPyNativeMode() ? input_size : (input_size + align_size - 1 + filled_size) / align_size * align_size; } total_size = total_size + block_size; }