|
|
@ -729,62 +729,6 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &
|
|
|
|
return ret;
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
|
|
|
const std::string op_name = primitive->name();
|
|
|
|
|
|
|
|
CheckArgsSize(op_name, args_spec_list, 2);
|
|
|
|
|
|
|
|
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(x);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(x->shape());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ShapeVector x_shape = x->shape()->shape();
|
|
|
|
|
|
|
|
ShapeVector x_shape_min = x->shape()->min_shape();
|
|
|
|
|
|
|
|
if (x_shape_min.empty()) {
|
|
|
|
|
|
|
|
x_shape_min = x_shape;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ShapeVector x_shape_max = x->shape()->max_shape();
|
|
|
|
|
|
|
|
if (x_shape_max.empty()) {
|
|
|
|
|
|
|
|
x_shape_max = x_shape;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int64_t value = 0;
|
|
|
|
|
|
|
|
if (args_spec_list[1]->isa<AbstractTensor>()) { // axis is Tensor
|
|
|
|
|
|
|
|
auto axis = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
|
|
|
|
|
|
|
auto axis_value = axis->BuildValue();
|
|
|
|
|
|
|
|
if (!axis_value->isa<tensor::Tensor>()) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << axis_value << " axis_value should be tensor, but got " << axis_value->type_name();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto axis_tensor = axis_value->cast<tensor::TensorPtr>();
|
|
|
|
|
|
|
|
value = *(static_cast<int64_t *>(axis_tensor->data_c()));
|
|
|
|
|
|
|
|
} else if (args_spec_list[1]->isa<AbstractScalar>()) { // axis is Scalar
|
|
|
|
|
|
|
|
auto axis = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(axis);
|
|
|
|
|
|
|
|
value = GetValue<int64_t>(axis->BuildValue());
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "axis incorrect type in ExpandDims";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value
|
|
|
|
|
|
|
|
<< " and input_x.dim is" << x_shape.size();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (value < 0) {
|
|
|
|
|
|
|
|
value = value + SizeToInt(x_shape.size()) + 1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ShapeVector shape;
|
|
|
|
|
|
|
|
shape.insert(shape.end(), x_shape.begin(), x_shape.end());
|
|
|
|
|
|
|
|
shape.insert(shape.begin() + value, 1);
|
|
|
|
|
|
|
|
ShapeVector shape_min;
|
|
|
|
|
|
|
|
shape_min.insert(shape_min.end(), x_shape_min.begin(), x_shape_min.end());
|
|
|
|
|
|
|
|
shape_min.insert(shape_min.begin() + value, 1);
|
|
|
|
|
|
|
|
ShapeVector shape_max;
|
|
|
|
|
|
|
|
shape_max.insert(shape_max.end(), x_shape_max.begin(), x_shape_max.end());
|
|
|
|
|
|
|
|
shape_max.insert(shape_max.begin() + value, 1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, shape_min, shape_max));
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
const std::string op_name = primitive->name();
|
|
|
|
const std::string op_name = primitive->name();
|
|
|
|