|
|
|
@ -45,7 +45,7 @@ bool IsEllipsisMaskValid(const GeTensorDescPtr &input_desc, const uint32_t ellip
|
|
|
|
|
++ellipsis_num;
|
|
|
|
|
}
|
|
|
|
|
if (ellipsis_num > 1) {
|
|
|
|
|
GELOGW("Only one non-zero bit is allowed in ellipsis_mask.");
|
|
|
|
|
GELOGW("Only one non-zero bit is allowed in ellipsis_mask");
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -84,14 +84,14 @@ void GetOriginStrideVec(const std::vector<ge::ConstGeTensorPtr> &input, vector<i
|
|
|
|
|
} // namespace
|
|
|
|
|
Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<ge::ConstGeTensorPtr> &input,
|
|
|
|
|
vector<ge::GeTensorPtr> &v_output) {
|
|
|
|
|
GELOGD("StridedSliceKernel in.");
|
|
|
|
|
GELOGD("StridedSliceKernel in");
|
|
|
|
|
// 1.Check input and attrs
|
|
|
|
|
if (CheckAndGetAttr(attr) != SUCCESS) {
|
|
|
|
|
GELOGW("Check and get attrs failed.Ignore kernel.");
|
|
|
|
|
GELOGW("Check and get attrs failed.Ignore kernel");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
if (CheckInputParam(input) != SUCCESS) {
|
|
|
|
|
GELOGW("Check input params failed.Ignore kernel.");
|
|
|
|
|
GELOGW("Check input params failed.Ignore kernel");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
// 2.Init param with mask attrs.
|
|
|
|
@ -100,7 +100,7 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<g
|
|
|
|
|
std::vector<int64_t> output_dims;
|
|
|
|
|
std::vector<int64_t> stride_vec;
|
|
|
|
|
if (InitParamWithAttrs(input, input_dims, begin_vec, output_dims, stride_vec) != SUCCESS) {
|
|
|
|
|
GELOGW("Init param with mask attrs failed.Ignore kernel.");
|
|
|
|
|
GELOGW("Init param with mask attrs failed.Ignore kernel");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -114,13 +114,13 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<g
|
|
|
|
|
auto output_tensor_desc = attr->GetOutputDesc(0);
|
|
|
|
|
GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc);
|
|
|
|
|
if (output_ptr == nullptr) {
|
|
|
|
|
GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s.", attr->GetName().c_str());
|
|
|
|
|
GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s", attr->GetName().c_str());
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
auto ret = OpUtils::SetOutputSliceData(data, static_cast<int64_t>(data_size), data_type, input_dims, begin_vec,
|
|
|
|
|
output_dims, output_ptr.get(), stride_vec);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "SetOutputSliceData failed.");
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "SetOutputSliceData failed");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -133,18 +133,18 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<g
|
|
|
|
|
GetOutputDims(final_dim_size, output_dims, v_dims);
|
|
|
|
|
t_d.SetShape(GeShape(v_dims));
|
|
|
|
|
v_output.push_back(output_ptr);
|
|
|
|
|
GELOGI("StridedSliceKernel success.");
|
|
|
|
|
GELOGI("StridedSliceKernel success");
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr) {
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "input opdescptr is nullptr.");
|
|
|
|
|
GELOGE(PARAM_INVALID, "input opdescptr is nullptr");
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
// Get all op attr value of strided_slice
|
|
|
|
|
for (auto &attr_2_value : attr_value_map_) {
|
|
|
|
|
if (!AttrUtils::GetInt(attr, attr_2_value.first, attr_2_value.second)) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "Get %s attr failed.", attr_2_value.first.c_str());
|
|
|
|
|
GELOGE(PARAM_INVALID, "Get %s attr failed", attr_2_value.first.c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -159,7 +159,7 @@ Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr) {
|
|
|
|
|
}
|
|
|
|
|
Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &input) {
|
|
|
|
|
if (input.size() != kStridedSliceInputSize) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "The number of input for strided slice must be %zu.", kStridedSliceInputSize);
|
|
|
|
|
GELOGE(PARAM_INVALID, "The number of input for strided slice must be %zu", kStridedSliceInputSize);
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -178,11 +178,11 @@ Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &
|
|
|
|
|
auto stride_tensor_desc = begin_tensor->GetTensorDesc();
|
|
|
|
|
if (begin_tensor_desc.GetDataType() != end_tensor_desc.GetDataType() ||
|
|
|
|
|
end_tensor_desc.GetDataType() != stride_tensor_desc.GetDataType()) {
|
|
|
|
|
GELOGW("Data type of StridedSlice OP(begin,end,strides) must be same.");
|
|
|
|
|
GELOGW("Data type of StridedSlice OP(begin,end,strides) must be same");
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
if (kIndexNumberType.find(begin_tensor_desc.GetDataType()) == kIndexNumberType.end()) {
|
|
|
|
|
GELOGW("Data type of StridedSlice OP(begin,end,strides) must be int32 or int64.");
|
|
|
|
|
GELOGW("Data type of StridedSlice OP(begin,end,strides) must be int32 or int64");
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -190,7 +190,7 @@ Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &
|
|
|
|
|
auto x_data_type = weight0->GetTensorDesc().GetDataType();
|
|
|
|
|
auto x_data_size = GetSizeByDataType(x_data_type);
|
|
|
|
|
if (x_data_size < 0) {
|
|
|
|
|
GELOGW("Data type of x input %s is not supported.", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
|
|
|
|
|
GELOGW("Data type of x input %s is not supported", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
size_t weight0_size = weight0->GetData().size() / x_data_size;
|
|
|
|
@ -198,12 +198,12 @@ Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &
|
|
|
|
|
size_t end_data_size = end_tensor->GetData().size();
|
|
|
|
|
size_t stride_data_size = stride_tensor->GetData().size();
|
|
|
|
|
if ((weight0_size == 0) || (begin_data_size == 0) || (end_data_size == 0) || (stride_data_size == 0)) {
|
|
|
|
|
GELOGW("Data size of inputs is 0.");
|
|
|
|
|
GELOGW("Data size of inputs is 0");
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
// check dim size
|
|
|
|
|
if (!((begin_data_size == end_data_size) && (end_data_size == stride_data_size))) {
|
|
|
|
|
GELOGW("The sizes of begin, end and stride is not supported.");
|
|
|
|
|
GELOGW("The sizes of begin, end and stride is not supported");
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -250,15 +250,15 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector<ConstGeTensorPtr
|
|
|
|
|
end_i = x_dims.at(i);
|
|
|
|
|
stride_i = 1;
|
|
|
|
|
}
|
|
|
|
|
GELOGD("Before mask calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld.",
|
|
|
|
|
GELOGD("Before mask calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld",
|
|
|
|
|
begin_i, end_i, stride_i, x_dims.at(i));
|
|
|
|
|
auto ret = MaskCal(i, begin_i, end_i, x_dims.at(i));
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGW("MaskCal failed, because of data overflow.");
|
|
|
|
|
GELOGW("MaskCal failed, because of data overflow");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
int64_t dim_final;
|
|
|
|
|
GELOGD("Before stride calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld.",
|
|
|
|
|
GELOGD("Before stride calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld",
|
|
|
|
|
begin_i, end_i, stride_i, x_dims.at(i));
|
|
|
|
|
(void) StrideCal(x_dims.at(i), begin_i, end_i, stride_i, dim_final);
|
|
|
|
|
output_dims.push_back(dim_final);
|
|
|
|
@ -273,7 +273,7 @@ void StridedSliceKernel::ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_ten
|
|
|
|
|
vector<int64_t> &x_dims) {
|
|
|
|
|
auto begin_data_type_size = GetSizeByDataType(begin_tensor->GetTensorDesc().GetDataType());
|
|
|
|
|
if (begin_data_type_size == 0) {
|
|
|
|
|
GELOGW("Param begin_data_type_size should not be zero.");
|
|
|
|
|
GELOGW("Param begin_data_type_size should not be zero");
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
size_t begin_vec_size = begin_tensor->GetData().size() / begin_data_type_size;
|
|
|
|
|