From 458f0e7c5868095c80f3325bb711640d63c0f1d4 Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Sat, 28 Nov 2020 09:15:00 +0800 Subject: [PATCH] dynamic shape adapting for allreduce and reducesum --- .../gpu/arrays/array_reduce_gpu_kernel.h | 3 +- .../gpu/cuda_impl/check_valid_impl.cu | 2 - .../gpu/nccl/nccl_collective_gpu_kernel.h | 27 +++--- .../pass/convert_const_input_to_attr.cc | 4 +- mindspore/core/abstract/infer_functions.h | 2 + mindspore/core/abstract/prim_maths.cc | 90 ++++++++++++++++++- .../core/abstract/primitive_infer_map.cc | 1 + mindspore/ops/operations/math_ops.py | 9 ++ 8 files changed, 122 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h index f82ec1f556..aea437b0b6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -133,6 +133,7 @@ class ArrayReduceGpuKernel : public GpuKernel { input_size_ = 0; output_size_ = 0; workspace_size_ = 0; + axis_.clear(); input_size_list_.clear(); output_size_list_.clear(); workspace_size_list_.clear(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu index ac5180d971..1698a0f142 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu @@ -40,8 +40,6 @@ template __global__ void CheckValidKernel(const size_t size, const unsigned char *box, const unsigned char *img_metas, S *valid) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - const size_t left_x = i * 4; - const size_t left_y = i * 4 + 1; const size_t right_x = i * 4 + 2; const size_t right_y = i * 4 + 3; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h index d04753e888..8af1c84c91 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h @@ -43,14 +43,7 @@ const std::map kNcclTypeMap = { template class NcclCollectiveGpuKernel : public NcclGpuKernel { public: - NcclCollectiveGpuKernel() - : nccl_kernel_type_(NCCL_INVALID_TYPE), - nccl_reduce_type_(ncclSum), - input_size_(0), - output_size_(0), - root_(0), - collective_handle_(nullptr), - comm_stream_(nullptr) {} + NcclCollectiveGpuKernel() { ResetResource(); } ~NcclCollectiveGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -109,6 +102,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { } return true; } + bool Init(const CNodePtr &kernel_node) override { nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); InferCommType(kernel_node); @@ -116,7 +110,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); for (size_t i = 0; i < input_num; ++i) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + auto shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, i); size_t size = sizeof(T); for (size_t j = 0; j < shape.size(); j++) { size *= IntToSize(shape[j]); @@ -126,7 +120,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { input_size_ += aligned_size; } for (size_t i = 0; i < output_num; ++i) { - auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); + auto shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, i); size_t size = sizeof(T); for (size_t j = 0; j < shape.size(); j++) { size *= IntToSize(shape[j]); @@ -149,6 +143,19 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { return true; } + void ResetResource() noexcept override { + nccl_kernel_type_ = NCCL_INVALID_TYPE; + nccl_reduce_type_ = ncclSum; + input_size_ = 0; + output_size_ = 0; + root_ = 0; + collective_handle_ = nullptr; + comm_stream_ = nullptr; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + protected: void InitSizeLists() override { return; } diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc index 3ccdc89128..04d2ea723b 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -43,8 +43,8 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An todos.push_back(node); } - std::set DynamicShapeConstInputToAttr = {kCastOpName, kExpandDimsOpName, kReshapeOpName, - kEmbeddingLookupOpName, kTransposeOpName}; + std::set DynamicShapeConstInputToAttr = { + kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName}; for (auto &t : todos) { CNodePtr cnode = t->cast(); ConstInputToAttrInfoRegister reg; diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 1f684ef556..5b84f49387 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -251,6 +251,8 @@ AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &prim const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 3bd5e1b686..af139f64db 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,6 +121,94 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr return ret; } +AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto input_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_x->element()); + + ValuePtr keep_dims = primitive->GetAttr("keep_dims"); + MS_EXCEPTION_IF_NULL(keep_dims); + if (!keep_dims->isa()) { + MS_LOG(EXCEPTION) << "Keep_dims should be Bool."; + } + bool keep_dims_value = GetValue(keep_dims); + + ValuePtr axis = primitive->GetAttr("axis"); + MS_EXCEPTION_IF_NULL(axis); + + auto check_axis = [](int64_t &axis, const size_t dim) -> void { + int64_t dim_ = static_cast(dim); + if (axis < -dim_ || axis >= dim_) { + MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis; + } + if (axis >= -dim_ && axis < 0) { + axis += dim_; + } + return; + }; + + auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void { + if (axis->isa() || axis->isa()) { + auto axis_ptr_list = + axis->isa() ? axis->cast()->value() : axis->cast()->value(); + if (!axis_ptr_list.size()) { + if (keep_dims_value) shape.insert(shape.end(), x_shape.size(), 1); + } else { + shape.insert(shape.end(), x_shape.begin(), x_shape.end()); + ValuePtrList axis_items = axis_ptr_list; + ValuePtrList::iterator it; + ValuePtrList::reverse_iterator it_re; + int64_t axis_value; + if (keep_dims_value) { + for (it = axis_items.begin(); it != axis_items.end(); ++it) { + axis_value = GetValue(*it); + check_axis(axis_value, x_shape.size()); + shape[axis_value] = 1; + } + } else { + std::sort(axis_items.begin(), axis_items.end()); + for (it_re = axis_items.rbegin(); it_re != axis_items.rend(); ++it_re) { + axis_value = GetValue(*it_re); + check_axis(axis_value, x_shape.size()); + shape.erase(std::begin(shape) + axis_value); + } + } + } + } else if (axis->isa() || axis->isa()) { + shape.insert(shape.end(), x_shape.begin(), x_shape.end()); + int64_t axis_value = GetValue(axis); + check_axis(axis_value, x_shape.size()); + if (keep_dims_value) { + shape[axis_value] = 1; + } else { + shape.erase(std::begin(shape) + axis_value); + } + } else { + MS_LOG(EXCEPTION) << "Axis should be one of types: [int/tuple/list]."; + } + return; + }; + + ShapeVector shape = {}; + ShapeVector x_shape = input_x->shape()->shape(); + cal_shape(shape, x_shape); + + bool x_is_dyn = (!input_x->shape()->min_shape().empty() && !input_x->shape()->max_shape().empty()); + if (x_is_dyn) { + ShapeVector shape_min = {}; + ShapeVector shape_max = {}; + ShapeVector x_shape_min = input_x->shape()->min_shape(); + ShapeVector x_shape_max = input_x->shape()->max_shape(); + cal_shape(shape_min, x_shape_min); + cal_shape(shape_max, x_shape_max); + return std::make_shared(input_x->element(), std::make_shared(shape, shape_min, shape_max)); + } + return std::make_shared(input_x->element(), std::make_shared(shape)); +} + AbstractBasePtr InferImplBinaryBase(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 17aa4bbc6e..f460a65c91 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -44,6 +44,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, {prim::kPrimSub, {InferImplSub, true}}, {prim::kPrimEqual, {InferImplEqual, true}}, + {prim::kPrimReduceSum, {InferImplReduceSum, true}}, {prim::kPrimMinimum, {InferImplMinimum, true}}, {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, {prim::kPrimLinSpace, {InferImplLinSpace, true}}, diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index c6258639ee..ab9a4fdf8b 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -320,7 +320,16 @@ class _Reduce(PrimitiveWithInfer): value = np_reduce_func(value, axis_v, keepdims=self.keep_dims) value = np.array(value) value = Tensor(value) + if 'max_shape' and 'min_shape' in input_x: + output_max_shape = _infer_shape_reduce(input_x['max_shape'], axis_v, self.keep_dims, self.name) + output_min_shape = _infer_shape_reduce(input_x['min_shape'], axis_v, self.keep_dims, self.name) + else: + output_max_shape = input_shp + output_min_shape = input_shp + return {'shape': input_shp, + 'min_shape': output_min_shape, + 'max_shape': output_max_shape, 'dtype': input_x['dtype'], 'value': value}