You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
4.5 KiB
169 lines
4.5 KiB
5 years ago
|
/**
|
||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#include "graph/common/bcast.h"
|
||
|
|
||
|
#include <vector>
|
||
|
|
||
|
#include "common/math_util.h"
|
||
|
#include "common/util.h"
|
||
|
|
||
|
using domi::Status;
|
||
|
|
||
|
namespace ge {
|
||
|
Status BCast::GenerateBcastInfo(const kVecInt &sx, const kVecInt &sy) {
|
||
|
if (sx.size() == 0 && sy.size() == 0) {
|
||
|
result_.push_back(1);
|
||
|
x_reshape_.push_back(1);
|
||
|
x_bcast_.push_back(1);
|
||
|
y_reshape_.push_back(1);
|
||
|
y_bcast_.push_back(1);
|
||
|
} else {
|
||
|
kVecInt x = sx;
|
||
|
kVecInt y = sy;
|
||
|
Reverse(x);
|
||
|
Reverse(y);
|
||
|
ExtendTensorDim(x, y);
|
||
|
GE_RETURN_WITH_LOG_IF_ERROR(SetShapeDifferentInfo(x, y), "GenerateBcastInfo failed.");
|
||
|
}
|
||
|
ReverseAllIntermediateShapes();
|
||
|
return domi::SUCCESS;
|
||
|
}
|
||
|
|
||
|
Status BCast::SetShapeDifferentInfo(const kVecInt &x, const kVecInt &y) {
|
||
|
const int64_t n = x.size();
|
||
|
for (int64_t i = 0; i < n; ++i) {
|
||
|
const int64_t x_i = x[i];
|
||
|
GE_CHECK_GE(x_i, 0);
|
||
|
const int64_t y_i = y[i];
|
||
|
GE_CHECK_GE(y_i, 0);
|
||
|
int64_t output_i = 0;
|
||
|
int64_t x_bcast_i = 0;
|
||
|
int64_t y_bcast_i = 0;
|
||
|
|
||
|
if (x_i == y_i) {
|
||
|
output_i = x_i;
|
||
|
x_bcast_i = 1;
|
||
|
y_bcast_i = 1;
|
||
|
if (x_i == 1) {
|
||
|
grad_x_reduce_idx_.push_back(n - 1 - i);
|
||
|
grad_y_reduce_idx_.push_back(n - 1 - i);
|
||
|
}
|
||
|
} else if (x_i == 1) {
|
||
|
output_i = y_i;
|
||
|
x_bcast_i = y_i;
|
||
|
y_bcast_i = 1;
|
||
|
grad_x_reduce_idx_.push_back(n - 1 - i);
|
||
|
} else if (y_i == 1) {
|
||
|
output_i = x_i;
|
||
|
x_bcast_i = 1;
|
||
|
y_bcast_i = x_i;
|
||
|
grad_y_reduce_idx_.push_back(n - 1 - i);
|
||
|
} else {
|
||
|
GELOGE(domi::PARAM_INVALID,
|
||
|
"SetShapeDifferentInfo failed. Two tensor shapes are not compatible "
|
||
|
"according to the broadcasting rule.");
|
||
|
return domi::PARAM_INVALID;
|
||
|
}
|
||
|
output_.push_back(output_i);
|
||
|
result_.push_back(output_i);
|
||
|
x_reshape_.push_back(x_i);
|
||
|
x_bcast_.push_back(x_bcast_i);
|
||
|
y_reshape_.push_back(y_i);
|
||
|
y_bcast_.push_back(y_bcast_i);
|
||
|
}
|
||
|
return domi::SUCCESS;
|
||
|
}
|
||
|
|
||
|
void BCast::ExtendTensorDim(kVecInt &v_x, kVecInt &v_y) {
|
||
|
if (v_x.size() > v_y.size()) {
|
||
|
v_y.resize(v_x.size(), 1);
|
||
|
} else {
|
||
|
v_x.resize(v_y.size(), 1);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
BCast::kVecInt BCast::TransShapeToDimVec(const GeTensorDesc &shape) {
|
||
|
const size_t dim_num = shape.GetShape().GetDimNum();
|
||
|
BCast::kVecInt ret(dim_num);
|
||
|
for (size_t i = 0; i < dim_num; ++i) {
|
||
|
ret[i] = shape.GetShape().GetDim(i);
|
||
|
}
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
void BCast::Reverse(kVecInt &shape) { std::reverse(shape.begin(), shape.end()); }
|
||
|
|
||
|
void BCast::ReverseAllIntermediateShapes() {
|
||
|
// Reverse all intermediate shape params
|
||
|
Reverse(x_reshape_);
|
||
|
Reverse(x_bcast_);
|
||
|
Reverse(y_reshape_);
|
||
|
Reverse(y_bcast_);
|
||
|
Reverse(result_);
|
||
|
Reverse(output_);
|
||
|
Reverse(grad_x_reduce_idx_);
|
||
|
Reverse(grad_y_reduce_idx_);
|
||
|
}
|
||
|
|
||
|
void BCast::BCastIndexes(kVecInt &x_indexes, kVecInt &y_indexes) {
|
||
|
Reverse(x_reshape_);
|
||
|
Reverse(y_reshape_);
|
||
|
Reverse(output_);
|
||
|
|
||
|
// Process 0-th dimension
|
||
|
int64_t x_dim = 1;
|
||
|
int64_t y_dim = 1;
|
||
|
int64_t out_dim = 1;
|
||
|
|
||
|
// If x and y are both scalar, then output_ is empty
|
||
|
if (!output_.empty()) {
|
||
|
x_dim = x_reshape_.at(0);
|
||
|
y_dim = y_reshape_.at(0);
|
||
|
out_dim = output_.at(0);
|
||
|
}
|
||
|
|
||
|
int64_t x_bias = x_dim;
|
||
|
int64_t y_bias = y_dim;
|
||
|
|
||
|
for (int64_t i = 0; i < out_dim; i++) {
|
||
|
x_indexes.push_back(x_dim == 1 ? 0 : i);
|
||
|
y_indexes.push_back(y_dim == 1 ? 0 : i);
|
||
|
}
|
||
|
|
||
|
// Process the remaining dimensions
|
||
|
for (size_t i = 1; i < output_.size(); i++) {
|
||
|
x_dim = x_reshape_.at(i); // i-th dimension of x.
|
||
|
y_dim = y_reshape_.at(i); // i-th dimension of y.
|
||
|
out_dim = output_.at(i); // i-th dimension of output_.
|
||
|
|
||
|
int64_t stride = x_indexes.size();
|
||
|
for (int64_t j = 1; j < out_dim; j++) {
|
||
|
for (int64_t k = 0; k < stride; k++) {
|
||
|
x_indexes.push_back(x_indexes.at(k) + (x_dim == 1 ? 0 : (j * x_bias)));
|
||
|
y_indexes.push_back(y_indexes.at(k) + (y_dim == 1 ? 0 : (j * y_bias)));
|
||
|
}
|
||
|
}
|
||
|
x_bias *= x_dim;
|
||
|
y_bias *= y_dim;
|
||
|
}
|
||
|
|
||
|
Reverse(x_reshape_);
|
||
|
Reverse(y_reshape_);
|
||
|
Reverse(output_);
|
||
|
}
|
||
|
} // namespace ge
|