|
|
|
@ -57,6 +57,7 @@ typedef enum { /* nGraph support state on ops */
|
|
|
|
|
PARTIAL_TEST /* Support partial list of ops for test */
|
|
|
|
|
} op_state;
|
|
|
|
|
|
|
|
|
|
// perform graph build through bridge and execute computation
|
|
|
|
|
class NgraphOperator {
|
|
|
|
|
public:
|
|
|
|
|
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
|
|
|
|
@ -100,33 +101,33 @@ class NgraphOperator {
|
|
|
|
|
std::unordered_set<std::string> post_op_inputs_;
|
|
|
|
|
op_state ng_op_state_;
|
|
|
|
|
|
|
|
|
|
// ngraph backend eg. CPU
|
|
|
|
|
static std::shared_ptr<ngraph::runtime::Backend> backend_;
|
|
|
|
|
|
|
|
|
|
// ngraph function to call and execute
|
|
|
|
|
std::shared_ptr<ngraph::Function> ngraph_function_;
|
|
|
|
|
// var_name of inputs
|
|
|
|
|
std::vector<std::string> var_in_;
|
|
|
|
|
// var_name of outputs from fetch in order
|
|
|
|
|
std::vector<std::string> var_out_;
|
|
|
|
|
|
|
|
|
|
// map input vars to nodes
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
var_in_node_map_;
|
|
|
|
|
|
|
|
|
|
// map each var name with a ngraph node
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
var_node_map_;
|
|
|
|
|
|
|
|
|
|
// cache key to check if function is cached
|
|
|
|
|
std::shared_ptr<std::string> GetCacheKey();
|
|
|
|
|
|
|
|
|
|
// get ngraph input and define ngraph input parameters
|
|
|
|
|
void GetNgInputShape(std::shared_ptr<OperatorBase> op);
|
|
|
|
|
|
|
|
|
|
// Call ngraph bridge to map ops
|
|
|
|
|
void BuildNgNode();
|
|
|
|
|
|
|
|
|
|
// get the ngraph input and output var list
|
|
|
|
|
void BuildNgIO();
|
|
|
|
|
|
|
|
|
|
// build ngraph function call
|
|
|
|
|
void BuildNgFunction();
|
|
|
|
|
|
|
|
|
|
// Check cache for ngraph function or otherwise build the function
|
|
|
|
|
void GetNgFunction();
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|