You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/host_kernels/range_kernel.cc

172 lines
5.9 KiB

5 years ago
/**
* Copyright 2020 Huawei Technologies Co., Ltd
5 years ago
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "host_kernels/range_kernel.h"
5 years ago
#include <memory>
#include <set>
#include "common/debug/log.h"
#include "common/fp16_t.h"
#include "common/types.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/utils/type_utils.h"
#include "inc/kernel_factory.h"
namespace ge {
namespace {
constexpr size_t kRangeInputNum = 3;
constexpr uint32_t kRangeDimNum = 0;
constexpr size_t kStartIndex = 0;
constexpr size_t kLimitIndex = 1;
constexpr size_t kDeltaIndex = 2;
const std::set<DataType> kRangeSupportedType = {DT_INT32, DT_FLOAT};
5 years ago
} // namespace
Status RangeKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
std::vector<GeTensorPtr> &v_output) {
GELOGD("RangeKernel in");
if (op_desc_ptr == nullptr) {
GELOGE(PARAM_INVALID, "Parameter's invalid, input opDescPtr is nullptr.");
return PARAM_INVALID;
}
Status ret = RangeCheck(input);
if (ret != SUCCESS) {
return ret;
}
GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
if (output_ptr == nullptr) {
GELOGE(MEMALLOC_FAILED, "Make shared failed");
return MEMALLOC_FAILED;
}
ConstGeTensorPtr start = input.at(kStartIndex);
ConstGeTensorPtr limit = input.at(kLimitIndex);
ConstGeTensorPtr delta = input.at(kDeltaIndex);
5 years ago
DataType data_type = delta->GetTensorDesc().GetDataType();
if (data_type == DT_FLOAT) {
if (GetRange(*reinterpret_cast<const float *>(start->GetData().data()),
*reinterpret_cast<const float *>(limit->GetData().data()),
*reinterpret_cast<const float *>(delta->GetData().data()), output_ptr) != SUCCESS) {
return PARAM_INVALID;
}
} else if (data_type == DT_INT32) {
if (GetRange(*reinterpret_cast<const int32_t *>(start->GetData().data()),
*reinterpret_cast<const int32_t *>(limit->GetData().data()),
*reinterpret_cast<const int32_t *>(delta->GetData().data()), output_ptr) != SUCCESS) {
return PARAM_INVALID;
}
}
output_ptr->MutableTensorDesc().SetDataType(data_type);
v_output.push_back(output_ptr);
return SUCCESS;
}
Status RangeKernel::RangeCheck(const std::vector<ConstGeTensorPtr> &input) {
// check input number
if (input.size() != kRangeInputNum) {
GELOGI("The number of input for Range must be %zu.", kRangeInputNum);
return NOT_CHANGED;
}
ConstGeTensorPtr start = input.at(0);
ConstGeTensorPtr limit = input.at(1);
ConstGeTensorPtr delta = input.at(2);
GE_CHECK_NOTNULL(start);
GE_CHECK_NOTNULL(limit);
GE_CHECK_NOTNULL(delta);
// check whether there is data in Tensor
if (start->GetData().size() == 0 || limit->GetData().size() == 0 || delta->GetData().size() == 0) {
GELOGI("Check data size fail. start: %zu, limit: %zu, delta: %zu", start->GetData().size(), limit->GetData().size(),
delta->GetData().size());
return NOT_CHANGED;
}
// check whether the data types are the same
DataType type = start->GetTensorDesc().GetDataType();
if ((type != limit->GetTensorDesc().GetDataType()) || (type != delta->GetTensorDesc().GetDataType())) {
5 years ago
GELOGI("Data type of inputs for Range not matched.");
return NOT_CHANGED;
}
// check whether are all scalars
size_t range_dim = static_cast<size_t>(kRangeDimNum);
bool all_scalar = (start->GetTensorDesc().MutableShape().GetDimNum() == range_dim) &&
(limit->GetTensorDesc().MutableShape().GetDimNum() == range_dim) &&
(delta->GetTensorDesc().MutableShape().GetDimNum() == range_dim);
5 years ago
if (!all_scalar) {
GELOGI("Inputs for Range are not all scalars.");
return NOT_CHANGED;
}
// check if input data type is supported
if (kRangeSupportedType.find(type) == kRangeSupportedType.end()) {
5 years ago
GELOGI("Range does not support this Data type: %s", TypeUtils::DataTypeToSerialString(type).c_str());
return NOT_CHANGED;
}
return SUCCESS;
}
template <typename T>
Status RangeKernel::GetRange(const T start, const T limit, const T delta, GeTensorPtr &output) {
// check whether start, limit, delta is valid
if (delta == 0) {
GELOGE(PARAM_INVALID, "Requires delta != 0");
return PARAM_INVALID;
}
if (start > limit && delta > 0) {
GELOGE(PARAM_INVALID, "Requires start <= limit when delta > 0");
return PARAM_INVALID;
}
if (start < limit && delta < 0) {
GELOGE(PARAM_INVALID, "Requires start >= limit when delta < 0");
return PARAM_INVALID;
}
int64_t size = (std::is_integral<T>::value ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
: std::ceil(std::abs((limit - start) / delta)));
output->MutableTensorDesc().SetShape(GeShape()); // when size is 0
if (size > 0) {
unique_ptr<T[]> buf(new (std::nothrow) T[size]);
if (buf == nullptr) {
GELOGE(MEMALLOC_FAILED, "New buf failed.");
return MEMALLOC_FAILED;
}
T val = start;
for (int64_t i = 0; i < size; ++i) {
buf[i] = val;
val += delta;
}
if (output->SetData(reinterpret_cast<uint8_t *>(buf.get()), size * sizeof(T)) != GRAPH_SUCCESS) {
GELOGW("GetRange: SetData failed");
}
output->MutableTensorDesc().SetShape(GeShape({size}));
}
return SUCCESS;
}
REGISTER_KERNEL(RANGE, RangeKernel);
} // namespace ge