You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
203 lines
6.7 KiB
203 lines
6.7 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "host_kernels/add_kernel.h"
|
|
|
|
#include <cfloat>
|
|
|
|
#include "common/math/math_util.h"
|
|
#include "graph/common/bcast.h"
|
|
#include "graph/utils/type_utils.h"
|
|
#include "inc/kernel_factory.h"
|
|
|
|
namespace ge {
|
|
namespace {
|
|
const size_t kAddFirstInput = 0;
|
|
const size_t kAddSecondInput = 1;
|
|
const size_t kAddFirstOutput = 0;
|
|
const size_t kAddInputSize = 2;
|
|
const size_t kAddOutputSize = 1;
|
|
|
|
#define SET_BCAST_ADD_CASE(DTYPE, TYPE) \
|
|
case (DTYPE): \
|
|
ret = BCastAdd<TYPE>(op_desc_ptr, input, v_output); \
|
|
break;
|
|
} // namespace
|
|
|
|
template <typename T>
|
|
Status AddKernel::OverflowCheck(const T &x, const T &y, DataType data_type) {
|
|
switch (data_type) {
|
|
case DT_INT8:
|
|
FMK_INT8_ADDCHECK(x, y)
|
|
break;
|
|
case DT_INT16:
|
|
FMK_INT16_ADDCHECK(x, y)
|
|
break;
|
|
case DT_INT32:
|
|
FMK_INT32_ADDCHECK(x, y)
|
|
break;
|
|
case DT_INT64:
|
|
FMK_INT64_ADDCHECK(x, y)
|
|
break;
|
|
case DT_UINT8:
|
|
FMK_UINT8_ADDCHECK(x, y)
|
|
break;
|
|
case DT_UINT16:
|
|
FMK_UINT16_ADDCHECK(x, y)
|
|
break;
|
|
case DT_UINT32:
|
|
FMK_UINT32_ADDCHECK(x, y)
|
|
break;
|
|
case DT_UINT64:
|
|
FMK_UINT64_ADDCHECK(x, y)
|
|
break;
|
|
case DT_FLOAT16:
|
|
FMK_FP16_ADDCHECK(x, y)
|
|
break;
|
|
case DT_FLOAT:
|
|
FMK_FLOAT_ADDCHECK(x, y)
|
|
break;
|
|
case DT_DOUBLE:
|
|
FMK_DOUBLE_ADDCHECK(x, y)
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
return SUCCESS;
|
|
}
|
|
|
|
template <typename InT>
|
|
Status AddKernel::BCastAdd(const OpDescPtr &op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
|
|
std::vector<GeTensorPtr> &v_output) {
|
|
// only broadcast shape
|
|
BCast bcast;
|
|
Status ret = bcast.GenerateBcastInfo(BCast::TransShapeToDimVec(input[kAddFirstInput]->GetTensorDesc()),
|
|
BCast::TransShapeToDimVec(input[kAddSecondInput]->GetTensorDesc()));
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Greater broadcasting failed.");
|
|
return ret;
|
|
}
|
|
|
|
std::vector<int64_t> x_indexes;
|
|
std::vector<int64_t> y_indexes;
|
|
bcast.BCastIndexes(x_indexes, y_indexes);
|
|
|
|
auto x1_data = reinterpret_cast<const InT *>(input[kAddFirstInput]->GetData().data());
|
|
auto x2_data = reinterpret_cast<const InT *>(input[kAddSecondInput]->GetData().data());
|
|
|
|
size_t data_num = x_indexes.size();
|
|
std::unique_ptr<InT[]> buf(new (std::nothrow) InT[data_num]());
|
|
if (buf == nullptr) {
|
|
GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast<size_t>(sizeof(InT) * data_num));
|
|
return MEMALLOC_FAILED;
|
|
}
|
|
|
|
DataType data_type = input[kAddFirstInput]->GetTensorDesc().GetDataType();
|
|
for (size_t i = 0; i < data_num; i++) {
|
|
auto x_index = *(x1_data + x_indexes[i]);
|
|
auto y_index = *(x2_data + y_indexes[i]);
|
|
if (OverflowCheck<InT>(x_index, y_index, data_type) != SUCCESS) {
|
|
GELOGE(PARAM_INVALID, "Result of add is overflow.");
|
|
return PARAM_INVALID;
|
|
}
|
|
*(buf.get() + i) = x_index + y_index;
|
|
}
|
|
|
|
GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(kAddFirstOutput));
|
|
if (output_ptr == nullptr) {
|
|
GELOGE(MEMALLOC_FAILED, "Make shared failed");
|
|
return MEMALLOC_FAILED;
|
|
}
|
|
output_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_num * sizeof(InT));
|
|
output_ptr->MutableTensorDesc().SetDataType(data_type);
|
|
vector<int64_t> bcast_dims = bcast.GetOutputShape();
|
|
output_ptr->MutableTensorDesc().SetShape(GeShape(bcast_dims));
|
|
v_output.push_back(output_ptr);
|
|
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status AddKernel::AddCheck(const OpDescPtr &op_desc_ptr, const std::vector<ConstGeTensorPtr> &input) {
|
|
if (op_desc_ptr == nullptr) {
|
|
GELOGW("Op_desc_ptr must not be null.");
|
|
return PARAM_INVALID;
|
|
}
|
|
// check how many inputs
|
|
if ((input.size() != kAddInputSize) || (op_desc_ptr->GetOutputsSize() != kAddOutputSize)) {
|
|
GELOGW("The number of input for add must be %zu, output number must be %zu.", kAddInputSize,
|
|
kAddOutputSize);
|
|
return PARAM_INVALID;
|
|
}
|
|
// input vector elements must not be null
|
|
if ((input[kAddFirstInput] == nullptr) || (input[kAddSecondInput] == nullptr)) {
|
|
GELOGW("Input vector elements must not be null.");
|
|
return PARAM_INVALID;
|
|
}
|
|
// Inputs must have the same datatype.
|
|
DataType data_type_0 = input[kAddFirstInput]->GetTensorDesc().GetDataType();
|
|
DataType data_type_1 = input[kAddSecondInput]->GetTensorDesc().GetDataType();
|
|
if (data_type_0 != data_type_1) {
|
|
GELOGW("Data type of inputs for add not matched, data_type_0:%s, data_type_1:%s",
|
|
TypeUtils::DataTypeToSerialString(data_type_0).c_str(),
|
|
TypeUtils::DataTypeToSerialString(data_type_1).c_str());
|
|
return PARAM_INVALID;
|
|
}
|
|
// Checking whether the weightdef contains data
|
|
if ((input[kAddFirstInput]->GetData().size() == 0) || (input[kAddSecondInput]->GetData().size() == 0)) {
|
|
GELOGW("Data size of input0 is %zu, input1 is %zu.", input[kAddFirstInput]->GetData().size(),
|
|
input[kAddSecondInput]->GetData().size());
|
|
return PARAM_INVALID;
|
|
}
|
|
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status AddKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
|
|
std::vector<GeTensorPtr> &v_output) {
|
|
if (AddCheck(op_desc_ptr, input) != SUCCESS) {
|
|
return NOT_CHANGED;
|
|
}
|
|
|
|
Status ret = NOT_CHANGED;
|
|
DataType data_type = input[kAddFirstInput]->GetTensorDesc().GetDataType();
|
|
switch (data_type) {
|
|
SET_BCAST_ADD_CASE(DT_INT8, int8_t)
|
|
SET_BCAST_ADD_CASE(DT_INT16, int16_t)
|
|
SET_BCAST_ADD_CASE(DT_INT32, int32_t)
|
|
SET_BCAST_ADD_CASE(DT_INT64, int64_t)
|
|
SET_BCAST_ADD_CASE(DT_UINT8, uint8_t)
|
|
SET_BCAST_ADD_CASE(DT_UINT16, uint16_t)
|
|
SET_BCAST_ADD_CASE(DT_UINT32, uint32_t)
|
|
SET_BCAST_ADD_CASE(DT_UINT64, uint64_t)
|
|
SET_BCAST_ADD_CASE(DT_FLOAT16, fp16_t)
|
|
SET_BCAST_ADD_CASE(DT_FLOAT, float)
|
|
SET_BCAST_ADD_CASE(DT_DOUBLE, double)
|
|
default:
|
|
GELOGI("Add kernel data type %s not support.", TypeUtils::DataTypeToSerialString(data_type).c_str());
|
|
return NOT_CHANGED;
|
|
}
|
|
|
|
if (ret != SUCCESS) {
|
|
GELOGW("Greater broadcasting failed.");
|
|
return NOT_CHANGED;
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
REGISTER_KERNEL(ADD, AddKernel);
|
|
} // namespace ge
|