|
|
|
@ -55,15 +55,66 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
|
|
|
|
|
return shape1;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> dims;
|
|
|
|
|
bool has_dynamic_shape = false;
|
|
|
|
|
dims.resize(shape1->shape().size());
|
|
|
|
|
for (std::size_t i = 0; i < shape1->shape().size(); i++) {
|
|
|
|
|
if (shape1->shape()[i] == shape2->shape()[i]) {
|
|
|
|
|
dims[i] = shape1->shape()[i];
|
|
|
|
|
if (shape1->shape()[i] == Shape::SHP_ANY) {
|
|
|
|
|
has_dynamic_shape = true;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
dims[i] = Shape::SHP_ANY;
|
|
|
|
|
has_dynamic_shape = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return std::make_shared<Shape>(dims);
|
|
|
|
|
if (!has_dynamic_shape) {
|
|
|
|
|
return std::make_shared<Shape>(dims);
|
|
|
|
|
}
|
|
|
|
|
// calculate dynamic shape
|
|
|
|
|
std::vector<int> min_dims(dims.size());
|
|
|
|
|
std::vector<int> max_dims(dims.size());
|
|
|
|
|
for (size_t i = 0; i < dims.size(); ++i) {
|
|
|
|
|
if (dims[i] != Shape::SHP_ANY) {
|
|
|
|
|
min_dims[i] = max_dims[i] = dims[i];
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
|
|
|
|
|
min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]);
|
|
|
|
|
max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
|
|
|
|
|
if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
|
|
|
|
|
<< " has dynamic shape, but does not have min/max shape info.";
|
|
|
|
|
}
|
|
|
|
|
min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]);
|
|
|
|
|
max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) {
|
|
|
|
|
if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
|
|
|
|
|
<< " has dynamic shape, but does not have min/max shape info.";
|
|
|
|
|
}
|
|
|
|
|
min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]);
|
|
|
|
|
max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// both shapes contains dynamic shape
|
|
|
|
|
if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
|
|
|
|
|
<< " has dynamic shape, but does not have min/max shape info.";
|
|
|
|
|
}
|
|
|
|
|
if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString()
|
|
|
|
|
<< " has dynamic shape, but does not have min/max shape info.";
|
|
|
|
|
}
|
|
|
|
|
min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]);
|
|
|
|
|
max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]);
|
|
|
|
|
}
|
|
|
|
|
return std::make_shared<Shape>(dims, min_dims, max_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|