You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/common/bcast.cc

169 lines
4.5 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/common/bcast.h"
#include <vector>
#include "common/math_util.h"
#include "common/util.h"
using domi::Status;
namespace ge {
Status BCast::GenerateBcastInfo(const kVecInt &sx, const kVecInt &sy) {
if (sx.size() == 0 && sy.size() == 0) {
result_.push_back(1);
x_reshape_.push_back(1);
x_bcast_.push_back(1);
y_reshape_.push_back(1);
y_bcast_.push_back(1);
} else {
kVecInt x = sx;
kVecInt y = sy;
Reverse(x);
Reverse(y);
ExtendTensorDim(x, y);
GE_RETURN_WITH_LOG_IF_ERROR(SetShapeDifferentInfo(x, y), "GenerateBcastInfo failed.");
}
ReverseAllIntermediateShapes();
return domi::SUCCESS;
}
Status BCast::SetShapeDifferentInfo(const kVecInt &x, const kVecInt &y) {
const int64_t n = x.size();
for (int64_t i = 0; i < n; ++i) {
const int64_t x_i = x[i];
GE_CHECK_GE(x_i, 0);
const int64_t y_i = y[i];
GE_CHECK_GE(y_i, 0);
int64_t output_i = 0;
int64_t x_bcast_i = 0;
int64_t y_bcast_i = 0;
if (x_i == y_i) {
output_i = x_i;
x_bcast_i = 1;
y_bcast_i = 1;
if (x_i == 1) {
grad_x_reduce_idx_.push_back(n - 1 - i);
grad_y_reduce_idx_.push_back(n - 1 - i);
}
} else if (x_i == 1) {
output_i = y_i;
x_bcast_i = y_i;
y_bcast_i = 1;
grad_x_reduce_idx_.push_back(n - 1 - i);
} else if (y_i == 1) {
output_i = x_i;
x_bcast_i = 1;
y_bcast_i = x_i;
grad_y_reduce_idx_.push_back(n - 1 - i);
} else {
GELOGE(domi::PARAM_INVALID,
"SetShapeDifferentInfo failed. Two tensor shapes are not compatible "
"according to the broadcasting rule.");
return domi::PARAM_INVALID;
}
output_.push_back(output_i);
result_.push_back(output_i);
x_reshape_.push_back(x_i);
x_bcast_.push_back(x_bcast_i);
y_reshape_.push_back(y_i);
y_bcast_.push_back(y_bcast_i);
}
return domi::SUCCESS;
}
void BCast::ExtendTensorDim(kVecInt &v_x, kVecInt &v_y) {
if (v_x.size() > v_y.size()) {
v_y.resize(v_x.size(), 1);
} else {
v_x.resize(v_y.size(), 1);
}
}
BCast::kVecInt BCast::TransShapeToDimVec(const GeTensorDesc &shape) {
const size_t dim_num = shape.GetShape().GetDimNum();
BCast::kVecInt ret(dim_num);
for (size_t i = 0; i < dim_num; ++i) {
ret[i] = shape.GetShape().GetDim(i);
}
return ret;
}
void BCast::Reverse(kVecInt &shape) { std::reverse(shape.begin(), shape.end()); }
void BCast::ReverseAllIntermediateShapes() {
// Reverse all intermediate shape params
Reverse(x_reshape_);
Reverse(x_bcast_);
Reverse(y_reshape_);
Reverse(y_bcast_);
Reverse(result_);
Reverse(output_);
Reverse(grad_x_reduce_idx_);
Reverse(grad_y_reduce_idx_);
}
void BCast::BCastIndexes(kVecInt &x_indexes, kVecInt &y_indexes) {
Reverse(x_reshape_);
Reverse(y_reshape_);
Reverse(output_);
// Process 0-th dimension
int64_t x_dim = 1;
int64_t y_dim = 1;
int64_t out_dim = 1;
// If x and y are both scalar, then output_ is empty
if (!output_.empty()) {
x_dim = x_reshape_.at(0);
y_dim = y_reshape_.at(0);
out_dim = output_.at(0);
}
int64_t x_bias = x_dim;
int64_t y_bias = y_dim;
for (int64_t i = 0; i < out_dim; i++) {
x_indexes.push_back(x_dim == 1 ? 0 : i);
y_indexes.push_back(y_dim == 1 ? 0 : i);
}
// Process the remaining dimensions
for (size_t i = 1; i < output_.size(); i++) {
x_dim = x_reshape_.at(i); // i-th dimension of x.
y_dim = y_reshape_.at(i); // i-th dimension of y.
out_dim = output_.at(i); // i-th dimension of output_.
int64_t stride = x_indexes.size();
for (int64_t j = 1; j < out_dim; j++) {
for (int64_t k = 0; k < stride; k++) {
x_indexes.push_back(x_indexes.at(k) + (x_dim == 1 ? 0 : (j * x_bias)));
y_indexes.push_back(y_indexes.at(k) + (y_dim == 1 ? 0 : (j * y_bias)));
}
}
x_bias *= x_dim;
y_bias *= y_dim;
}
Reverse(x_reshape_);
Reverse(y_reshape_);
Reverse(output_);
}
} // namespace ge