|
|
|
@ -26,14 +26,15 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<ngraph::Node> GetNode(
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string prm,
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string name,
|
|
|
|
|
const VariableNameMap& var_map,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto& var_names = var_map.at(prm);
|
|
|
|
|
auto& var_names = var_map.at(name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var_names.size(), 1,
|
|
|
|
|
"op %s prm %s expects one associated var", op->Type(), prm);
|
|
|
|
|
"op %s name %s expects one associated var", op->Type(),
|
|
|
|
|
name);
|
|
|
|
|
if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) {
|
|
|
|
|
return (*ngb_node_map)[var_names[0]];
|
|
|
|
|
} else {
|
|
|
|
@ -42,42 +43,42 @@ static std::shared_ptr<ngraph::Node> GetNode(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<ngraph::Node> GetInputNode(
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string prm,
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string name,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
return GetNode(op, prm, op->Inputs(), ngb_node_map);
|
|
|
|
|
return GetNode(op, name, op->Inputs(), ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<ngraph::Node> GetOutputNode(
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string prm,
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string name,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
return GetNode(op, prm, op->Outputs(), ngb_node_map);
|
|
|
|
|
return GetNode(op, name, op->Outputs(), ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void SetOutputNode(
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string prm,
|
|
|
|
|
const std::shared_ptr<OperatorBase>& op, const std::string name,
|
|
|
|
|
std::shared_ptr<ngraph::Node> node,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto& var_names = op->Outputs().at(prm);
|
|
|
|
|
auto& var_names = op->Outputs().at(name);
|
|
|
|
|
if (var_names.size() == 1) {
|
|
|
|
|
(*ngb_node_map)[var_names[0]] = node;
|
|
|
|
|
} else if (var_names.size() == 0) {
|
|
|
|
|
(*ngb_node_map)[""] = node;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("prm %s has more than 1 var_names.", prm);
|
|
|
|
|
PADDLE_THROW("name %s has more than 1 var_names.", name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool HasOutput(const std::shared_ptr<OperatorBase>& op,
|
|
|
|
|
const std::string prm) {
|
|
|
|
|
const std::string name) {
|
|
|
|
|
auto& outputs = op->Outputs();
|
|
|
|
|
if (outputs.find(prm) == outputs.end()) return false;
|
|
|
|
|
return outputs.at(prm).size() > 0;
|
|
|
|
|
if (outputs.find(name) == outputs.end()) return false;
|
|
|
|
|
return outputs.at(name).size() > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|