|
|
|
@ -21,7 +21,6 @@ namespace kernel {
|
|
|
|
|
void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
CheckParam(kernel_node);
|
|
|
|
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
|
|
|
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
|
|
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
@ -65,12 +64,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SliceCPUKernel::ExpandAllMemberDims() {
|
|
|
|
|
auto output_len = output_shape_.size();
|
|
|
|
|
if (output_len < 4) {
|
|
|
|
|
for (size_t i = 0; i < 4 - output_len; ++i) {
|
|
|
|
|
output_shape_.push_back(1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto input_len = input_shape_.size();
|
|
|
|
|
if (input_len < 4) {
|
|
|
|
|
for (size_t i = 0; i < 4 - input_len; ++i) {
|
|
|
|
@ -80,6 +73,15 @@ void SliceCPUKernel::ExpandAllMemberDims() {
|
|
|
|
|
end_.insert(end_.begin(), 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < 4; ++i) {
|
|
|
|
|
if (SignOfStride(i)) {
|
|
|
|
|
int ax = (end_[i] - begin_[i]) * SignOfStride(i);
|
|
|
|
|
if (ax < 0) {
|
|
|
|
|
ax = 0;
|
|
|
|
|
}
|
|
|
|
|
output_shape_.push_back(IntToSize(ax));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
@ -87,7 +89,6 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
|
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
|
|
|
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
|
|
|
|
|
|
|
|
|
bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
|
|
|
|
|
int signstride[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)};
|
|
|
|
|
size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1],
|
|
|
|
|