|
|
|
@ -274,7 +274,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr
|
|
|
|
|
auto indices_ptr = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(indices_tensor_ptr->GetData().data()));
|
|
|
|
|
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
|
|
|
|
|
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
|
|
|
|
|
GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
|
|
|
|
|
GELOGE(NOT_CHANGED, "indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
indicates_.push_back(*(indices_ptr + i));
|
|
|
|
@ -284,7 +284,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr
|
|
|
|
|
auto indices_ptr = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(indices_tensor_ptr->GetData().data()));
|
|
|
|
|
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
|
|
|
|
|
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
|
|
|
|
|
GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
|
|
|
|
|
GELOGE(NOT_CHANGED, "indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
indicates_.push_back(*(indices_ptr + i));
|
|
|
|
@ -296,19 +296,19 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr
|
|
|
|
|
Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeTensorPtr> &input,
|
|
|
|
|
vector<GeTensorPtr> &v_output) const {
|
|
|
|
|
if (op_desc_ptr == nullptr) {
|
|
|
|
|
GELOGW("input opdesc is nullptr.");
|
|
|
|
|
GELOGE(NOT_CHANGED, "input opdesc is nullptr.");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (input.size() != kGatherV2InpotNum) {
|
|
|
|
|
GELOGW("The number of input for GatherV2 must be %zu.", kGatherV2InpotNum);
|
|
|
|
|
GELOGE(NOT_CHANGED, "The number of input for GatherV2 must be %zu.", kGatherV2InpotNum);
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_null = (input[kGatherV2InputIndexZero] == nullptr || input[kGatherV2InputIndexOne] == nullptr ||
|
|
|
|
|
input[kGatherV2InputIndexTwo] == nullptr);
|
|
|
|
|
if (is_null) {
|
|
|
|
|
GELOGW("some input is nullptr.");
|
|
|
|
|
GELOGE(NOT_CHANGED, "some input is nullptr.");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
|
|
|
|
@ -318,7 +318,7 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeT
|
|
|
|
|
bool size_is_zero =
|
|
|
|
|
((tensor0->GetData().size() == 0) || (tensor1->GetData().size() == 0) || (tensor2->GetData().size() == 0));
|
|
|
|
|
if (size_is_zero) {
|
|
|
|
|
GELOGW("some input size is zero.");
|
|
|
|
|
GELOGE(NOT_CHANGED, "some input size is zero.");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -326,13 +326,13 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeT
|
|
|
|
|
auto axis_shape = tensor2->GetTensorDesc().GetShape();
|
|
|
|
|
// axis must be scalar
|
|
|
|
|
if (axis_shape.GetDimNum() != 0) {
|
|
|
|
|
GELOGW("axis must be scalar but its shape is %zu", axis_shape.GetDimNum());
|
|
|
|
|
GELOGE(NOT_CHANGED, "axis must be scalar but its shape is %zu", axis_shape.GetDimNum());
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
|
|
|
|
|
bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64;
|
|
|
|
|
if (!is_valid_axis_data_type) {
|
|
|
|
|
GELOGW("axis datatype must be DT_INT32 or DT_INT64");
|
|
|
|
|
GELOGE(NOT_CHANGED, "axis datatype must be DT_INT32 or DT_INT64");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -340,11 +340,11 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeT
|
|
|
|
|
auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
|
|
|
|
|
bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64;
|
|
|
|
|
if (!is_valid_indices_data_type) {
|
|
|
|
|
GELOGW("indices datatype must be DT_INT32 or DT_INT64");
|
|
|
|
|
GELOGE(NOT_CHANGED, "indices datatype must be DT_INT32 or DT_INT64");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
if (indices_shape.GetDimNum() > kMaxIndicatesDims) {
|
|
|
|
|
GELOGW("indices input only support 0 or 1 dims");
|
|
|
|
|
GELOGE(NOT_CHANGED, "indices input only support 0 or 1 dims");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -372,7 +372,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGe
|
|
|
|
|
GELOGI("Enter GatherV2Kernel Process.");
|
|
|
|
|
Status ret = Check(op_desc_ptr, input, v_output);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGW("param check failed.");
|
|
|
|
|
GELOGE(NOT_CHANGED, "param check failed.");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
GELOGI("GatherV2Kernel[%s] start Process.", op_desc_ptr->GetName().c_str());
|
|
|
|
@ -390,13 +390,13 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGe
|
|
|
|
|
axis = axis >= 0 ? axis : axis + x_shape.GetDimNum();
|
|
|
|
|
// check axis value
|
|
|
|
|
if (axis < 0 || (axis + 1) > static_cast<int64_t>(x_shape.GetDimNum())) {
|
|
|
|
|
GELOGW("axis is invalid");
|
|
|
|
|
GELOGE(NOT_CHANGED, "axis is invalid");
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
|
|
|
|
|
ret = SaveIndicesByDataType(tensor1, x_shape, indices_shape, indices_data_type, static_cast<size_t>(axis));
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGW("Save indeices by data type failed!");
|
|
|
|
|
GELOGE(NOT_CHANGED, "Save indeices by data type failed!");
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -420,7 +420,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGe
|
|
|
|
|
|
|
|
|
|
GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
|
|
|
|
|
if (output_ptr == nullptr) {
|
|
|
|
|
GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str());
|
|
|
|
|
GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str());
|
|
|
|
|
return NOT_CHANGED;
|
|
|
|
|
}
|
|
|
|
|
output_ptr->MutableTensorDesc().SetShape(GeShape(y_shape));
|
|
|
|
|