|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
@ -76,7 +77,12 @@ static void BuildReshapeNode(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (is_v2) {
|
|
|
|
|
platform::SetOutputNode(op, "XShape", input, ngb_node_map);
|
|
|
|
|
ngraph::Shape input_xshape(input_shape.size() + 1);
|
|
|
|
|
input_xshape[0] = 0;
|
|
|
|
|
std::copy(input_shape.begin(), input_shape.end(), input_xshape.begin() + 1);
|
|
|
|
|
auto xshape_node = std::make_shared<ngraph::op::Constant>(
|
|
|
|
|
input->get_element_type(), input_xshape, std::vector<std::string>{});
|
|
|
|
|
platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
platform::SetOutputNode(op, "Out", out, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
@ -88,13 +94,17 @@ void BuildReshapeGradNode(
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
|
|
|
|
|
std::shared_ptr<ngraph::Node> input;
|
|
|
|
|
ngraph::Shape out_shape;
|
|
|
|
|
if (is_v2) {
|
|
|
|
|
input = paddle::platform::GetInputNode(op, "XShape", ngb_node_map);
|
|
|
|
|
auto& xshape =
|
|
|
|
|
platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape();
|
|
|
|
|
out_shape.resize(xshape.size() - 1);
|
|
|
|
|
std::copy(xshape.begin() + 1, xshape.end(), out_shape.begin());
|
|
|
|
|
} else {
|
|
|
|
|
input = paddle::platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
auto input = paddle::platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
out_shape = input->get_shape();
|
|
|
|
|
}
|
|
|
|
|
auto dx = platform::NgReshaper(dout, input->get_shape());
|
|
|
|
|
auto dx = platform::NgReshaper(dout, out_shape);
|
|
|
|
|
paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
} // namespace ngraphs
|
|
|
|
|