|
|
@ -15,23 +15,105 @@ limitations under the License. */
|
|
|
|
#ifdef PADDLE_WITH_NGRAPH
|
|
|
|
#ifdef PADDLE_WITH_NGRAPH
|
|
|
|
#include <algorithm>
|
|
|
|
#include <algorithm>
|
|
|
|
#include <functional>
|
|
|
|
#include <functional>
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ngraph_bridge.h"
|
|
|
|
#include "paddle/fluid/framework/ngraph_bridge.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
|
|
|
|
|
|
|
#include "ngraph/ngraph.hpp"
|
|
|
|
#include "ngraph/ngraph.hpp"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace framework {
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<ngraph::Node> GetNode(
|
|
|
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string prm,
|
|
|
|
|
|
|
|
const VariableNameMap& var_map,
|
|
|
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
|
|
|
auto& var_names = var_map.at(prm);
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(var_names.size(), 1,
|
|
|
|
|
|
|
|
"op %s prm %s expects one associated var", op->Type(), prm);
|
|
|
|
|
|
|
|
if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) {
|
|
|
|
|
|
|
|
return (*ngb_node_map)[var_names[0]];
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<ngraph::Node> GetInputNode(
|
|
|
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string prm,
|
|
|
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
|
|
|
return GetNode(op, prm, op->Inputs(), ngb_node_map);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<ngraph::Node> GetOutputNode(
|
|
|
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string prm,
|
|
|
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
|
|
|
return GetNode(op, prm, op->Outputs(), ngb_node_map);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void SetOutputNode(
|
|
|
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string prm,
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> node,
|
|
|
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
|
|
|
auto& var_names = op->Outputs().at(prm);
|
|
|
|
|
|
|
|
if (var_names.size() == 1) {
|
|
|
|
|
|
|
|
(*ngb_node_map)[var_names[0]] = node;
|
|
|
|
|
|
|
|
} else if (var_names.size() == 0) {
|
|
|
|
|
|
|
|
(*ngb_node_map)[""] = node;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
PADDLE_THROW("prm %s has more than 1 var_names.", prm);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static bool HasOutput(const std::shared_ptr<OperatorBase>& op,
|
|
|
|
|
|
|
|
const std::string prm) {
|
|
|
|
|
|
|
|
auto& outputs = op->Outputs();
|
|
|
|
|
|
|
|
if (outputs.find(prm) == outputs.end()) return false;
|
|
|
|
|
|
|
|
return outputs.at(prm).size() > 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
static void BuildBinaryNode(
|
|
|
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op,
|
|
|
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
|
|
|
auto x = GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
|
|
|
auto y = GetInputNode(op, "Y", ngb_node_map);
|
|
|
|
|
|
|
|
auto out = std::make_shared<T>(x, y);
|
|
|
|
|
|
|
|
SetOutputNode(op, "Out", out, ngb_node_map);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
static void BuildUnaryNode(
|
|
|
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op,
|
|
|
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
|
|
|
auto input = GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
|
|
|
auto out = std::make_shared<T>(input);
|
|
|
|
|
|
|
|
SetOutputNode(op, "Out", out, ngb_node_map);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::map<std::string,
|
|
|
|
std::map<std::string,
|
|
|
|
std::function<void(const std::shared_ptr<OperatorBase>&,
|
|
|
|
std::function<void(const std::shared_ptr<OperatorBase>&,
|
|
|
|
std::shared_ptr<std::unordered_map<
|
|
|
|
std::shared_ptr<std::unordered_map<
|
|
|
|
std::string, std::shared_ptr<ngraph::Node>>>)>>
|
|
|
|
std::string, std::shared_ptr<ngraph::Node>>>)>>
|
|
|
|
NgraphBridge::NG_NODE_MAP = {};
|
|
|
|
NgraphBridge::NG_NODE_MAP = {{"relu", BuildUnaryNode<ngraph::op::Relu>},
|
|
|
|
|
|
|
|
{"tanh", BuildUnaryNode<ngraph::op::Tanh>}};
|
|
|
|
|
|
|
|
|
|
|
|
void NgraphBridge::build_graph(const std::shared_ptr<OperatorBase>& op) {
|
|
|
|
void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) {
|
|
|
|
auto& op_type = op->Type();
|
|
|
|
auto& op_type = op->Type();
|
|
|
|
NG_NODE_MAP[op_type](op, ngb_node_map);
|
|
|
|
NG_NODE_MAP[op_type](op, ngb_node_map_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|