|
|
|
@ -122,7 +122,7 @@ class NgraphOperator {
|
|
|
|
|
// get ngraph input and define ngraph input parameters
|
|
|
|
|
void GetNgInputShape(std::shared_ptr<OperatorBase> op);
|
|
|
|
|
// Call ngraph bridge to map ops
|
|
|
|
|
void BuildNgNode();
|
|
|
|
|
void BuildNgNodes();
|
|
|
|
|
// get the ngraph input and output var list
|
|
|
|
|
void BuildNgIO();
|
|
|
|
|
// build ngraph function call
|
|
|
|
@ -301,7 +301,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void NgraphOperator::BuildNgNode() {
|
|
|
|
|
void NgraphOperator::BuildNgNodes() {
|
|
|
|
|
for (auto& var_name : var_out_) {
|
|
|
|
|
if (var_node_map_->find(var_name) == var_node_map_->end()) {
|
|
|
|
|
auto* var = scope_.FindVar(var_name);
|
|
|
|
@ -319,7 +319,7 @@ void NgraphOperator::BuildNgNode() {
|
|
|
|
|
|
|
|
|
|
paddle::framework::NgraphBridge ngb(var_node_map_);
|
|
|
|
|
for (auto& op : fused_ops_) {
|
|
|
|
|
ngb.BuildNgGraph(op);
|
|
|
|
|
ngb.BuildNgNode(op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -396,7 +396,7 @@ void NgraphOperator::BuildNgIO() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void NgraphOperator::BuildNgFunction() {
|
|
|
|
|
BuildNgNode();
|
|
|
|
|
BuildNgNodes();
|
|
|
|
|
ngraph_function_ = nullptr;
|
|
|
|
|
ngraph::NodeVector func_outputs;
|
|
|
|
|
ngraph::op::ParameterVector func_inputs;
|
|
|
|
|