|
|
|
/**
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include "host_kernels/range_kernel.h"
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
#include "common/debug/log.h"
|
|
|
|
#include "common/fp16_t.h"
|
|
|
|
#include "common/types.h"
|
|
|
|
#include "common/util.h"
|
|
|
|
#include "framework/common/debug/ge_log.h"
|
|
|
|
#include "framework/common/ge_inner_error_codes.h"
|
|
|
|
#include "graph/utils/type_utils.h"
|
|
|
|
#include "inc/kernel_factory.h"
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
namespace {
|
|
|
|
constexpr size_t kRangeInputNum = 3;
|
|
|
|
constexpr uint32_t kRangeDimNum = 0;
|
|
|
|
constexpr size_t kStartIndex = 0;
|
|
|
|
constexpr size_t kLimitIndex = 1;
|
|
|
|
constexpr size_t kDeltaIndex = 2;
|
|
|
|
const std::set<DataType> kRangeSupportedType = {DT_INT32, DT_FLOAT};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
Status RangeKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
|
|
|
|
std::vector<GeTensorPtr> &v_output) {
|
|
|
|
GELOGD("RangeKernel in");
|
|
|
|
if (op_desc_ptr == nullptr) {
|
|
|
|
GELOGE(PARAM_INVALID, "Parameter's invalid, input opDescPtr is nullptr.");
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
Status ret = RangeCheck(input);
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
|
|
|
|
if (output_ptr == nullptr) {
|
|
|
|
GELOGE(MEMALLOC_FAILED, "Make shared failed");
|
|
|
|
return MEMALLOC_FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
ConstGeTensorPtr start = input.at(kStartIndex);
|
|
|
|
ConstGeTensorPtr limit = input.at(kLimitIndex);
|
|
|
|
ConstGeTensorPtr delta = input.at(kDeltaIndex);
|
|
|
|
DataType data_type = delta->GetTensorDesc().GetDataType();
|
|
|
|
if (data_type == DT_FLOAT) {
|
|
|
|
if (GetRange(*reinterpret_cast<const float *>(start->GetData().data()),
|
|
|
|
*reinterpret_cast<const float *>(limit->GetData().data()),
|
|
|
|
*reinterpret_cast<const float *>(delta->GetData().data()), output_ptr) != SUCCESS) {
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
} else if (data_type == DT_INT32) {
|
|
|
|
if (GetRange(*reinterpret_cast<const int32_t *>(start->GetData().data()),
|
|
|
|
*reinterpret_cast<const int32_t *>(limit->GetData().data()),
|
|
|
|
*reinterpret_cast<const int32_t *>(delta->GetData().data()), output_ptr) != SUCCESS) {
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
output_ptr->MutableTensorDesc().SetDataType(data_type);
|
|
|
|
v_output.push_back(output_ptr);
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
Status RangeKernel::RangeCheck(const std::vector<ConstGeTensorPtr> &input) {
|
|
|
|
// check input number
|
|
|
|
if (input.size() != kRangeInputNum) {
|
|
|
|
GELOGI("The number of input for Range must be %zu.", kRangeInputNum);
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
|
|
|
|
ConstGeTensorPtr start = input.at(0);
|
|
|
|
ConstGeTensorPtr limit = input.at(1);
|
|
|
|
ConstGeTensorPtr delta = input.at(2);
|
|
|
|
|
|
|
|
GE_CHECK_NOTNULL(start);
|
|
|
|
GE_CHECK_NOTNULL(limit);
|
|
|
|
GE_CHECK_NOTNULL(delta);
|
|
|
|
// check whether there is data in Tensor
|
|
|
|
if (start->GetData().size() == 0 || limit->GetData().size() == 0 || delta->GetData().size() == 0) {
|
|
|
|
GELOGI("Check data size fail. start: %zu, limit: %zu, delta: %zu", start->GetData().size(), limit->GetData().size(),
|
|
|
|
delta->GetData().size());
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
|
|
|
|
// check whether the data types are the same
|
|
|
|
DataType type = start->GetTensorDesc().GetDataType();
|
|
|
|
if ((type != limit->GetTensorDesc().GetDataType()) || (type != delta->GetTensorDesc().GetDataType())) {
|
|
|
|
GELOGI("Data type of inputs for Range not matched.");
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
|
|
|
|
// check whether are all scalars
|
|
|
|
size_t range_dim = static_cast<size_t>(kRangeDimNum);
|
|
|
|
bool all_scalar = (start->GetTensorDesc().MutableShape().GetDimNum() == range_dim) &&
|
|
|
|
(limit->GetTensorDesc().MutableShape().GetDimNum() == range_dim) &&
|
|
|
|
(delta->GetTensorDesc().MutableShape().GetDimNum() == range_dim);
|
|
|
|
if (!all_scalar) {
|
|
|
|
GELOGI("Inputs for Range are not all scalars.");
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
|
|
|
|
// check if input data type is supported
|
|
|
|
if (kRangeSupportedType.find(type) == kRangeSupportedType.end()) {
|
|
|
|
GELOGI("Range does not support this Data type: %s", TypeUtils::DataTypeToSerialString(type).c_str());
|
|
|
|
return NOT_CHANGED;
|
|
|
|
}
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
Status RangeKernel::GetRange(const T start, const T limit, const T delta, GeTensorPtr &output) {
|
|
|
|
// check whether start, limit, delta is valid
|
|
|
|
if (delta == 0) {
|
|
|
|
GELOGE(PARAM_INVALID, "Requires delta != 0");
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
if (start > limit && delta > 0) {
|
|
|
|
GELOGE(PARAM_INVALID, "Requires start <= limit when delta > 0");
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
if (start < limit && delta < 0) {
|
|
|
|
GELOGE(PARAM_INVALID, "Requires start >= limit when delta < 0");
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t size = (std::is_integral<T>::value ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
|
|
|
|
: std::ceil(std::abs((limit - start) / delta)));
|
|
|
|
output->MutableTensorDesc().SetShape(GeShape()); // when size is 0
|
|
|
|
|
|
|
|
if (size > 0) {
|
|
|
|
unique_ptr<T[]> buf(new (std::nothrow) T[size]);
|
|
|
|
if (buf == nullptr) {
|
|
|
|
GELOGE(MEMALLOC_FAILED, "New buf failed.");
|
|
|
|
return MEMALLOC_FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
T val = start;
|
|
|
|
for (int64_t i = 0; i < size; ++i) {
|
|
|
|
buf[i] = val;
|
|
|
|
val += delta;
|
|
|
|
}
|
|
|
|
if (output->SetData(reinterpret_cast<uint8_t *>(buf.get()), size * sizeof(T)) != GRAPH_SUCCESS) {
|
|
|
|
GELOGW("GetRange: SetData failed");
|
|
|
|
}
|
|
|
|
output->MutableTensorDesc().SetShape(GeShape({size}));
|
|
|
|
}
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
REGISTER_KERNEL(RANGE, RangeKernel);
|
|
|
|
} // namespace ge
|