diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h index a1a9eeecdb..0ac125321e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h @@ -130,8 +130,9 @@ class NcclGpuKernel : public GpuKernel { for (size_t j = 0; j < shape.size(); j++) { size *= IntToSize(shape[j]); } - input_size_list_.push_back(size); - input_size_ += size; + size_t aligned_size = AlignMemorySize(size); + input_size_list_.push_back(aligned_size); + input_size_ += aligned_size; } for (size_t i = 0; i < output_num; ++i) { auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); @@ -139,8 +140,9 @@ class NcclGpuKernel : public GpuKernel { for (size_t j = 0; j < shape.size(); j++) { size *= IntToSize(shape[j]); } - output_size_list_.push_back(size); - output_size_ += size; + size_t aligned_size = AlignMemorySize(size); + output_size_list_.push_back(aligned_size); + output_size_ += aligned_size; } InferCommType(kernel_node); @@ -193,6 +195,13 @@ class NcclGpuKernel : public GpuKernel { return; } + size_t AlignMemorySize(size_t size) const { + if (size == 0) { + return COMMUNICATION_MEM_ALIGN_SIZE; + } + return ((size + COMMUNICATION_MEM_ALIGN_SIZE - 1) / COMMUNICATION_MEM_ALIGN_SIZE) * COMMUNICATION_MEM_ALIGN_SIZE; + } + NcclKernelType nccl_kernel_type_; ncclRedOp_t nccl_reduce_type_; ncclDataType_t nccl_data_type_; @@ -205,6 +214,8 @@ class NcclGpuKernel : public GpuKernel { std::vector workspace_size_list_; const void *collective_handle_; cudaStream_t comm_stream_; + + static const size_t COMMUNICATION_MEM_ALIGN_SIZE = 16; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h index 01c1079a86..14fc721889 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -75,7 +75,7 @@ class ActivationGpuFwdKernel : public GpuKernel { MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1."; return false; } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); is_null_input_ = CHECK_NULL_INPUT(input_shape); if (is_null_input_) { MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null."; @@ -89,9 +89,15 @@ class ActivationGpuFwdKernel : public GpuKernel { const int split_dim = 4; if (input_shape.size() <= split_dim) { ShapeNdTo4d(input_shape, &shape); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape[0], shape[1], shape[2], shape[3]), - "cudnnSetTensor4dDescriptor failed"); + if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, + shape[0], shape[3], shape[1], shape[2]), + "cudnnSetTensor4dDescriptor failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + shape[0], shape[1], shape[2], shape[3]), + "cudnnSetTensor4dDescriptor failed"); + } } else { CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_); } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc index 2483e8171a..f594320d91 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace opt { const BaseRef ReplaceBNCastFusion::DefinePattern() const { VectorRef in_cast = VectorRef({prim::kPrimCast, x_}); - VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNorm, in_cast, scale_, bias_, mean_, var_}); + VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNormEx, in_cast, scale_, bias_, mean_, var_}); VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2, index_}); return tupleget; } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc index eb78e7280f..ef5c698a13 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace opt { const BaseRef ReplaceBNGradCastFusion::DefinePattern() const { VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_}); - VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_cast, x_, scale_, mean_, var_}); + VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGradEx, dy_cast, x_, scale_, mean_, var_, reserve_}); VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_}); return tupleget; } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h index b937aa25bf..968ed52848 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h @@ -33,6 +33,7 @@ class ReplaceBNGradCastFusion : public PatternProcessPass { bn_scale_ = std::make_shared(); bn_bias_ = std::make_shared(); index_ = std::make_shared(); + reserve_ = std::make_shared(); } ~ReplaceBNGradCastFusion() override = default; const BaseRef DefinePattern() const override; @@ -48,6 +49,7 @@ class ReplaceBNGradCastFusion : public PatternProcessPass { VarPtr bn_scale_; VarPtr bn_bias_; VarPtr index_; + VarPtr reserve_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index fc8b2dcfdf..42d47445e8 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -28,6 +28,9 @@ #include "backend/optimizer/gpu/adam_fusion.h" #include "backend/optimizer/gpu/replace_bn_cast_fusion.h" #include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" +#include "backend/optimizer/gpu/batch_norm_relu_fusion.h" +#include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" +#include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" #include "backend/optimizer/gpu/replace_addn_fusion.h" #include "backend/optimizer/gpu/insert_format_transform_op.h" @@ -70,6 +73,9 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc index c7fbda2dad..40662a334f 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc @@ -40,17 +40,21 @@ bool GPUDeviceAddress::SyncDeviceToHost(const std::vector &, size_t size, T return ret; } if (size != size_) { - MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_; - return true; + // nccl kernel input and outpu memory size is aligned, may lead to sync memory size is inconformity + MS_LOG(INFO) << "Sync memory size is inconformity, host size: " << size << ", device size " << size_; } - return GPUDeviceManager::GetInstance().CopyDeviceMemToHost(host_ptr, ptr_, size_); + return GPUDeviceManager::GetInstance().CopyDeviceMemToHost(host_ptr, ptr_, size); } -bool GPUDeviceAddress::SyncHostToDevice(const std::vector &, size_t, TypeId, const void *host_ptr) const { +bool GPUDeviceAddress::SyncHostToDevice(const std::vector &, size_t size, TypeId, const void *host_ptr) const { MS_EXCEPTION_IF_NULL(host_ptr); auto &stream = GPUDeviceManager::GetInstance().default_stream(); MS_EXCEPTION_IF_NULL(stream); - if (!GPUDeviceManager::GetInstance().CopyHostMemToDeviceAsync(ptr_, host_ptr, size_, stream)) { + if (size != size_) { + // nccl kernel input and outpu memory size is aligned, may lead to sync memory size is inconformity + MS_LOG(INFO) << "Sync memory size is inconformity, host size: " << size << ", device size " << size_; + } + if (!GPUDeviceManager::GetInstance().CopyHostMemToDeviceAsync(ptr_, host_ptr, size, stream)) { MS_LOG(ERROR) << "CopyHostMemToDeviceAsync failed"; return false; } diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index eca69aa8dc..255d7b46d8 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -1001,7 +1001,10 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN size_t total_size = 0; std::vector size_list; DeviceAddressPtrList addr_list; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto intput_sizes = kernel_mod->GetInputSizeList(); + for (size_t i = 0; i < intput_sizes.size(); ++i) { DeviceAddressPtr device_address; if (mem_reuse_util_->is_all_nop_node()) { // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. @@ -1016,8 +1019,8 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN } else { is_need_free_memory = true; } - total_size += device_address->size_; - size_list.emplace_back(device_address->size_); + total_size += intput_sizes[i]; + size_list.emplace_back(intput_sizes[i]); addr_list.emplace_back(device_address); } AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index 3fb2d89d19..f231f9cdc3 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -180,7 +180,7 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector &inputs_type, diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 843e09b2ca..41dba64373 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -86,6 +86,7 @@ class _BatchNorm(Cell): self.dtype = P.DType() self.reshape = P.Reshape() self.is_ascend = context.get_context("device_target") == "Ascend" + self.is_gpu = context.get_context("device_target") == "GPU" self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE self.momentum = 1.0 - momentum if context.get_context("enable_ge"): @@ -96,6 +97,10 @@ class _BatchNorm(Cell): if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps) + elif self.is_gpu: + self.bn_train = P.FusedBatchNormEx(mode=1, + epsilon=self.eps, + momentum=self.momentum) else: self.bn_train = P.FusedBatchNorm(mode=1, epsilon=self.eps, diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index f61454eb20..cdba616310 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -535,6 +535,24 @@ def get_bprop_fused_batch_norm(self): return bprop +@bprop_getters.register(P.FusedBatchNormEx) +def get_bprop_fused_batch_norm_ex(self): + """Grad definition for `FusedBatchNormEx` operation.""" + input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum) + + def bprop(x, scale, b, mean, variance, out, dout): + saved_mean = out[3] + saved_variance = out[4] + reserve = out[5] + out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve) + dx = out[0] + dscale = out[1] + dbias = out[2] + return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) + + return bprop + + @bprop_getters.register(P.BatchNorm) def get_bprop_batch_norm(self): """Grad definition for `BatchNorm` operation.""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index ca03ad2edf..112fc07689 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl BiasAdd, Conv2D, DepthwiseConv2dNative, DropoutDoMask, DropoutGrad, Dropout, - DropoutGenMask, Flatten, FusedBatchNorm, BNTrainingReduce, BNTrainingUpdate, + DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate, Gelu, Elu, GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, LogSoftmax, @@ -118,6 +118,7 @@ __all__ = [ 'Flatten', 'MaxPoolWithArgmax', 'FusedBatchNorm', + 'FusedBatchNormEx', 'BNTrainingReduce', 'BNTrainingUpdate', 'BatchNorm', diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 417441fb41..101fe15483 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -491,6 +491,22 @@ class FusedBatchNormGrad(Primitive): raise NotImplementedError +class FusedBatchNormGradEx(PrimitiveWithInfer): + """Gradients of FusedBatchNormEx operation.""" + + @prim_attr_register + def __init__(self, epsilon=0.0, momentum=0.1): + self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance', 'reserve'], + outputs=['dx', 'bn_scale', 'bn_bias']) + self.add_prim_attr('data_format', "NCHW") + + def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve_shape): + return (x_shape, scale_shape, scale_shape) + + def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_type, save_variance_type, reserve_type): + return (x_type, scale_type, scale_type) + + class UniqueGrad(Primitive): """Gradients of Unique operation.""" diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 50316fab8f..5e7a1da92a 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -623,6 +623,73 @@ class FusedBatchNorm(Primitive): self._update_parameter = True +class FusedBatchNormEx(PrimitiveWithInfer): + r""" + FusedBatchNormEx is an extension of FusedBatchNorm + + Args: + mode (int): Mode of batch normalization, value is 0 or 1. Default: 0. + epsilon (float): A small value added for numerical stability. Default: 1e-5. + momentum (float): The hyper parameter to compute moving average for running_mean and running_var + (e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`). + Momentum value should be [0, 1]. Default: 0.9. + + Inputs: + - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`. + - **scale** (Tensor) - Tensor of shape :math:`(C,)`. + - **bias** (Tensor) - Tensor of shape :math:`(C,)`. + - **mean** (Tensor) - Tensor of shape :math:`(C,)`. + - **variance** (Tensor) - Tensor of shape :math:`(C,)`. + + Outputs: + Tuple of 6 Tensor, the normalized input and the updated parameters. + + - **output_x** (Tensor) - The same type and shape as the `input_x`. + - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`. + - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. + - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`. + - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. + - **reserve** (Tensor) - Tensor of shape :math:`(C,)`. + + Examples: + >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) + >>> scale = Tensor(np.ones([64]), mindspore.float32) + >>> bias = Tensor(np.ones([64]), mindspore.float32) + >>> mean = Tensor(np.ones([64]), mindspore.float32) + >>> variance = Tensor(np.ones([64]), mindspore.float32) + >>> op = P.FusedBatchNormEx() + >>> output = op(input_x, scale, bias, mean, variance) + """ + + @prim_attr_register + def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): + self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], + outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) + self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) + self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) + self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) + self._update_parameter = True + self.add_prim_attr('data_format', "NCHW") + + def infer_shape(self, input_x, scale, bias, mean, variance): + validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) + validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) + validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) + validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) + validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) + validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) + return (input_x, scale, scale, scale, scale, scale) + + def infer_dtype(self, input_x, scale, bias, mean, variance): + validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) + args = {"scale": scale, "bias": bias} + validator.check_tensor_type_same(args, [mstype.float32], self.name) + args_moving = {"mean": mean, "variance": variance} + valid_types = [mstype.tensor_type(mstype.float32)] + validator.check_type_same(args_moving, valid_types, self.name) + return (input_x, scale, scale, scale, scale, scale) + + class BNTrainingReduce(PrimitiveWithInfer): """ reduce sum at axis [0, 2, 3].