|
|
|
@ -156,8 +156,9 @@ int StridedSlice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr int kStridedSliceOutputNum = 1;
|
|
|
|
|
constexpr int kStridedSliceInputNum = 1;
|
|
|
|
|
constexpr size_t kStridedSliceOutputNum = 1;
|
|
|
|
|
constexpr size_t kStridedSliceInputNum = 1;
|
|
|
|
|
constexpr size_t kStridedSliceMultiInputNum = 4;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void StridedSlice::ApplyNewAxisMask() {
|
|
|
|
@ -231,7 +232,7 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
if (inputs.size() != kStridedSliceInputNum) {
|
|
|
|
|
if (inputs.size() != kStridedSliceInputNum && inputs.size() != kStridedSliceMultiInputNum) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
@ -244,13 +245,33 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
MS_ASSERT(input != nullptr);
|
|
|
|
|
auto input_shape = input->shape();
|
|
|
|
|
std::vector<int> output_shape;
|
|
|
|
|
ndim_ = static_cast<int>(GetBegin().size());
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
|
begins_.emplace_back((GetBegin())[i]);
|
|
|
|
|
ends_.emplace_back((GetEnd())[i]);
|
|
|
|
|
strides_.emplace_back((GetStride())[i]);
|
|
|
|
|
if (inputs.size() == kStridedSliceInputNum) {
|
|
|
|
|
ndim_ = static_cast<int>(GetBegin().size());
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
|
begins_.emplace_back((GetBegin())[i]);
|
|
|
|
|
ends_.emplace_back((GetEnd())[i]);
|
|
|
|
|
strides_.emplace_back((GetStride())[i]);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto begin_tensor = inputs.at(1);
|
|
|
|
|
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
|
|
|
|
|
auto end_tensor = inputs.at(2);
|
|
|
|
|
int *end_data = reinterpret_cast<int *>(end_tensor->MutableData());
|
|
|
|
|
auto stride_tensor = inputs.at(3);
|
|
|
|
|
int *stride_data = reinterpret_cast<int *>(stride_tensor->MutableData());
|
|
|
|
|
if (begin_data == nullptr || end_data == nullptr || stride_data == nullptr) {
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
ndim_ = begin_tensor->ElementsNum();
|
|
|
|
|
for (int i=0; i< ndim_; ++i) {
|
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
|
begins_.emplace_back(begin_data[i]);
|
|
|
|
|
ends_.emplace_back(end_data[i]);
|
|
|
|
|
strides_.emplace_back(stride_data[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// set all mask to original input shape
|
|
|
|
|