|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
/**
|
|
|
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
|
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@ -121,6 +121,94 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
const std::string op_name = primitive->name();
|
|
|
|
|
CheckArgsSize(op_name, args_spec_list, 1);
|
|
|
|
|
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_x);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_x->element());
|
|
|
|
|
|
|
|
|
|
ValuePtr keep_dims = primitive->GetAttr("keep_dims");
|
|
|
|
|
MS_EXCEPTION_IF_NULL(keep_dims);
|
|
|
|
|
if (!keep_dims->isa<BoolImm>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Keep_dims should be Bool.";
|
|
|
|
|
}
|
|
|
|
|
bool keep_dims_value = GetValue<bool>(keep_dims);
|
|
|
|
|
|
|
|
|
|
ValuePtr axis = primitive->GetAttr("axis");
|
|
|
|
|
MS_EXCEPTION_IF_NULL(axis);
|
|
|
|
|
|
|
|
|
|
auto check_axis = [](int64_t &axis, const size_t dim) -> void {
|
|
|
|
|
int64_t dim_ = static_cast<int64_t>(dim);
|
|
|
|
|
if (axis < -dim_ || axis >= dim_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis;
|
|
|
|
|
}
|
|
|
|
|
if (axis >= -dim_ && axis < 0) {
|
|
|
|
|
axis += dim_;
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void {
|
|
|
|
|
if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) {
|
|
|
|
|
auto axis_ptr_list =
|
|
|
|
|
axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value();
|
|
|
|
|
if (!axis_ptr_list.size()) {
|
|
|
|
|
if (keep_dims_value) shape.insert(shape.end(), x_shape.size(), 1);
|
|
|
|
|
} else {
|
|
|
|
|
shape.insert(shape.end(), x_shape.begin(), x_shape.end());
|
|
|
|
|
ValuePtrList axis_items = axis_ptr_list;
|
|
|
|
|
ValuePtrList::iterator it;
|
|
|
|
|
ValuePtrList::reverse_iterator it_re;
|
|
|
|
|
int64_t axis_value;
|
|
|
|
|
if (keep_dims_value) {
|
|
|
|
|
for (it = axis_items.begin(); it != axis_items.end(); ++it) {
|
|
|
|
|
axis_value = GetValue<int64_t>(*it);
|
|
|
|
|
check_axis(axis_value, x_shape.size());
|
|
|
|
|
shape[axis_value] = 1;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
std::sort(axis_items.begin(), axis_items.end());
|
|
|
|
|
for (it_re = axis_items.rbegin(); it_re != axis_items.rend(); ++it_re) {
|
|
|
|
|
axis_value = GetValue<int64_t>(*it_re);
|
|
|
|
|
check_axis(axis_value, x_shape.size());
|
|
|
|
|
shape.erase(std::begin(shape) + axis_value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
|
|
|
|
|
shape.insert(shape.end(), x_shape.begin(), x_shape.end());
|
|
|
|
|
int64_t axis_value = GetValue<int64_t>(axis);
|
|
|
|
|
check_axis(axis_value, x_shape.size());
|
|
|
|
|
if (keep_dims_value) {
|
|
|
|
|
shape[axis_value] = 1;
|
|
|
|
|
} else {
|
|
|
|
|
shape.erase(std::begin(shape) + axis_value);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Axis should be one of types: [int/tuple/list].";
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
ShapeVector shape = {};
|
|
|
|
|
ShapeVector x_shape = input_x->shape()->shape();
|
|
|
|
|
cal_shape(shape, x_shape);
|
|
|
|
|
|
|
|
|
|
bool x_is_dyn = (!input_x->shape()->min_shape().empty() && !input_x->shape()->max_shape().empty());
|
|
|
|
|
if (x_is_dyn) {
|
|
|
|
|
ShapeVector shape_min = {};
|
|
|
|
|
ShapeVector shape_max = {};
|
|
|
|
|
ShapeVector x_shape_min = input_x->shape()->min_shape();
|
|
|
|
|
ShapeVector x_shape_max = input_x->shape()->max_shape();
|
|
|
|
|
cal_shape(shape_min, x_shape_min);
|
|
|
|
|
cal_shape(shape_max, x_shape_max);
|
|
|
|
|
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape, shape_min, shape_max));
|
|
|
|
|
}
|
|
|
|
|
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplBinaryBase(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
const std::string op_name = primitive->name();
|
|
|
|
|