/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "host_kernels/gather_v2_kernel.h" #include #include #include "common/fp16_t.h" #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "common/types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" namespace ge { namespace { const size_t kGatherV2InputIndexZero = 0; const size_t kGatherV2InputIndexOne = 1; const size_t kGatherV2InputIndexTwo = 2; const size_t kGatherV2InputIndexThree = 3; const size_t kGatherV2DimOne = 1; const size_t kGatherV2InpotNum = 3; const size_t kMaxIndicatesDims = 1; // only support scalar and 1 dims indicates_ const std::set supported_type = {DT_FLOAT16, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64}; const int64_t DIM_AXIS_0 = 0; const int64_t DIM_AXIS_1 = 1; const int64_t DIM_AXIS_2 = 2; const int64_t DIM_AXIS_3 = 3; } // namespace template Status GatherV2Kernel::ProcessAxis0(ConstGeTensorPtr tensor_x, GeTensorPtr output) { Status ret = SUCCESS; T *data_ptr_x = reinterpret_cast(const_cast(tensor_x->GetData().data())); T *data_ptr_y = reinterpret_cast(const_cast(output->GetData().data())); // index is valid, and no bigger than kGatherV2InputIndexZero size_t output_size = output->GetData().size(); for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) { T *data_ptr_x_tmp = data_ptr_x + indicates_[i] * xstride_[kGatherV2InputIndexZero]; T *data_ptr_y_tmp = data_ptr_y + i * ystride_[kGatherV2InputIndexZero]; size_t size = sizeof(T) * xstride_[kGatherV2InputIndexZero]; if (data_ptr_y_tmp - data_ptr_y < 0) { GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero"); return PARAM_INVALID; } size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T); auto ret_mem = memcpy_s(reinterpret_cast(data_ptr_y_tmp), output_size - offset_size, reinterpret_cast(data_ptr_x_tmp), size); if (ret_mem != 0) { GELOGE(MEMALLOC_FAILED, "memcpy failed!"); return MEMALLOC_FAILED; } } return ret; } template Status GatherV2Kernel::ProcessAxis1(ConstGeTensorPtr tensor_x, GeTensorPtr output) { Status ret = SUCCESS; T *data_ptr_x = reinterpret_cast(const_cast(tensor_x->GetData().data())); T *data_ptr_y = reinterpret_cast(const_cast(output->GetData().data())); // index is valid, and no bigger than kGatherV2InputIndexOne size_t output_size = output->GetData().size(); for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) { T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero]; T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero]; for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) { T *data_ptr_x_tmp = data_ptr_x_i + indicates_[j] * xstride_[kGatherV2InputIndexOne]; T *data_ptr_y_tmp = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne]; size_t size = sizeof(T) * xstride_[kGatherV2InputIndexOne]; if (data_ptr_y_tmp - data_ptr_y < 0) { GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero"); return PARAM_INVALID; } size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T); auto ret_mem = memcpy_s(reinterpret_cast(data_ptr_y_tmp), output_size - offset_size, reinterpret_cast(data_ptr_x_tmp), size); if (ret_mem != 0) { GELOGE(MEMALLOC_FAILED, "memcpy failed!"); return MEMALLOC_FAILED; } } } return ret; } template Status GatherV2Kernel::ProcessAxis2(ConstGeTensorPtr tensor_x, GeTensorPtr output) { Status ret = SUCCESS; T *data_ptr_x = reinterpret_cast(const_cast(tensor_x->GetData().data())); T *data_ptr_y = reinterpret_cast(const_cast(output->GetData().data())); // index is valid, and no bigger than kGatherV2InputIndexTwo size_t output_size = output->GetData().size(); for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) { T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero]; T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero]; for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) { T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne]; T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne]; for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) { T *data_ptr_x_tmp = data_ptr_x_j + indicates_[m] * xstride_[kGatherV2InputIndexTwo]; T *data_ptr_y_tmp = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo]; size_t size = sizeof(T) * xstride_[kGatherV2InputIndexTwo]; if (data_ptr_y_tmp - data_ptr_y < 0) { GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero"); return PARAM_INVALID; } size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T); auto ret_mem = memcpy_s(reinterpret_cast(data_ptr_y_tmp), output_size - offset_size, reinterpret_cast(data_ptr_x_tmp), size); if (ret_mem != 0) { GELOGE(MEMALLOC_FAILED, "memcpy failed!"); return MEMALLOC_FAILED; } } } } return ret; } template Status GatherV2Kernel::ProcessAxis3(ConstGeTensorPtr tensor_x, GeTensorPtr output) { Status ret = SUCCESS; T *data_ptr_x = reinterpret_cast(const_cast(tensor_x->GetData().data())); T *data_ptr_y = reinterpret_cast(const_cast(output->GetData().data())); // index is valid, and no bigger than kGatherV2InputIndexThree size_t output_size = output->GetData().size(); for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) { T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero]; T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero]; for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) { T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne]; T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne]; for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) { T *data_ptr_x_m = data_ptr_x_j + m * xstride_[kGatherV2InputIndexTwo]; T *data_ptr_y_m = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo]; for (int64_t n = 0; n < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexThree); n++) { T *data_ptr_x_tmp = data_ptr_x_m + indicates_[n] * xstride_[kGatherV2InputIndexThree]; T *data_ptr_y_tmp = data_ptr_y_m + n * ystride_[kGatherV2InputIndexThree]; size_t size = sizeof(T) * xstride_[kGatherV2InputIndexThree]; if (data_ptr_y_tmp - data_ptr_y < 0) { GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero"); return PARAM_INVALID; } size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T); auto ret_mem = memcpy_s(reinterpret_cast(data_ptr_y_tmp), output_size - offset_size, reinterpret_cast(data_ptr_x_tmp), size); if (ret_mem != 0) { GELOGE(MEMALLOC_FAILED, "memcpy failed!"); return MEMALLOC_FAILED; } } } } } return ret; } template Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x, int64_t axis, GeTensorPtr output) { if (data_num <= 0) { return PARAM_INVALID; } if (!CheckInt64MulOverflow(data_num, sizeof(T))) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num:%ld, type_len:%zu.", data_num, sizeof(T)); return PARAM_INVALID; } std::unique_ptr buf(new (std::nothrow) T[data_num]()); if (buf == nullptr) { GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast(sizeof(T) * data_num)); return MEMALLOC_FAILED; } GE_IF_BOOL_EXEC( output->SetData(reinterpret_cast(buf.get()), static_cast(data_num * sizeof(T))) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "set data failed"); return INTERNAL_ERROR); Status ret = SUCCESS; switch (axis) { case DIM_AXIS_0: ret = ProcessAxis0(tensor_x, output); break; case DIM_AXIS_1: ret = ProcessAxis1(tensor_x, output); break; case DIM_AXIS_2: ret = ProcessAxis2(tensor_x, output); break; case DIM_AXIS_3: ret = ProcessAxis3(tensor_x, output); break; default: GELOGI("Only support 4 dims and below but input axis is %ld", axis); return NOT_CHANGED; } return ret; } Status GatherV2Kernel::CalcStride(std::vector &stride, std::vector dims) { if (stride.size() != dims.size() || dims.size() == 0) { return PARAM_INVALID; } int i = static_cast(dims.size() - kGatherV2DimOne); stride[static_cast(i)] = static_cast(kGatherV2DimOne); i--; while (i >= 0) { size_t index = static_cast(i) + kGatherV2DimOne; if (!CheckInt64MulOverflow(stride[index], dims[index])) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) type_len(%ld)", stride[index], dims[index]); return PARAM_INVALID; } stride[static_cast(i)] = stride[index] * dims[index]; i--; } return SUCCESS; } Status GatherV2Kernel::Process(int64_t axis, DataType data_type, ConstGeTensorPtr input_tensor_ptr, GeTensorPtr output_ptr) { Status ret = SUCCESS; int64_t data_num = output_ptr->GetTensorDesc().GetShape().GetShapeSize(); switch (data_type) { case DT_FLOAT16: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; case DT_DOUBLE: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; case DT_INT8: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; case DT_INT16: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; case DT_INT32: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; case DT_INT64: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; case DT_UINT8: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; case DT_UINT16: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; case DT_UINT32: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; case DT_UINT64: ret = GenData(data_num, input_tensor_ptr, axis, output_ptr); break; default: GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str()); return NOT_CHANGED; } return ret; } Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr, GeShape &x_shape, GeShape &indices_shape, DataType indices_data_type, size_t axis) { if (indices_data_type == DT_INT32) { auto indices_ptr = const_cast(reinterpret_cast(indices_tensor_ptr->GetData().data())); for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) { if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) { GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis)); return NOT_CHANGED; } indicates_.push_back(*(indices_ptr + i)); } } else { // int64 auto indices_ptr = const_cast(reinterpret_cast(indices_tensor_ptr->GetData().data())); for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) { if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) { GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis)); return NOT_CHANGED; } indicates_.push_back(*(indices_ptr + i)); } } return SUCCESS; } Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector &input, vector &v_output) const { if (op_desc_ptr == nullptr) { GELOGW("input opdesc is nullptr."); return NOT_CHANGED; } if (input.size() != kGatherV2InpotNum) { GELOGW("The number of input for GatherV2 must be %zu.", kGatherV2InpotNum); return NOT_CHANGED; } bool is_null = (input[kGatherV2InputIndexZero] == nullptr || input[kGatherV2InputIndexOne] == nullptr || input[kGatherV2InputIndexTwo] == nullptr); if (is_null) { GELOGW("some input is nullptr."); return NOT_CHANGED; } ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero); ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne); ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo); bool size_is_zero = ((tensor0->GetData().size() == 0) || (tensor1->GetData().size() == 0) || (tensor2->GetData().size() == 0)); if (size_is_zero) { GELOGW("some input size is zero."); return NOT_CHANGED; } auto indices_shape = tensor1->GetTensorDesc().GetShape(); auto axis_shape = tensor2->GetTensorDesc().GetShape(); // axis must be scalar if (axis_shape.GetDimNum() != 0) { GELOGW("axis must be scalar but its shape is %zu", axis_shape.GetDimNum()); return NOT_CHANGED; } auto axis_data_type = tensor2->GetTensorDesc().GetDataType(); bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64; if (!is_valid_axis_data_type) { GELOGW("axis datatype must be DT_INT32 or DT_INT64"); return NOT_CHANGED; } // check indices data_type && dims && every element auto indices_data_type = tensor1->GetTensorDesc().GetDataType(); bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64; if (!is_valid_indices_data_type) { GELOGW("indices datatype must be DT_INT32 or DT_INT64."); return NOT_CHANGED; } if (indices_shape.GetDimNum() > kMaxIndicatesDims) { GELOGW("indices input only support 0 or 1 dims."); return NOT_CHANGED; } return SUCCESS; } void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeShape &indices_shape, const std::vector &y_shape) { GELOGD("GatherV2Kernel axis:%ld x_shape:%zu indices_shape:%zu y_shape:%zu.", axis, x_shape.GetDimNum(), indices_shape.GetDimNum(), y_shape.size()); for (size_t i = 0; i < x_shape.GetDimNum(); i++) { GELOGD("GatherV2Kernel x_shape[%zu]: %ld.", i, x_shape.GetDim(i)); } for (size_t i = 0; i < indices_shape.GetDimNum(); i++) { GELOGD("GatherV2Kernel indices_shape[%zu]: %ld.", i, indices_shape.GetDim(i)); } for (size_t i = 0; i < y_shape.size(); i++) { GELOGD("GatherV2Kernel y_shape[%zu]: %ld.", i, y_shape[i]); } for (auto ele : indicates_) { GELOGD("GatherV2Kernel indices:%ld.", ele); } } Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector &input, vector &v_output) { GELOGI("Enter GatherV2Kernel Process."); Status ret = Check(op_desc_ptr, input, v_output); if (ret != SUCCESS) { GELOGW("param check failed"); return NOT_CHANGED; } GELOGI("GatherV2Kernel[%s] start Process", op_desc_ptr->GetName().c_str()); ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero); ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne); ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo); auto x_shape = tensor0->GetTensorDesc().GetShape(); auto indices_shape = tensor1->GetTensorDesc().GetShape(); auto axis_data_type = tensor2->GetTensorDesc().GetDataType(); int64_t axis = axis_data_type == DT_INT32 ? *(const_cast(reinterpret_cast(tensor2->GetData().data()))) : *(const_cast(reinterpret_cast(tensor2->GetData().data()))); axis = axis >= 0 ? axis : axis + x_shape.GetDimNum(); // check axis value if (axis < 0 || (axis + 1) > static_cast(x_shape.GetDimNum())) { GELOGW("axis is invalid!"); return NOT_CHANGED; } auto indices_data_type = tensor1->GetTensorDesc().GetDataType(); ret = SaveIndicesByDataType(tensor1, x_shape, indices_shape, indices_data_type, static_cast(axis)); if (ret != SUCCESS) { GELOGW("Save indeices by data type failed!"); return ret; } // check input data type auto x_data_type = tensor0->GetTensorDesc().GetDataType(); if (supported_type.find(x_data_type) == supported_type.end()) { GELOGI("GatherV2Kernel does not support this Data type:%s.", TypeUtils::DataTypeToSerialString(x_data_type).c_str()); return NOT_CHANGED; } // calc output shape std::vector y_shape; for (size_t i = 0; i < static_cast(axis); i++) { y_shape.push_back(x_shape.GetDim(i)); } for (size_t i = 0; i < indices_shape.GetDimNum(); i++) { y_shape.push_back(indices_shape.GetDim(i)); } for (size_t i = static_cast(axis) + 1; i < x_shape.GetDimNum(); i++) { y_shape.push_back(x_shape.GetDim(i)); } GeTensorPtr output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(0)); if (output_ptr == nullptr) { GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } output_ptr->MutableTensorDesc().SetShape(GeShape(y_shape)); output_ptr->MutableTensorDesc().SetDataType(x_data_type); // added for debug DebugPrint(axis, x_shape, indices_shape, y_shape); // calc stride std::vector xstride(x_shape.GetDimNum()); std::vector ystride(y_shape.size()); xstride_ = xstride; ystride_ = ystride; auto ret_x = CalcStride(xstride_, x_shape.GetDims()); auto ret_y = CalcStride(ystride_, y_shape); ret = (ret_x == SUCCESS && ret_y == SUCCESS) ? SUCCESS : NOT_CHANGED; if (ret != SUCCESS) { GELOGE(ret, "CalcStride Failed"); return ret; } ret = Process(axis, x_data_type, tensor0, output_ptr); if (ret != SUCCESS) { GELOGE(ret, "GenData failed, data_type: %s", TypeUtils::DataTypeToSerialString(x_data_type).c_str()); return ret; } GELOGI("GatherV2Kernel Process Success."); v_output.push_back(output_ptr); return SUCCESS; } REGISTER_KERNEL(GATHERV2, GatherV2Kernel); } // namespace ge