/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "host_kernels/range_kernel.h" #include #include #include "common/debug/log.h" #include "common/fp16_t.h" #include "common/types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" namespace ge { namespace { constexpr size_t kRangeInputNum = 3; constexpr uint32_t kRangeDimNum = 0; constexpr size_t kStartIndex = 0; constexpr size_t kLimitIndex = 1; constexpr size_t kDeltaIndex = 2; const std::set kRangeSupportedType = {DT_INT32, DT_FLOAT}; } // namespace Status RangeKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector &input, std::vector &v_output) { GELOGD("RangeKernel in"); if (op_desc_ptr == nullptr) { GELOGE(PARAM_INVALID, "Parameter's invalid, input opDescPtr is nullptr."); return PARAM_INVALID; } Status ret = RangeCheck(input); if (ret != SUCCESS) { return ret; } GeTensorPtr output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(0)); if (output_ptr == nullptr) { GELOGE(MEMALLOC_FAILED, "Make shared failed"); return MEMALLOC_FAILED; } ConstGeTensorPtr start = input.at(kStartIndex); ConstGeTensorPtr limit = input.at(kLimitIndex); ConstGeTensorPtr delta = input.at(kDeltaIndex); DataType data_type = delta->GetTensorDesc().GetDataType(); if (data_type == DT_FLOAT) { if (GetRange(*reinterpret_cast(start->GetData().data()), *reinterpret_cast(limit->GetData().data()), *reinterpret_cast(delta->GetData().data()), output_ptr) != SUCCESS) { return PARAM_INVALID; } } else if (data_type == DT_INT32) { if (GetRange(*reinterpret_cast(start->GetData().data()), *reinterpret_cast(limit->GetData().data()), *reinterpret_cast(delta->GetData().data()), output_ptr) != SUCCESS) { return PARAM_INVALID; } } output_ptr->MutableTensorDesc().SetDataType(data_type); v_output.push_back(output_ptr); return SUCCESS; } Status RangeKernel::RangeCheck(const std::vector &input) { // check input number if (input.size() != kRangeInputNum) { GELOGI("The number of input for Range must be %zu.", kRangeInputNum); return NOT_CHANGED; } ConstGeTensorPtr start = input.at(0); ConstGeTensorPtr limit = input.at(1); ConstGeTensorPtr delta = input.at(2); GE_CHECK_NOTNULL(start); GE_CHECK_NOTNULL(limit); GE_CHECK_NOTNULL(delta); // check whether there is data in Tensor if (start->GetData().size() == 0 || limit->GetData().size() == 0 || delta->GetData().size() == 0) { GELOGI("Check data size fail. start: %zu, limit: %zu, delta: %zu", start->GetData().size(), limit->GetData().size(), delta->GetData().size()); return NOT_CHANGED; } // check whether the data types are the same DataType type = start->GetTensorDesc().GetDataType(); if ((type != limit->GetTensorDesc().GetDataType()) || (type != delta->GetTensorDesc().GetDataType())) { GELOGI("Data type of inputs for Range not matched."); return NOT_CHANGED; } // check whether are all scalars size_t range_dim = static_cast(kRangeDimNum); bool all_scalar = (start->GetTensorDesc().MutableShape().GetDimNum() == range_dim) && (limit->GetTensorDesc().MutableShape().GetDimNum() == range_dim) && (delta->GetTensorDesc().MutableShape().GetDimNum() == range_dim); if (!all_scalar) { GELOGI("Inputs for Range are not all scalars."); return NOT_CHANGED; } // check if input data type is supported if (kRangeSupportedType.find(type) == kRangeSupportedType.end()) { GELOGI("Range does not support this Data type: %s", TypeUtils::DataTypeToSerialString(type).c_str()); return NOT_CHANGED; } return SUCCESS; } template Status RangeKernel::GetRange(const T start, const T limit, const T delta, GeTensorPtr &output) { // check whether start, limit, delta is valid if (delta == 0) { GELOGE(PARAM_INVALID, "Requires delta != 0"); return PARAM_INVALID; } if (start > limit && delta > 0) { GELOGE(PARAM_INVALID, "Requires start <= limit when delta > 0"); return PARAM_INVALID; } if (start < limit && delta < 0) { GELOGE(PARAM_INVALID, "Requires start >= limit when delta < 0"); return PARAM_INVALID; } int64_t size = (std::is_integral::value ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) : std::ceil(std::abs((limit - start) / delta))); output->MutableTensorDesc().SetShape(GeShape()); // when size is 0 if (size > 0) { unique_ptr buf(new (std::nothrow) T[size]); if (buf == nullptr) { GELOGE(MEMALLOC_FAILED, "New buf failed."); return MEMALLOC_FAILED; } T val = start; for (int64_t i = 0; i < size; ++i) { buf[i] = val; val += delta; } if (output->SetData(reinterpret_cast(buf.get()), size * sizeof(T)) != GRAPH_SUCCESS) { GELOGW("GetRange: SetData failed"); } output->MutableTensorDesc().SetShape(GeShape({size})); } return SUCCESS; } REGISTER_KERNEL(RANGE, RangeKernel); } // namespace ge