|
|
|
@ -23,6 +23,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "ngraph/ngraph.hpp"
|
|
|
|
|
#include "paddle/fluid/operators/ngraph/ops/op_bridge.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
#include "paddle/fluid/platform/ngraph_helper.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -60,20 +61,16 @@ static void BuildReshapeNode(
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> shape =
|
|
|
|
|
platform::GetInputNode(op, "Shape", ngb_node_map);
|
|
|
|
|
PADDLE_ENFORCE_EQ(shape, nullptr,
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Support for Shape input is not implemented"));
|
|
|
|
|
|
|
|
|
|
auto op_attrs = framework::AttrReader(op->Attrs());
|
|
|
|
|
std::vector<int> v_shape = op_attrs.Get<std::vector<int>>("shape");
|
|
|
|
|
auto out = input;
|
|
|
|
|
if (shape != nullptr) {
|
|
|
|
|
ngraph::Shape new_shape;
|
|
|
|
|
for (auto& it : shape->get_shape()) {
|
|
|
|
|
new_shape.push_back(it);
|
|
|
|
|
}
|
|
|
|
|
out = platform::NgReshaper(input, shape->get_shape());
|
|
|
|
|
} else {
|
|
|
|
|
auto out_shape = calc_output_shape(input_shape, v_shape);
|
|
|
|
|
out = platform::NgReshaper(input, out_shape);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_shape = calc_output_shape(input_shape, v_shape);
|
|
|
|
|
auto out = platform::NgReshaper(input, out_shape);
|
|
|
|
|
platform::SetOutputNode(op, "Out", out, ngb_node_map);
|
|
|
|
|
|
|
|
|
|
if (is_v2) {
|
|
|
|
|
ngraph::Shape input_xshape(input_shape.size() + 1);
|
|
|
|
@ -83,7 +80,6 @@ static void BuildReshapeNode(
|
|
|
|
|
input->get_element_type(), input_xshape, std::vector<std::string>{});
|
|
|
|
|
platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
platform::SetOutputNode(op, "Out", out, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <bool is_v2>
|
|
|
|
|