!6629 support negative axis with tuple type

Merge pull request !6629 from baihuawei/fixreduce
pull/6629/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d1d28fb032

@ -39,6 +39,7 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
} }
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
if (axis_addr->isa<ValueTuple>()) { if (axis_addr->isa<ValueTuple>()) {
auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS); auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS);
if (attr_axis.size() > shape_.size()) { if (attr_axis.size() > shape_.size()) {
@ -47,18 +48,24 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
axis_.push_back(shape_.size() - 1); axis_.push_back(shape_.size() - 1);
} else { } else {
for (auto axis : attr_axis) { for (auto axis : attr_axis) {
while (axis < 0) {
axis += SizeToInt(shape_.size());
}
if (IntToSize(axis) >= (shape_.size())) { if (IntToSize(axis) >= (shape_.size())) {
MS_LOG(EXCEPTION) << "axis value is oversize."; MS_LOG(EXCEPTION) << "axis value is oversize.";
} }
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); axis_.push_back(IntToSize(axis));
} }
} }
} else if (axis_addr->isa<Int32Imm>()) { } else if (axis_addr->isa<Int32Imm>()) {
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS); int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
if (axis >= 0 && IntToSize(axis) >= shape_.size()) { while (axis < 0) {
axis += SizeToInt(shape_.size());
}
if (IntToSize(axis) >= shape_.size()) {
MS_LOG(EXCEPTION) << "axis value is oversize."; MS_LOG(EXCEPTION) << "axis value is oversize.";
} }
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); axis_.push_back(IntToSize(axis));
} else { } else {
MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; MS_LOG(EXCEPTION) << "Attribute axis type is invalid.";
} }

Loading…
Cancel
Save