|
|
|
@ -39,6 +39,7 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
}
|
|
|
|
|
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
|
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
|
|
|
|
|
|
|
|
|
|
if (axis_addr->isa<ValueTuple>()) {
|
|
|
|
|
auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS);
|
|
|
|
|
if (attr_axis.size() > shape_.size()) {
|
|
|
|
@ -47,18 +48,24 @@ void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
axis_.push_back(shape_.size() - 1);
|
|
|
|
|
} else {
|
|
|
|
|
for (auto axis : attr_axis) {
|
|
|
|
|
while (axis < 0) {
|
|
|
|
|
axis += SizeToInt(shape_.size());
|
|
|
|
|
}
|
|
|
|
|
if (IntToSize(axis) >= (shape_.size())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
|
|
|
|
}
|
|
|
|
|
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis);
|
|
|
|
|
axis_.push_back(IntToSize(axis));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (axis_addr->isa<Int32Imm>()) {
|
|
|
|
|
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
|
|
|
|
|
if (axis >= 0 && IntToSize(axis) >= shape_.size()) {
|
|
|
|
|
while (axis < 0) {
|
|
|
|
|
axis += SizeToInt(shape_.size());
|
|
|
|
|
}
|
|
|
|
|
if (IntToSize(axis) >= shape_.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
|
|
|
|
}
|
|
|
|
|
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis);
|
|
|
|
|
axis_.push_back(IntToSize(axis));
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Attribute axis type is invalid.";
|
|
|
|
|
}
|
|
|
|
|