You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/host_kernels/add_kernel.cc

203 lines
6.7 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "host_kernels/add_kernel.h"
#include <cfloat>
#include "common/math/math_util.h"
#include "graph/common/bcast.h"
#include "graph/utils/type_utils.h"
#include "inc/kernel_factory.h"
namespace ge {
namespace {
const size_t kAddFirstInput = 0;
const size_t kAddSecondInput = 1;
const size_t kAddFirstOutput = 0;
const size_t kAddInputSize = 2;
const size_t kAddOutputSize = 1;
#define SET_BCAST_ADD_CASE(DTYPE, TYPE) \
case (DTYPE): \
ret = BCastAdd<TYPE>(op_desc_ptr, input, v_output); \
break;
} // namespace
template <typename T>
Status AddKernel::OverflowCheck(const T &x, const T &y, DataType data_type) {
switch (data_type) {
case DT_INT8:
FMK_INT8_ADDCHECK(x, y)
break;
case DT_INT16:
FMK_INT16_ADDCHECK(x, y)
break;
case DT_INT32:
FMK_INT32_ADDCHECK(x, y)
break;
case DT_INT64:
FMK_INT64_ADDCHECK(x, y)
break;
case DT_UINT8:
FMK_UINT8_ADDCHECK(x, y)
break;
case DT_UINT16:
FMK_UINT16_ADDCHECK(x, y)
break;
case DT_UINT32:
FMK_UINT32_ADDCHECK(x, y)
break;
case DT_UINT64:
FMK_UINT64_ADDCHECK(x, y)
break;
case DT_FLOAT16:
FMK_FP16_ADDCHECK(x, y)
break;
case DT_FLOAT:
FMK_FLOAT_ADDCHECK(x, y)
break;
case DT_DOUBLE:
FMK_DOUBLE_ADDCHECK(x, y)
break;
default:
break;
}
return SUCCESS;
}
template <typename InT>
Status AddKernel::BCastAdd(const OpDescPtr &op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
std::vector<GeTensorPtr> &v_output) {
// only broadcast shape
BCast bcast;
Status ret = bcast.GenerateBcastInfo(BCast::TransShapeToDimVec(input[kAddFirstInput]->GetTensorDesc()),
BCast::TransShapeToDimVec(input[kAddSecondInput]->GetTensorDesc()));
if (ret != SUCCESS) {
GELOGE(ret, "Greater broadcasting failed.");
return ret;
}
std::vector<int64_t> x_indexes;
std::vector<int64_t> y_indexes;
bcast.BCastIndexes(x_indexes, y_indexes);
auto x1_data = reinterpret_cast<const InT *>(input[kAddFirstInput]->GetData().data());
auto x2_data = reinterpret_cast<const InT *>(input[kAddSecondInput]->GetData().data());
size_t data_num = x_indexes.size();
std::unique_ptr<InT[]> buf(new (std::nothrow) InT[data_num]());
if (buf == nullptr) {
GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast<size_t>(sizeof(InT) * data_num));
return MEMALLOC_FAILED;
}
DataType data_type = input[kAddFirstInput]->GetTensorDesc().GetDataType();
for (size_t i = 0; i < data_num; i++) {
auto x_index = *(x1_data + x_indexes[i]);
auto y_index = *(x2_data + y_indexes[i]);
if (OverflowCheck<InT>(x_index, y_index, data_type) != SUCCESS) {
GELOGE(PARAM_INVALID, "Result of add is overflow.");
return PARAM_INVALID;
}
*(buf.get() + i) = x_index + y_index;
}
GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(kAddFirstOutput));
if (output_ptr == nullptr) {
GELOGE(MEMALLOC_FAILED, "Make shared failed");
return MEMALLOC_FAILED;
}
output_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_num * sizeof(InT));
output_ptr->MutableTensorDesc().SetDataType(data_type);
vector<int64_t> bcast_dims = bcast.GetOutputShape();
output_ptr->MutableTensorDesc().SetShape(GeShape(bcast_dims));
v_output.push_back(output_ptr);
return SUCCESS;
}
Status AddKernel::AddCheck(const OpDescPtr &op_desc_ptr, const std::vector<ConstGeTensorPtr> &input) {
if (op_desc_ptr == nullptr) {
GELOGW("Op_desc_ptr must not be null.");
return PARAM_INVALID;
}
// check how many inputs
if ((input.size() != kAddInputSize) || (op_desc_ptr->GetOutputsSize() != kAddOutputSize)) {
GELOGW("The number of input for add must be %zu, output number must be %zu.", kAddInputSize,
kAddOutputSize);
return PARAM_INVALID;
}
// input vector elements must not be null
if ((input[kAddFirstInput] == nullptr) || (input[kAddSecondInput] == nullptr)) {
GELOGW("Input vector elements must not be null.");
return PARAM_INVALID;
}
// Inputs must have the same datatype.
DataType data_type_0 = input[kAddFirstInput]->GetTensorDesc().GetDataType();
DataType data_type_1 = input[kAddSecondInput]->GetTensorDesc().GetDataType();
if (data_type_0 != data_type_1) {
GELOGW("Data type of inputs for add not matched, data_type_0:%s, data_type_1:%s",
TypeUtils::DataTypeToSerialString(data_type_0).c_str(),
TypeUtils::DataTypeToSerialString(data_type_1).c_str());
return PARAM_INVALID;
}
// Checking whether the weightdef contains data
if ((input[kAddFirstInput]->GetData().size() == 0) || (input[kAddSecondInput]->GetData().size() == 0)) {
GELOGW("Data size of input0 is %zu, input1 is %zu.", input[kAddFirstInput]->GetData().size(),
input[kAddSecondInput]->GetData().size());
return PARAM_INVALID;
}
return SUCCESS;
}
Status AddKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
std::vector<GeTensorPtr> &v_output) {
if (AddCheck(op_desc_ptr, input) != SUCCESS) {
return NOT_CHANGED;
}
Status ret = NOT_CHANGED;
DataType data_type = input[kAddFirstInput]->GetTensorDesc().GetDataType();
switch (data_type) {
SET_BCAST_ADD_CASE(DT_INT8, int8_t)
SET_BCAST_ADD_CASE(DT_INT16, int16_t)
SET_BCAST_ADD_CASE(DT_INT32, int32_t)
SET_BCAST_ADD_CASE(DT_INT64, int64_t)
SET_BCAST_ADD_CASE(DT_UINT8, uint8_t)
SET_BCAST_ADD_CASE(DT_UINT16, uint16_t)
SET_BCAST_ADD_CASE(DT_UINT32, uint32_t)
SET_BCAST_ADD_CASE(DT_UINT64, uint64_t)
SET_BCAST_ADD_CASE(DT_FLOAT16, fp16_t)
SET_BCAST_ADD_CASE(DT_FLOAT, float)
SET_BCAST_ADD_CASE(DT_DOUBLE, double)
default:
GELOGI("Add kernel data type %s not support.", TypeUtils::DataTypeToSerialString(data_type).c_str());
return NOT_CHANGED;
}
if (ret != SUCCESS) {
GELOGW("Greater broadcasting failed.");
return NOT_CHANGED;
}
return SUCCESS;
}
REGISTER_KERNEL(ADD, AddKernel);
} // namespace ge