From 178fb01e538a36ba9c2c3fba678619ccca876e43 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Mon, 21 Sep 2020 12:15:25 +0800 Subject: [PATCH] support negative axis with tuple type --- .../kernel_compiler/cpu/reduce_cpu_kernel.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc index ef5b06ad2f..5e646d5720 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc @@ -39,6 +39,7 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { } shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); + if (axis_addr->isa()) { auto attr_axis = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); if (attr_axis.size() > shape_.size()) { @@ -47,18 +48,24 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { axis_.push_back(shape_.size() - 1); } else { for (auto axis : attr_axis) { + while (axis < 0) { + axis += SizeToInt(shape_.size()); + } if (IntToSize(axis) >= (shape_.size())) { MS_LOG(EXCEPTION) << "axis value is oversize."; } - axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); + axis_.push_back(IntToSize(axis)); } } } else if (axis_addr->isa()) { int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis >= 0 && IntToSize(axis) >= shape_.size()) { + while (axis < 0) { + axis += SizeToInt(shape_.size()); + } + if (IntToSize(axis) >= shape_.size()) { MS_LOG(EXCEPTION) << "axis value is oversize."; } - axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); + axis_.push_back(IntToSize(axis)); } else { MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; }