|
|
|
@ -47,16 +47,27 @@ void BuildLookupTableNode(
|
|
|
|
|
if (is_sparse) {
|
|
|
|
|
PADDLE_THROW("Sparsity is not yet supported in nGraph lookup_table op.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto ng_w_mask = ng_w;
|
|
|
|
|
if (padding_idx != kNoPadding) {
|
|
|
|
|
PADDLE_THROW("Padding is not yet supported in nGraph lookup_table op.");
|
|
|
|
|
auto w_shape = ng_w->get_shape();
|
|
|
|
|
|
|
|
|
|
std::vector<int> maskV(w_shape[0], 1);
|
|
|
|
|
maskV[padding_idx] = 0;
|
|
|
|
|
auto maskV_node = std::make_shared<ngraph::op::Constant>(
|
|
|
|
|
ng_w->get_element_type(), ngraph::Shape{w_shape[0]}, maskV);
|
|
|
|
|
ngraph::AxisSet axis_set;
|
|
|
|
|
for (unsigned int i = 1; i < w_shape.size(); ++i) axis_set.insert(i);
|
|
|
|
|
auto maskV_bd =
|
|
|
|
|
std::make_shared<ngraph::op::Broadcast>(maskV_node, w_shape, axis_set);
|
|
|
|
|
ng_w_mask = std::make_shared<ngraph::op::Multiply>(ng_w, maskV_bd);
|
|
|
|
|
}
|
|
|
|
|
auto shape = ng_ids->get_shape();
|
|
|
|
|
if (shape.back() == 1) {
|
|
|
|
|
shape.pop_back();
|
|
|
|
|
ng_ids = platform::NgReshaper(ng_ids, shape);
|
|
|
|
|
}
|
|
|
|
|
auto ng_lookup = std::make_shared<ngraph::op::Gather>(ng_w, ng_ids);
|
|
|
|
|
|
|
|
|
|
auto ng_lookup = std::make_shared<ngraph::op::Gather>(ng_w_mask, ng_ids);
|
|
|
|
|
platform::SetOutputNode(op, "Out", ng_lookup, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -67,8 +78,6 @@ void BuildLookupTableGradNode(
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
|
|
|
|
|
const bool is_sparse = op_attrs.Get<bool>("is_sparse");
|
|
|
|
|
const int64_t padding_idx = op_attrs.Get<int64_t>("padding_idx");
|
|
|
|
|
|
|
|
|
|
auto ng_ids = paddle::platform::GetInputNode(op, "Ids", ngb_node_map);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ng_ids);
|
|
|
|
|
|
|
|
|
@ -81,9 +90,6 @@ void BuildLookupTableGradNode(
|
|
|
|
|
PADDLE_THROW("Sparsity is not yet supported in nGraph lookup_table op.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (padding_idx != kNoPadding) {
|
|
|
|
|
PADDLE_THROW("Padding is not yet supported in nGraph lookup_table op.");
|
|
|
|
|
}
|
|
|
|
|
auto shape = ng_ids->get_shape();
|
|
|
|
|
if (shape.back() == 1) {
|
|
|
|
|
shape.pop_back();
|
|
|
|
|