|
|
@ -254,18 +254,17 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
auto input = inputs.at(0);
|
|
|
|
auto input = inputs.at(0);
|
|
|
|
outputs.front()->set_data_type(input->data_type());
|
|
|
|
outputs.front()->set_data_type(input->data_type());
|
|
|
|
outputs[0]->SetFormat(input->GetFormat());
|
|
|
|
outputs[0]->SetFormat(input->GetFormat());
|
|
|
|
if (!GetInferFlag()) {
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MS_ASSERT(input != nullptr);
|
|
|
|
MS_ASSERT(input != nullptr);
|
|
|
|
auto input_shape = input->shape();
|
|
|
|
auto input_shape = input->shape();
|
|
|
|
std::vector<int> output_shape;
|
|
|
|
auto inferflag = GetInferFlag();
|
|
|
|
|
|
|
|
|
|
|
|
if (inputs.size() == kStridedSliceInputNum) {
|
|
|
|
if (inputs.size() == kStridedSliceInputNum) {
|
|
|
|
ndim_ = static_cast<int>(GetBegin().size());
|
|
|
|
ndim_ = static_cast<int>(GetBegin().size());
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
|
|
|
|
if (inferflag) {
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
|
|
|
|
}
|
|
|
|
begins_.emplace_back((GetBegin())[i]);
|
|
|
|
begins_.emplace_back((GetBegin())[i]);
|
|
|
|
ends_.emplace_back((GetEnd())[i]);
|
|
|
|
ends_.emplace_back((GetEnd())[i]);
|
|
|
|
strides_.emplace_back((GetStride())[i]);
|
|
|
|
strides_.emplace_back((GetStride())[i]);
|
|
|
@ -282,7 +281,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ndim_ = begin_tensor->ElementsNum();
|
|
|
|
ndim_ = begin_tensor->ElementsNum();
|
|
|
|
for (int i = 0; i < ndim_; ++i) {
|
|
|
|
for (int i = 0; i < ndim_; ++i) {
|
|
|
|
|
|
|
|
if (inferflag) {
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
|
|
|
|
}
|
|
|
|
begins_.emplace_back(begin_data[i]);
|
|
|
|
begins_.emplace_back(begin_data[i]);
|
|
|
|
ends_.emplace_back(end_data[i]);
|
|
|
|
ends_.emplace_back(end_data[i]);
|
|
|
|
strides_.emplace_back(stride_data[i]);
|
|
|
|
strides_.emplace_back(stride_data[i]);
|
|
|
@ -310,6 +311,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
ApplyEndMask();
|
|
|
|
ApplyEndMask();
|
|
|
|
ApplyEllipsisMask();
|
|
|
|
ApplyEllipsisMask();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!inferflag) {
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<int> output_shape;
|
|
|
|
output_shape.clear();
|
|
|
|
output_shape.clear();
|
|
|
|
output_shape.resize(in_shape_.size());
|
|
|
|
output_shape.resize(in_shape_.size());
|
|
|
|
|
|
|
|
|
|
|
|