|
|
@ -34,7 +34,15 @@ void BuildGatherNode(
|
|
|
|
ngb_node_map) {
|
|
|
|
ngb_node_map) {
|
|
|
|
auto x = platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
auto x = platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(x);
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(x);
|
|
|
|
|
|
|
|
|
|
|
|
auto index = platform::GetInputNode(op, "Index", ngb_node_map);
|
|
|
|
auto index = platform::GetInputNode(op, "Index", ngb_node_map);
|
|
|
|
|
|
|
|
auto& index_shape = index->get_shape();
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(index_shape.size() == 1 ||
|
|
|
|
|
|
|
|
(index_shape.size() == 2 && index_shape[1] == 1));
|
|
|
|
|
|
|
|
if (index_shape.size() == 2) {
|
|
|
|
|
|
|
|
index = platform::NgReshaper(index, ngraph::Shape{index_shape[0]});
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto out = std::make_shared<ngraph::op::Gather>(x, index);
|
|
|
|
auto out = std::make_shared<ngraph::op::Gather>(x, index);
|
|
|
|
|
|
|
|
|
|
|
|
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
|
|
|
|
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
|
|
|
@ -47,7 +55,14 @@ void BuildGatherGradNode(
|
|
|
|
auto dout = platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
|
|
|
|
auto dout = platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(dout);
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(dout);
|
|
|
|
auto x = platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
auto x = platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
|
|
|
|
|
|
|
auto index = platform::GetInputNode(op, "Index", ngb_node_map);
|
|
|
|
auto index = platform::GetInputNode(op, "Index", ngb_node_map);
|
|
|
|
|
|
|
|
auto& index_shape = index->get_shape();
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(index_shape.size() == 1 ||
|
|
|
|
|
|
|
|
(index_shape.size() == 2 && index_shape[1] == 1));
|
|
|
|
|
|
|
|
if (index_shape.size() == 2) {
|
|
|
|
|
|
|
|
index = platform::NgReshaper(index, ngraph::Shape{index_shape[0]});
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> x0 = paddle::platform::CreateConstant(
|
|
|
|
std::shared_ptr<ngraph::Node> x0 = paddle::platform::CreateConstant(
|
|
|
|
dout->get_element_type(), x->get_shape(), {0});
|
|
|
|
dout->get_element_type(), x->get_shape(), {0});
|
|
|
|