|
|
|
@ -29,30 +29,16 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES);
|
|
|
|
|
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END);
|
|
|
|
|
TransArg();
|
|
|
|
|
for (size_t i = 0; i < begin_.size(); i++) {
|
|
|
|
|
while (begin_[i] < 0) {
|
|
|
|
|
begin_[i] = begin_[i] + input_shape_[i];
|
|
|
|
|
}
|
|
|
|
|
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
|
|
|
|
begin_[i] = input_shape_[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ClipBegin();
|
|
|
|
|
} else {
|
|
|
|
|
auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE);
|
|
|
|
|
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < begin_.size(); i++) {
|
|
|
|
|
while (begin_[i] < 0) {
|
|
|
|
|
begin_[i] = begin_[i] + input_shape_[i];
|
|
|
|
|
}
|
|
|
|
|
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
|
|
|
|
begin_[i] = input_shape_[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ClipBegin();
|
|
|
|
|
for (size_t i = 0; i < sizes.size(); ++i) {
|
|
|
|
|
while (sizes[i] < 0) {
|
|
|
|
|
sizes[i] = sizes[i] + input_shape_[i];
|
|
|
|
|
sizes[i] = sizes[i] + SizeToInt(input_shape_[i]);
|
|
|
|
|
}
|
|
|
|
|
strides_.emplace_back(1);
|
|
|
|
|
end_.emplace_back(begin_[i] + sizes[i]);
|
|
|
|
@ -62,7 +48,17 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
|
|
|
|
|
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SliceCPUKernel::ClipBegin() {
|
|
|
|
|
for (size_t i = 0; i < begin_.size(); i++) {
|
|
|
|
|
if (begin_[i] < 0) {
|
|
|
|
|
auto k = begin_[i] + SizeToInt(input_shape_[i]);
|
|
|
|
|
begin_[i] = k < 0 ? 0 : k;
|
|
|
|
|
}
|
|
|
|
|
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
|
|
|
|
begin_[i] = SizeToInt(input_shape_[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void SliceCPUKernel::ExpandAllMemberDims() {
|
|
|
|
|
auto input_len = input_shape_.size();
|
|
|
|
|
if (input_len < 4) {
|
|
|
|
@ -178,13 +174,13 @@ void SliceCPUKernel::TransArg() {
|
|
|
|
|
MS_LOG(EXCEPTION) << "slice stride cannot be zero";
|
|
|
|
|
}
|
|
|
|
|
if (end_[i] == 0 && begin_[i] < 0) {
|
|
|
|
|
end_[i] = end_[i] + input_shape_[i];
|
|
|
|
|
end_[i] = end_[i] + SizeToInt(input_shape_[i]);
|
|
|
|
|
}
|
|
|
|
|
while (end_[i] < 0) {
|
|
|
|
|
end_[i] = end_[i] + input_shape_[i];
|
|
|
|
|
if (end_[i] < 0) {
|
|
|
|
|
end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]);
|
|
|
|
|
}
|
|
|
|
|
if (end_[i] > SizeToInt(input_shape_[i])) {
|
|
|
|
|
end_[i] = input_shape_[i];
|
|
|
|
|
end_[i] = SizeToInt(input_shape_[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|