|
|
@ -278,7 +278,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr
|
|
|
|
auto indices_ptr = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(indices_tensor_ptr->GetData().data()));
|
|
|
|
auto indices_ptr = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(indices_tensor_ptr->GetData().data()));
|
|
|
|
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
|
|
|
|
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
|
|
|
|
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
|
|
|
|
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
|
|
|
|
GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
|
|
|
|
GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
|
|
|
|
return NOT_CHANGED;
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
indicates_.push_back(*(indices_ptr + i));
|
|
|
|
indicates_.push_back(*(indices_ptr + i));
|
|
|
@ -288,7 +288,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr
|
|
|
|
auto indices_ptr = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(indices_tensor_ptr->GetData().data()));
|
|
|
|
auto indices_ptr = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(indices_tensor_ptr->GetData().data()));
|
|
|
|
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
|
|
|
|
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
|
|
|
|
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
|
|
|
|
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
|
|
|
|
GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
|
|
|
|
GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
|
|
|
|
return NOT_CHANGED;
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
indicates_.push_back(*(indices_ptr + i));
|
|
|
|
indicates_.push_back(*(indices_ptr + i));
|
|
|
@ -344,42 +344,42 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeT
|
|
|
|
auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
|
|
|
|
auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
|
|
|
|
bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64;
|
|
|
|
bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64;
|
|
|
|
if (!is_valid_indices_data_type) {
|
|
|
|
if (!is_valid_indices_data_type) {
|
|
|
|
GELOGW("indices datatype must be DT_INT32 or DT_INT64");
|
|
|
|
GELOGW("indices datatype must be DT_INT32 or DT_INT64.");
|
|
|
|
return NOT_CHANGED;
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (indices_shape.GetDimNum() > kMaxIndicatesDims) {
|
|
|
|
if (indices_shape.GetDimNum() > kMaxIndicatesDims) {
|
|
|
|
GELOGW("indices input only support 0 or 1 dims");
|
|
|
|
GELOGW("indices input only support 0 or 1 dims.");
|
|
|
|
return NOT_CHANGED;
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return SUCCESS;
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeShape &indices_shape,
|
|
|
|
void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeShape &indices_shape,
|
|
|
|
const std::vector<int64_t> &y_shape) {
|
|
|
|
const std::vector<int64_t> &y_shape) {
|
|
|
|
GELOGD("GatherV2Kernel axis:%ld x_shape:%zu indices_shape:%zu y_shape:%zu", axis, x_shape.GetDimNum(),
|
|
|
|
GELOGD("GatherV2Kernel axis:%ld x_shape:%zu indices_shape:%zu y_shape:%zu.", axis, x_shape.GetDimNum(),
|
|
|
|
indices_shape.GetDimNum(), y_shape.size());
|
|
|
|
indices_shape.GetDimNum(), y_shape.size());
|
|
|
|
for (size_t i = 0; i < x_shape.GetDimNum(); i++) {
|
|
|
|
for (size_t i = 0; i < x_shape.GetDimNum(); i++) {
|
|
|
|
GELOGD("GatherV2Kernel x_shape[%zu]: %ld", i, x_shape.GetDim(i));
|
|
|
|
GELOGD("GatherV2Kernel x_shape[%zu]: %ld.", i, x_shape.GetDim(i));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
|
|
|
|
for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
|
|
|
|
GELOGD("GatherV2Kernel indices_shape[%zu]: %ld", i, indices_shape.GetDim(i));
|
|
|
|
GELOGD("GatherV2Kernel indices_shape[%zu]: %ld.", i, indices_shape.GetDim(i));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (size_t i = 0; i < y_shape.size(); i++) {
|
|
|
|
for (size_t i = 0; i < y_shape.size(); i++) {
|
|
|
|
GELOGD("GatherV2Kernel y_shape[%zu]: %ld", i, y_shape[i]);
|
|
|
|
GELOGD("GatherV2Kernel y_shape[%zu]: %ld.", i, y_shape[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto ele : indicates_) {
|
|
|
|
for (auto ele : indicates_) {
|
|
|
|
GELOGD("GatherV2Kernel indices:%ld", ele);
|
|
|
|
GELOGD("GatherV2Kernel indices:%ld.", ele);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input,
|
|
|
|
Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input,
|
|
|
|
vector<GeTensorPtr> &v_output) {
|
|
|
|
vector<GeTensorPtr> &v_output) {
|
|
|
|
GELOGI("Enter GatherV2Kernel Process.");
|
|
|
|
GELOGI("Enter GatherV2Kernel Process");
|
|
|
|
Status ret = Check(op_desc_ptr, input, v_output);
|
|
|
|
Status ret = Check(op_desc_ptr, input, v_output);
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
GELOGW("param check failed.");
|
|
|
|
GELOGW("param check failed");
|
|
|
|
return NOT_CHANGED;
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
GELOGI("GatherV2Kernel[%s] start Process.", op_desc_ptr->GetName().c_str());
|
|
|
|
GELOGI("GatherV2Kernel[%s] start Process", op_desc_ptr->GetName().c_str());
|
|
|
|
ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
|
|
|
|
ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
|
|
|
|
ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
|
|
|
|
ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
|
|
|
|
ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
|
|
|
|
ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
|
|
|
@ -394,7 +394,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGe
|
|
|
|
axis = axis >= 0 ? axis : axis + x_shape.GetDimNum();
|
|
|
|
axis = axis >= 0 ? axis : axis + x_shape.GetDimNum();
|
|
|
|
// check axis value
|
|
|
|
// check axis value
|
|
|
|
if (axis < 0 || (axis + 1) > static_cast<int64_t>(x_shape.GetDimNum())) {
|
|
|
|
if (axis < 0 || (axis + 1) > static_cast<int64_t>(x_shape.GetDimNum())) {
|
|
|
|
GELOGW("axis is invalid");
|
|
|
|
GELOGW("axis is invalid!");
|
|
|
|
return NOT_CHANGED;
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
|
|
|
|
auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
|
|
|
@ -407,7 +407,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGe
|
|
|
|
// check input data type
|
|
|
|
// check input data type
|
|
|
|
auto x_data_type = tensor0->GetTensorDesc().GetDataType();
|
|
|
|
auto x_data_type = tensor0->GetTensorDesc().GetDataType();
|
|
|
|
if (supported_type.find(x_data_type) == supported_type.end()) {
|
|
|
|
if (supported_type.find(x_data_type) == supported_type.end()) {
|
|
|
|
GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
|
|
|
|
GELOGI("GatherV2Kernel does not support this Data type:%s.", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
|
|
|
|
return NOT_CHANGED;
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// calc output shape
|
|
|
|
// calc output shape
|
|
|
|