/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "host_kernels/rsqrt_kernel.h" #include #include #include "common/debug/ge_log.h" #include "common/debug/log.h" #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" #include "common/math/math_util.h" #include "framework/common/types.h" namespace ge { namespace { const size_t kRsqrtInputSize = 1; const size_t kRsqrtInputIndex0 = 0; template Status ZeroCheck(T x, const DataType &data_type) { switch (data_type) { case DT_FLOAT16: FMK_FP16_ZEROCHECK(static_cast(x)) break; case DT_FLOAT: FMK_FLOAT_ZEROCHECK(static_cast(x)) break; case DT_DOUBLE: FMK_DOUBLE_ZEROCHECK(static_cast(x)) break; default: break; } return SUCCESS; } #define SET_RSQRT_CASE(DTYPE, TYPE) \ case (DTYPE): \ ret = RsqrtKernel::RsqrtCompute(input_ptr, output_ptr); \ break; } // namespace template Status RsqrtKernel::RsqrtCompute(ConstGeTensorPtr &input_tensor_ptr, GeTensorPtr &output_tensor_ptr) { GE_CHECK_NOTNULL(input_tensor_ptr); GE_CHECK_NOTNULL(output_tensor_ptr); size_t data_size = input_tensor_ptr->GetData().size(); size_t data_count = data_size / sizeof(T); auto data_type = input_tensor_ptr->GetTensorDesc().GetDataType(); if (data_count > 0) { unique_ptr buf(new(std::nothrow) T[data_count]()); if (buf == nullptr) { GELOGW("New buf failed"); return NOT_CHANGED; } auto ptr = const_cast(reinterpret_cast(input_tensor_ptr->GetData().data())); for (size_t i = 0; i < data_count; i++) { if (ZeroCheck(*(ptr + i), data_type) != SUCCESS) { GELOGW("Rsqrt: The input data can not less than or equal to zero, rsqrt folding failed."); return NOT_CHANGED; } switch (data_type) { case DT_FLOAT16: { double val = static_cast(*(reinterpret_cast(input_tensor_ptr->GetData().data()) + i)); double drSqrt = 1.0 / std::sqrt(val); buf[i] = drSqrt; break; } case DT_FLOAT:{ float denominator = std::sqrt(*(reinterpret_cast(input_tensor_ptr->GetData().data()) + i)); buf[i] = static_cast(1 / denominator); break; } case DT_DOUBLE: { double denominator = std::sqrt(*(reinterpret_cast(input_tensor_ptr->GetData().data()) + i)); buf[i] = static_cast(1 / denominator); break; } default: GELOGW("Input data type must be FP16, FP32 and DOUBLE."); return NOT_CHANGED; } } GE_IF_BOOL_EXEC(output_tensor_ptr->SetData(reinterpret_cast(buf.get()), data_size) != GRAPH_SUCCESS, GELOGW("Set data failed"); return NOT_CHANGED); output_tensor_ptr->MutableTensorDesc().SetDataType(data_type); output_tensor_ptr->MutableTensorDesc().SetShape(input_tensor_ptr->GetTensorDesc().GetShape()); } return SUCCESS; } Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector &input, std::vector &v_output) { GELOGI("RsqrtKernel in."); GE_CHECK_NOTNULL(op_desc_ptr); // check input size if (input.size() != kRsqrtInputSize) { GELOGW("The number of input for rsqrt must be %zu.", kRsqrtInputSize); return NOT_CHANGED; } ConstGeTensorPtr input_ptr = input.at(kRsqrtInputIndex0); GE_CHECK_NOTNULL(input_ptr); // Index 0 can always gets a GeTensorDesc object from any OpDescPtr. auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { GELOGW("MakeShared GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } Status ret = NOT_CHANGED; auto dtype = input_ptr->GetTensorDesc().GetDataType(); switch (dtype) { SET_RSQRT_CASE(DT_FLOAT16, fp16_t) SET_RSQRT_CASE(DT_FLOAT, float) SET_RSQRT_CASE(DT_DOUBLE, double) default: GELOGW("Input data type must be FP16, FP32 and DOUBLE."); return NOT_CHANGED; } if (ret != SUCCESS) { GELOGW("Rsqrt folding failed."); return NOT_CHANGED; } v_output.push_back(output_ptr); GELOGI("RsqrtKernel success."); return SUCCESS; } REGISTER_KERNEL(RSQRT, RsqrtKernel); } // namespace ge