|
|
|
@ -50,11 +50,15 @@ class ReshapeSameShapeEliminater : public AnfVisitor {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto src_shape = src_shape_abs->GetShapeTrack();
|
|
|
|
|
auto tgt_shape = GetValueNode(shape_);
|
|
|
|
|
if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa<Shape>()) {
|
|
|
|
|
auto elements = GetValue<std::vector<int>>(tgt_shape);
|
|
|
|
|
auto tgt_shape_abs = node->abstract();
|
|
|
|
|
if (tgt_shape_abs == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto tgt_shape = tgt_shape_abs->GetShapeTrack();
|
|
|
|
|
if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa<Shape>() && tgt_shape->isa<Shape>()) {
|
|
|
|
|
auto elements = tgt_shape->cast<ShapePtr>();
|
|
|
|
|
auto shape = src_shape->cast<ShapePtr>();
|
|
|
|
|
if (shape->shape() == elements) {
|
|
|
|
|
if (shape->shape() == elements->shape()) {
|
|
|
|
|
return x_;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|