You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/host_kernels/gather_v2_kernel.cc

462 lines
19 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "host_kernels/gather_v2_kernel.h"
#include <memory>
#include <set>
#include "common/fp16_t.h"
#include "common/ge_inner_error_codes.h"
#include "common/op/ge_op_utils.h"
#include "common/types.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "host_kernels/kernel_utils.h"
#include "graph/utils/type_utils.h"
#include "inc/kernel_factory.h"
namespace ge {
namespace {
const size_t kGatherV2InputIndexZero = 0;
const size_t kGatherV2InputIndexOne = 1;
const size_t kGatherV2InputIndexTwo = 2;
const size_t kGatherV2InputIndexThree = 3;
const size_t kGatherV2DimOne = 1;
const size_t kGatherV2InpotNum = 3;
const size_t kMaxIndicatesDims = 1; // only support scalar and 1 dims indicates_
const std::set<DataType> supported_type = {DT_FLOAT16, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT16, DT_INT32,
DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64};
const int64_t DIM_AXIS_0 = 0;
const int64_t DIM_AXIS_1 = 1;
const int64_t DIM_AXIS_2 = 2;
const int64_t DIM_AXIS_3 = 3;
} // namespace
template <typename T>
Status GatherV2Kernel::ProcessAxis0(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
Status ret = SUCCESS;
T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
// index is valid, and no bigger than kGatherV2InputIndexZero
size_t output_size = output->GetData().size();
for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
T *data_ptr_x_tmp = data_ptr_x + indicates_[i] * xstride_[kGatherV2InputIndexZero];
T *data_ptr_y_tmp = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
size_t size = sizeof(T) * xstride_[kGatherV2InputIndexZero];
if (data_ptr_y_tmp - data_ptr_y < 0) {
GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
return PARAM_INVALID;
}
size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
reinterpret_cast<void *>(data_ptr_x_tmp), size);
if (ret_mem != 0) {
GELOGE(MEMALLOC_FAILED, "memcpy failed!");
return MEMALLOC_FAILED;
}
}
return ret;
}
template <typename T>
Status GatherV2Kernel::ProcessAxis1(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
Status ret = SUCCESS;
T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
// index is valid, and no bigger than kGatherV2InputIndexOne
size_t output_size = output->GetData().size();
for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
T *data_ptr_x_tmp = data_ptr_x_i + indicates_[j] * xstride_[kGatherV2InputIndexOne];
T *data_ptr_y_tmp = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
size_t size = sizeof(T) * xstride_[kGatherV2InputIndexOne];
if (data_ptr_y_tmp - data_ptr_y < 0) {
GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
return PARAM_INVALID;
}
size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
reinterpret_cast<void *>(data_ptr_x_tmp), size);
if (ret_mem != 0) {
GELOGE(MEMALLOC_FAILED, "memcpy failed!");
return MEMALLOC_FAILED;
}
}
}
return ret;
}
template <typename T>
Status GatherV2Kernel::ProcessAxis2(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
Status ret = SUCCESS;
T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
// index is valid, and no bigger than kGatherV2InputIndexTwo
size_t output_size = output->GetData().size();
for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne];
T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) {
T *data_ptr_x_tmp = data_ptr_x_j + indicates_[m] * xstride_[kGatherV2InputIndexTwo];
T *data_ptr_y_tmp = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo];
size_t size = sizeof(T) * xstride_[kGatherV2InputIndexTwo];
if (data_ptr_y_tmp - data_ptr_y < 0) {
GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
return PARAM_INVALID;
}
size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
reinterpret_cast<void *>(data_ptr_x_tmp), size);
if (ret_mem != 0) {
GELOGE(MEMALLOC_FAILED, "memcpy failed!");
return MEMALLOC_FAILED;
}
}
}
}
return ret;
}
template <typename T>
Status GatherV2Kernel::ProcessAxis3(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
Status ret = SUCCESS;
T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
// index is valid, and no bigger than kGatherV2InputIndexThree
size_t output_size = output->GetData().size();
for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne];
T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) {
T *data_ptr_x_m = data_ptr_x_j + m * xstride_[kGatherV2InputIndexTwo];
T *data_ptr_y_m = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo];
for (int64_t n = 0; n < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexThree); n++) {
T *data_ptr_x_tmp = data_ptr_x_m + indicates_[n] * xstride_[kGatherV2InputIndexThree];
T *data_ptr_y_tmp = data_ptr_y_m + n * ystride_[kGatherV2InputIndexThree];
size_t size = sizeof(T) * xstride_[kGatherV2InputIndexThree];
if (data_ptr_y_tmp - data_ptr_y < 0) {
GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
return PARAM_INVALID;
}
size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
reinterpret_cast<void *>(data_ptr_x_tmp), size);
if (ret_mem != 0) {
GELOGE(MEMALLOC_FAILED, "memcpy failed!");
return MEMALLOC_FAILED;
}
}
}
}
}
return ret;
}
template <typename T>
Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x, int64_t axis, GeTensorPtr output) {
if (data_num <= 0) {
return PARAM_INVALID;
}
if (!CheckInt64MulOverflow(data_num, sizeof(T))) {
GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num:%ld, type_len:%zu.", data_num, sizeof(T));
return PARAM_INVALID;
}
std::unique_ptr<T[]> buf(new (std::nothrow) T[data_num]());
if (buf == nullptr) {
GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast<size_t>(sizeof(T) * data_num));
return MEMALLOC_FAILED;
}
GE_IF_BOOL_EXEC(
output->SetData(reinterpret_cast<uint8_t *>(buf.get()), static_cast<size_t>(data_num * sizeof(T))) != GRAPH_SUCCESS,
GELOGE(INTERNAL_ERROR, "set data failed");
return INTERNAL_ERROR);
Status ret = SUCCESS;
switch (axis) {
case DIM_AXIS_0:
ret = ProcessAxis0<T>(tensor_x, output);
break;
case DIM_AXIS_1:
ret = ProcessAxis1<T>(tensor_x, output);
break;
case DIM_AXIS_2:
ret = ProcessAxis2<T>(tensor_x, output);
break;
case DIM_AXIS_3:
ret = ProcessAxis3<T>(tensor_x, output);
break;
default:
GELOGI("Only support 4 dims and below but input axis is %ld", axis);
return NOT_CHANGED;
}
return ret;
}
Status GatherV2Kernel::CalcStride(std::vector<int64_t> &stride, std::vector<int64_t> dims) {
if (stride.size() != dims.size() || dims.size() == 0) {
return PARAM_INVALID;
}
int i = static_cast<int>(dims.size() - kGatherV2DimOne);
stride[static_cast<size_t>(i)] = static_cast<int64_t>(kGatherV2DimOne);
i--;
while (i >= 0) {
size_t index = static_cast<size_t>(i) + kGatherV2DimOne;
if (!CheckInt64MulOverflow(stride[index], dims[index])) {
GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) type_len(%ld)", stride[index], dims[index]);
return PARAM_INVALID;
}
stride[static_cast<size_t>(i)] = stride[index] * dims[index];
i--;
}
return SUCCESS;
}
Status GatherV2Kernel::Process(int64_t axis, DataType data_type, ConstGeTensorPtr input_tensor_ptr,
GeTensorPtr output_ptr) {
Status ret = SUCCESS;
int64_t data_num = output_ptr->GetTensorDesc().GetShape().GetShapeSize();
switch (data_type) {
case DT_FLOAT16:
ret = GenData<fp16_t>(data_num, input_tensor_ptr, axis, output_ptr);
break;
case DT_DOUBLE:
ret = GenData<double>(data_num, input_tensor_ptr, axis, output_ptr);
break;
case DT_INT8:
ret = GenData<int8_t>(data_num, input_tensor_ptr, axis, output_ptr);
break;
case DT_INT16:
ret = GenData<int16_t>(data_num, input_tensor_ptr, axis, output_ptr);
break;
case DT_INT32:
ret = GenData<int32_t>(data_num, input_tensor_ptr, axis, output_ptr);
break;
case DT_INT64:
ret = GenData<int64_t>(data_num, input_tensor_ptr, axis, output_ptr);
break;
case DT_UINT8:
ret = GenData<uint8_t>(data_num, input_tensor_ptr, axis, output_ptr);
break;
case DT_UINT16:
ret = GenData<uint16_t>(data_num, input_tensor_ptr, axis, output_ptr);
break;
case DT_UINT32:
ret = GenData<uint32_t>(data_num, input_tensor_ptr, axis, output_ptr);
break;
case DT_UINT64:
ret = GenData<uint64_t>(data_num, input_tensor_ptr, axis, output_ptr);
break;
default:
GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str());
return NOT_CHANGED;
}
return ret;
}
Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr, GeShape &x_shape,
GeShape &indices_shape, DataType indices_data_type, size_t axis) {
if (indices_data_type == DT_INT32) {
auto indices_ptr = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(indices_tensor_ptr->GetData().data()));
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
return NOT_CHANGED;
}
indicates_.push_back(*(indices_ptr + i));
}
} else {
// int64
auto indices_ptr = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(indices_tensor_ptr->GetData().data()));
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
return NOT_CHANGED;
}
indicates_.push_back(*(indices_ptr + i));
}
}
return SUCCESS;
}
Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeTensorPtr> &input,
vector<GeTensorPtr> &v_output) const {
if (op_desc_ptr == nullptr) {
GELOGW("input opdesc is nullptr.");
return NOT_CHANGED;
}
if (input.size() != kGatherV2InpotNum) {
GELOGW("The number of input for GatherV2 must be %zu.", kGatherV2InpotNum);
return NOT_CHANGED;
}
bool is_null = (input[kGatherV2InputIndexZero] == nullptr || input[kGatherV2InputIndexOne] == nullptr ||
input[kGatherV2InputIndexTwo] == nullptr);
if (is_null) {
GELOGW("some input is nullptr.");
return NOT_CHANGED;
}
ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
bool size_is_zero =
((tensor0->GetData().size() == 0) || (tensor1->GetData().size() == 0) || (tensor2->GetData().size() == 0));
if (size_is_zero) {
GELOGW("some input size is zero.");
return NOT_CHANGED;
}
auto indices_shape = tensor1->GetTensorDesc().GetShape();
auto axis_shape = tensor2->GetTensorDesc().GetShape();
// axis must be scalar
if (axis_shape.GetDimNum() != 0) {
GELOGW("axis must be scalar but its shape is %zu", axis_shape.GetDimNum());
return NOT_CHANGED;
}
auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64;
if (!is_valid_axis_data_type) {
GELOGW("axis datatype must be DT_INT32 or DT_INT64");
return NOT_CHANGED;
}
// check indices data_type && dims && every element
auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64;
if (!is_valid_indices_data_type) {
GELOGW("indices datatype must be DT_INT32 or DT_INT64.");
return NOT_CHANGED;
}
if (indices_shape.GetDimNum() > kMaxIndicatesDims) {
GELOGW("indices input only support 0 or 1 dims.");
return NOT_CHANGED;
}
return SUCCESS;
}
void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeShape &indices_shape,
const std::vector<int64_t> &y_shape) {
GELOGD("GatherV2Kernel axis:%ld x_shape:%zu indices_shape:%zu y_shape:%zu.", axis, x_shape.GetDimNum(),
indices_shape.GetDimNum(), y_shape.size());
for (size_t i = 0; i < x_shape.GetDimNum(); i++) {
GELOGD("GatherV2Kernel x_shape[%zu]: %ld.", i, x_shape.GetDim(i));
}
for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
GELOGD("GatherV2Kernel indices_shape[%zu]: %ld.", i, indices_shape.GetDim(i));
}
for (size_t i = 0; i < y_shape.size(); i++) {
GELOGD("GatherV2Kernel y_shape[%zu]: %ld.", i, y_shape[i]);
}
for (auto ele : indicates_) {
GELOGD("GatherV2Kernel indices:%ld.", ele);
}
}
Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input,
vector<GeTensorPtr> &v_output) {
GELOGI("Enter GatherV2Kernel Process.");
Status ret = Check(op_desc_ptr, input, v_output);
if (ret != SUCCESS) {
GELOGW("param check failed");
return NOT_CHANGED;
}
GELOGI("GatherV2Kernel[%s] start Process", op_desc_ptr->GetName().c_str());
ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
auto x_shape = tensor0->GetTensorDesc().GetShape();
auto indices_shape = tensor1->GetTensorDesc().GetShape();
auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
int64_t axis = axis_data_type == DT_INT32
? *(const_cast<int32_t *>(reinterpret_cast<const int32_t *>(tensor2->GetData().data())))
: *(const_cast<int64_t *>(reinterpret_cast<const int64_t *>(tensor2->GetData().data())));
axis = axis >= 0 ? axis : axis + x_shape.GetDimNum();
// check axis value
if (axis < 0 || (axis + 1) > static_cast<int64_t>(x_shape.GetDimNum())) {
GELOGW("axis is invalid!");
return NOT_CHANGED;
}
auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
ret = SaveIndicesByDataType(tensor1, x_shape, indices_shape, indices_data_type, static_cast<size_t>(axis));
if (ret != SUCCESS) {
GELOGW("Save indeices by data type failed!");
return ret;
}
// check input data type
auto x_data_type = tensor0->GetTensorDesc().GetDataType();
if (supported_type.find(x_data_type) == supported_type.end()) {
GELOGI("GatherV2Kernel does not support this Data type:%s.",
TypeUtils::DataTypeToSerialString(x_data_type).c_str());
return NOT_CHANGED;
}
// calc output shape
std::vector<int64_t> y_shape;
for (size_t i = 0; i < static_cast<size_t>(axis); i++) {
y_shape.push_back(x_shape.GetDim(i));
}
for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
y_shape.push_back(indices_shape.GetDim(i));
}
for (size_t i = static_cast<size_t>(axis) + 1; i < x_shape.GetDimNum(); i++) {
y_shape.push_back(x_shape.GetDim(i));
}
GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
if (output_ptr == nullptr) {
GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str());
return NOT_CHANGED;
}
output_ptr->MutableTensorDesc().SetShape(GeShape(y_shape));
output_ptr->MutableTensorDesc().SetDataType(x_data_type);
// added for debug
DebugPrint(axis, x_shape, indices_shape, y_shape);
// calc stride
std::vector<int64_t> xstride(x_shape.GetDimNum());
std::vector<int64_t> ystride(y_shape.size());
xstride_ = xstride;
ystride_ = ystride;
auto ret_x = CalcStride(xstride_, x_shape.GetDims());
auto ret_y = CalcStride(ystride_, y_shape);
ret = (ret_x == SUCCESS && ret_y == SUCCESS) ? SUCCESS : NOT_CHANGED;
if (ret != SUCCESS) {
GELOGE(ret, "CalcStride Failed");
return ret;
}
ret = Process(axis, x_data_type, tensor0, output_ptr);
if (ret != SUCCESS) {
GELOGE(ret, "GenData failed, data_type: %s", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
return ret;
}
GELOGI("GatherV2Kernel Process Success.");
v_output.push_back(output_ptr);
return SUCCESS;
}
REGISTER_KERNEL(GATHERV2, GatherV2Kernel);
} // namespace ge