|
|
|
@ -16,8 +16,6 @@
|
|
|
|
|
|
|
|
|
|
#include "host_kernels/slice_kernel.h"
|
|
|
|
|
|
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
|
|
#include "common/ge_inner_error_codes.h"
|
|
|
|
|
#include "common/op/ge_op_utils.h"
|
|
|
|
|
#include "common/types.h"
|
|
|
|
@ -33,30 +31,6 @@ const size_t kSliceInputSize = 3;
|
|
|
|
|
const size_t kSliceInputIndexX = 0;
|
|
|
|
|
const size_t kSliceInputIndexBegin = 1;
|
|
|
|
|
const size_t kSliceInputIndexSize = 2;
|
|
|
|
|
const std::set<ge::DataType> kSupportedDataTypeToLength = {
|
|
|
|
|
DT_BOOL,
|
|
|
|
|
DT_INT64,
|
|
|
|
|
DT_UINT64,
|
|
|
|
|
DT_FLOAT,
|
|
|
|
|
DT_INT32,
|
|
|
|
|
DT_UINT32,
|
|
|
|
|
DT_INT8,
|
|
|
|
|
DT_UINT8,
|
|
|
|
|
DT_INT16,
|
|
|
|
|
DT_UINT16,
|
|
|
|
|
DT_FLOAT16,
|
|
|
|
|
DT_DOUBLE,
|
|
|
|
|
DT_DUAL,
|
|
|
|
|
DT_DUAL_SUB_INT8,
|
|
|
|
|
DT_DUAL_SUB_UINT8,
|
|
|
|
|
DT_COMPLEX64,
|
|
|
|
|
DT_COMPLEX128,
|
|
|
|
|
DT_QINT8,
|
|
|
|
|
DT_QINT16,
|
|
|
|
|
DT_QINT32,
|
|
|
|
|
DT_QUINT8,
|
|
|
|
|
DT_QUINT16,
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTensorPtr> &input,
|
|
|
|
@ -79,18 +53,9 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso
|
|
|
|
|
GELOGW("input tensor is nullptr.");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// data type in input_x
|
|
|
|
|
auto data_type = x_->GetTensorDesc().GetDataType();
|
|
|
|
|
// check supported
|
|
|
|
|
if (kSupportedDataTypeToLength.count(data_type) == 0) {
|
|
|
|
|
GELOGW("input_x data_type is [%s], does not supported!", TypeUtils::DataTypeToSerialString(data_type).c_str());
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
uint32_t type_size = 0;
|
|
|
|
|
bool is_success = TypeUtils::GetDataTypeLength(data_type, type_size);
|
|
|
|
|
if (!is_success) {
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
// check data type of begin and size
|
|
|
|
|
if (begin->GetTensorDesc().GetDataType() != DT_INT32 || size->GetTensorDesc().GetDataType() != DT_INT32) {
|
|
|
|
|
GELOGW("Data type of begin and size for slice are not DT_INT32.");
|
|
|
|
@ -104,7 +69,7 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso
|
|
|
|
|
GE_CHECK_NOTNULL(begin_data);
|
|
|
|
|
GE_CHECK_NOTNULL(size_data);
|
|
|
|
|
|
|
|
|
|
size_t data_size = x_->GetData().size() / type_size;
|
|
|
|
|
size_t data_size = x_->GetData().size() / sizeof(int32_t);
|
|
|
|
|
size_t begin_size = begin->GetData().size() / sizeof(int32_t);
|
|
|
|
|
size_t size_size = size->GetData().size() / sizeof(int32_t);
|
|
|
|
|
const ge::GeShape &x_shape = x_->GetTensorDesc().GetShape();
|
|
|
|
|