|
|
|
/**
|
|
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#ifndef GE_GRAPH_COMMON_BCAST_H_
|
|
|
|
#define GE_GRAPH_COMMON_BCAST_H_
|
|
|
|
|
|
|
|
#include <stdint.h>
|
|
|
|
#include <functional>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
#include "common/debug/log.h"
|
|
|
|
#include "common/types.h"
|
|
|
|
#include "framework/common/debug/ge_log.h"
|
|
|
|
#include "framework/common/ge_inner_error_codes.h"
|
|
|
|
#include "graph/attr_value.h"
|
|
|
|
#include "graph/ge_tensor.h"
|
|
|
|
#include "graph/utils/tensor_adapter.h"
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
static const size_t kMinDimNum = 2;
|
|
|
|
class BCast {
|
|
|
|
public:
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief define kVecInt
|
|
|
|
///
|
|
|
|
typedef std::vector<int64_t> kVecInt;
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief constructor
|
|
|
|
///
|
|
|
|
BCast() {}
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief destructor
|
|
|
|
///
|
|
|
|
~BCast() {}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief Not optimize intermediate shapes
|
|
|
|
/// @decrease dims, more efficient, set by user
|
|
|
|
/// @param [in] x first Tensor dim
|
|
|
|
/// @param [in] y second Tensor dim
|
|
|
|
/// @return SUCCESS broadcast message successfully generated
|
|
|
|
/// @return other broadcast message failed to generate
|
|
|
|
///
|
|
|
|
ge::Status GenerateBcastInfo(const kVecInt &x, const kVecInt &y);
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief get x_reshape
|
|
|
|
///
|
|
|
|
const kVecInt &GetXReshape() const { return x_reshape_; }
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief get x_bcast
|
|
|
|
///
|
|
|
|
const kVecInt &GetXBcast() const { return x_bcast_; }
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief get y_reshape
|
|
|
|
///
|
|
|
|
const kVecInt &GetYReshape() const { return y_reshape_; }
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief get y_bcast
|
|
|
|
///
|
|
|
|
const kVecInt &GetYBcast() const { return y_bcast_; }
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief get result_shape
|
|
|
|
///
|
|
|
|
const kVecInt &GetResultShape() const { return result_; }
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief get result_shape
|
|
|
|
///
|
|
|
|
const kVecInt &GetOutputShape() const { return output_; }
|
|
|
|
const kVecInt &GetGradXReduceIdx() const { return grad_x_reduce_idx_; }
|
|
|
|
const kVecInt &GetGradYReduceIdx() const { return grad_y_reduce_idx_; }
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief convert TensorDescriptor to kVecInt
|
|
|
|
/// @param [in] shape Tensor descriptor
|
|
|
|
/// @return kVecInt dim info
|
|
|
|
///
|
|
|
|
static kVecInt TransShapeToDimVec(const GeTensorDesc &shape);
|
|
|
|
|
|
|
|
void BCastIndexes(kVecInt &x_indexes, kVecInt &y_indexes);
|
|
|
|
template <typename InT, typename OutT>
|
|
|
|
Status BCastCompute(const std::vector<ConstGeTensorPtr> &input, std::vector<OutT> &v_output,
|
|
|
|
const std::function<OutT(InT const &, InT const &)> &func) {
|
|
|
|
Status ret;
|
|
|
|
if (func == nullptr) {
|
|
|
|
REPORT_INNER_ERROR("E19999", "Check param func nullptr");
|
|
|
|
GELOGE(domi::PARAM_INVALID, "Param func is null");
|
|
|
|
return domi::PARAM_INVALID;
|
|
|
|
}
|
|
|
|
// Min input num is 2
|
|
|
|
if (input.size() < kMinDimNum) {
|
|
|
|
REPORT_INNER_ERROR("E19999", "Param input.size():%zu < %zu, check invalid",
|
|
|
|
input.size(), kMinDimNum);
|
|
|
|
GELOGE(domi::PARAM_INVALID, "Input size is smaller than two.");
|
|
|
|
return domi::PARAM_INVALID;
|
|
|
|
}
|
|
|
|
// Only broadcast shape
|
|
|
|
ret =
|
|
|
|
GenerateBcastInfo(TransShapeToDimVec(input[0]->GetTensorDesc()), TransShapeToDimVec(input[1]->GetTensorDesc()));
|
|
|
|
if (ret != domi::SUCCESS) {
|
|
|
|
GELOGE(ret, "Greater broadcasting failed.");
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
kVecInt x_indexes;
|
|
|
|
kVecInt y_indexes;
|
|
|
|
BCastIndexes(x_indexes, y_indexes);
|
|
|
|
|
|
|
|
const void *x1_data = input[0]->GetData().data();
|
|
|
|
const void *x2_data = input[1]->GetData().data();
|
|
|
|
|
|
|
|
for (size_t i = 0; i < x_indexes.size(); i++) {
|
|
|
|
int64_t x_index = x_indexes[i];
|
|
|
|
int64_t y_index = y_indexes[i];
|
|
|
|
auto value = func((*(reinterpret_cast<const InT *>(x1_data) + x_index)),
|
|
|
|
(*(reinterpret_cast<const InT *>(x2_data) + y_index)));
|
|
|
|
v_output.push_back(value);
|
|
|
|
}
|
|
|
|
|
|
|
|
return domi::SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename InT, typename OutT>
|
|
|
|
Status BCastComputeCheck(const std::vector<ConstGeTensorPtr> &input, std::vector<OutT> &v_output,
|
|
|
|
const std::function<OutT(InT const &, InT const &, DataType &type, Status &)> &func) {
|
|
|
|
if (func == nullptr) {
|
|
|
|
REPORT_INNER_ERROR("E19999", "Check param func nullptr");
|
|
|
|
GELOGE(PARAM_INVALID, "Param func is null");
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
// Min input num is 2
|
|
|
|
if (input.size() < kMinDimNum) {
|
|
|
|
REPORT_INNER_ERROR("E19999", "Param input.size():%zu < %zu, check invalid",
|
|
|
|
input.size(), kMinDimNum);
|
|
|
|
GELOGE(PARAM_INVALID, "Input size is smaller than two.");
|
|
|
|
return PARAM_INVALID;
|
|
|
|
}
|
|
|
|
// Only broadcast shape
|
|
|
|
Status ret =
|
|
|
|
GenerateBcastInfo(TransShapeToDimVec(input[0]->GetTensorDesc()), TransShapeToDimVec(input[1]->GetTensorDesc()));
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
GELOGE(ret, "Greater broadcasting failed.");
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
DataType data_type = input[0]->GetTensorDesc().GetDataType();
|
|
|
|
kVecInt x_indexes;
|
|
|
|
kVecInt y_indexes;
|
|
|
|
BCastIndexes(x_indexes, y_indexes);
|
|
|
|
|
|
|
|
const void *x1_data = input[0]->GetData().data();
|
|
|
|
const void *x2_data = input[1]->GetData().data();
|
|
|
|
|
|
|
|
for (size_t i = 0; i < x_indexes.size(); i++) {
|
|
|
|
int64_t x_index = x_indexes[i];
|
|
|
|
int64_t y_index = y_indexes[i];
|
|
|
|
auto value = func((*(reinterpret_cast<const InT *>(x1_data) + x_index)),
|
|
|
|
(*(reinterpret_cast<const InT *>(x2_data) + y_index)), data_type, ret);
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
REPORT_INNER_ERROR("E19999", "BCastComputeCheck func execute failed, datatype is %d.", data_type);
|
|
|
|
GELOGE(ret, "BCastComputeCheck func execute failed, datatype is %d.", data_type);
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
v_output.push_back(value);
|
|
|
|
}
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief reverse elements in kVecInt
|
|
|
|
/// @param [in] shape dim info
|
|
|
|
/// @return null
|
|
|
|
///
|
|
|
|
static void Reverse(kVecInt &shape);
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief two Tensor with different shape, set broadcast info
|
|
|
|
/// @param [in] x first input Tensor dim info
|
|
|
|
/// @param [in] y second input Tensor dim info
|
|
|
|
/// @return null
|
|
|
|
///
|
|
|
|
ge::Status SetShapeDifferentInfo(const kVecInt &x, const kVecInt &y);
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief extend Tensor dim
|
|
|
|
/// @param [in] x first input Tensor dim info
|
|
|
|
/// @param [in] y second input Tensor dim info
|
|
|
|
/// @return null
|
|
|
|
///
|
|
|
|
void ExtendTensorDim(kVecInt &x, kVecInt &y);
|
|
|
|
///
|
|
|
|
/// @ingroup domi_calibration
|
|
|
|
/// @brief reverse all intermediate shape params
|
|
|
|
/// @param [in] void
|
|
|
|
/// @return null
|
|
|
|
///
|
|
|
|
void ReverseAllIntermediateShapes();
|
|
|
|
|
|
|
|
kVecInt x_reshape_;
|
|
|
|
kVecInt x_bcast_;
|
|
|
|
kVecInt y_reshape_;
|
|
|
|
kVecInt y_bcast_;
|
|
|
|
kVecInt result_;
|
|
|
|
kVecInt output_;
|
|
|
|
kVecInt grad_x_reduce_idx_;
|
|
|
|
kVecInt grad_y_reduce_idx_;
|
|
|
|
};
|
|
|
|
} // namespace ge
|
|
|
|
|
|
|
|
#endif // GE_GRAPH_COMMON_BCAST_H_
|