/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "host_kernels/add_kernel.h" #include #include "common/math/math_util.h" #include "graph/common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" namespace ge { namespace { const size_t kAddFirstInput = 0; const size_t kAddSecondInput = 1; const size_t kAddFirstOutput = 0; const size_t kAddInputSize = 2; const size_t kAddOutputSize = 1; #define SET_BCAST_ADD_CASE(DTYPE, TYPE) \ case (DTYPE): \ ret = BCastAdd(op_desc_ptr, input, v_output); \ break; } // namespace template Status AddKernel::OverflowCheck(const T &x, const T &y, DataType data_type) { switch (data_type) { case DT_INT8: FMK_INT8_ADDCHECK(x, y) break; case DT_INT16: FMK_INT16_ADDCHECK(x, y) break; case DT_INT32: FMK_INT32_ADDCHECK(x, y) break; case DT_INT64: FMK_INT64_ADDCHECK(x, y) break; case DT_UINT8: FMK_UINT8_ADDCHECK(x, y) break; case DT_UINT16: FMK_UINT16_ADDCHECK(x, y) break; case DT_UINT32: FMK_UINT32_ADDCHECK(x, y) break; case DT_UINT64: FMK_UINT64_ADDCHECK(x, y) break; case DT_FLOAT16: FMK_FP16_ADDCHECK(x, y) break; case DT_FLOAT: FMK_FLOAT_ADDCHECK(x, y) break; case DT_DOUBLE: FMK_DOUBLE_ADDCHECK(x, y) break; default: break; } return SUCCESS; } template Status AddKernel::BCastAdd(const OpDescPtr &op_desc_ptr, const std::vector &input, std::vector &v_output) { // only broadcast shape BCast bcast; Status ret = bcast.GenerateBcastInfo(BCast::TransShapeToDimVec(input[kAddFirstInput]->GetTensorDesc()), BCast::TransShapeToDimVec(input[kAddSecondInput]->GetTensorDesc())); if (ret != SUCCESS) { GELOGE(ret, "Greater broadcasting failed."); return ret; } std::vector x_indexes; std::vector y_indexes; bcast.BCastIndexes(x_indexes, y_indexes); auto x1_data = reinterpret_cast(input[kAddFirstInput]->GetData().data()); auto x2_data = reinterpret_cast(input[kAddSecondInput]->GetData().data()); size_t data_num = x_indexes.size(); std::unique_ptr buf(new (std::nothrow) InT[data_num]()); if (buf == nullptr) { GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast(sizeof(InT) * data_num)); return MEMALLOC_FAILED; } DataType data_type = input[kAddFirstInput]->GetTensorDesc().GetDataType(); for (size_t i = 0; i < data_num; i++) { auto x_index = *(x1_data + x_indexes[i]); auto y_index = *(x2_data + y_indexes[i]); if (OverflowCheck(x_index, y_index, data_type) != SUCCESS) { GELOGE(PARAM_INVALID, "Result of add is overflow."); return PARAM_INVALID; } *(buf.get() + i) = x_index + y_index; } GeTensorPtr output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(kAddFirstOutput)); if (output_ptr == nullptr) { GELOGE(MEMALLOC_FAILED, "Make shared failed"); return MEMALLOC_FAILED; } output_ptr->SetData(reinterpret_cast(buf.get()), data_num * sizeof(InT)); output_ptr->MutableTensorDesc().SetDataType(data_type); vector bcast_dims = bcast.GetOutputShape(); output_ptr->MutableTensorDesc().SetShape(GeShape(bcast_dims)); v_output.push_back(output_ptr); return SUCCESS; } Status AddKernel::AddCheck(const OpDescPtr &op_desc_ptr, const std::vector &input) { if (op_desc_ptr == nullptr) { GELOGW("Op_desc_ptr must not be null."); return PARAM_INVALID; } // check how many inputs if ((input.size() != kAddInputSize) || (op_desc_ptr->GetOutputsSize() != kAddOutputSize)) { GELOGW("The number of input for add must be %zu, output number must be %zu.", kAddInputSize, kAddOutputSize); return PARAM_INVALID; } // input vector elements must not be null if ((input[kAddFirstInput] == nullptr) || (input[kAddSecondInput] == nullptr)) { GELOGW("Input vector elements must not be null."); return PARAM_INVALID; } // Inputs must have the same datatype. DataType data_type_0 = input[kAddFirstInput]->GetTensorDesc().GetDataType(); DataType data_type_1 = input[kAddSecondInput]->GetTensorDesc().GetDataType(); if (data_type_0 != data_type_1) { GELOGW("Data type of inputs for add not matched, data_type_0:%s, data_type_1:%s", TypeUtils::DataTypeToSerialString(data_type_0).c_str(), TypeUtils::DataTypeToSerialString(data_type_1).c_str()); return PARAM_INVALID; } // Checking whether the weightdef contains data if ((input[kAddFirstInput]->GetData().size() == 0) || (input[kAddSecondInput]->GetData().size() == 0)) { GELOGW("Data size of input0 is %zu, input1 is %zu.", input[kAddFirstInput]->GetData().size(), input[kAddSecondInput]->GetData().size()); return PARAM_INVALID; } return SUCCESS; } Status AddKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector &input, std::vector &v_output) { if (AddCheck(op_desc_ptr, input) != SUCCESS) { return NOT_CHANGED; } Status ret = NOT_CHANGED; DataType data_type = input[kAddFirstInput]->GetTensorDesc().GetDataType(); switch (data_type) { SET_BCAST_ADD_CASE(DT_INT8, int8_t) SET_BCAST_ADD_CASE(DT_INT16, int16_t) SET_BCAST_ADD_CASE(DT_INT32, int32_t) SET_BCAST_ADD_CASE(DT_INT64, int64_t) SET_BCAST_ADD_CASE(DT_UINT8, uint8_t) SET_BCAST_ADD_CASE(DT_UINT16, uint16_t) SET_BCAST_ADD_CASE(DT_UINT32, uint32_t) SET_BCAST_ADD_CASE(DT_UINT64, uint64_t) SET_BCAST_ADD_CASE(DT_FLOAT16, fp16_t) SET_BCAST_ADD_CASE(DT_FLOAT, float) SET_BCAST_ADD_CASE(DT_DOUBLE, double) default: GELOGI("Add kernel data type %s not support.", TypeUtils::DataTypeToSerialString(data_type).c_str()); return NOT_CHANGED; } if (ret != SUCCESS) { GELOGW("Greater broadcasting failed."); return NOT_CHANGED; } return SUCCESS; } REGISTER_KERNEL(ADD, AddKernel); } // namespace ge