|
|
|
@ -35,6 +35,13 @@ static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
|
|
|
|
|
{proto::VarType::BOOL, ngraph::element::boolean},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
typedef enum { /* nGraph support state on ops */
|
|
|
|
|
FULL_TRAIN, /* Support full ops for train */
|
|
|
|
|
PARTIAL_TRAIN, /* Support partial ops for train */
|
|
|
|
|
FULL_TEST, /* Support full list of ops for test */
|
|
|
|
|
PARTIAL_TEST /* Support partial list of ops for test */
|
|
|
|
|
} op_state;
|
|
|
|
|
|
|
|
|
|
class NgraphOperator {
|
|
|
|
|
public:
|
|
|
|
|
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
|
|
|
|
@ -44,33 +51,29 @@ class NgraphOperator {
|
|
|
|
|
const std::unordered_set<std::string>& persist,
|
|
|
|
|
const std::unordered_set<std::string>& fetches,
|
|
|
|
|
const std::unordered_set<std::string>& post_op_inputs,
|
|
|
|
|
int is_test_or_train)
|
|
|
|
|
: scope(scope),
|
|
|
|
|
place(place),
|
|
|
|
|
fused_ops(ops),
|
|
|
|
|
var_type_map(var_type_map),
|
|
|
|
|
persistables(persist),
|
|
|
|
|
fetches(fetches),
|
|
|
|
|
post_op_inputs(post_op_inputs),
|
|
|
|
|
is_test_or_train(is_test_or_train) {}
|
|
|
|
|
op_state ng_op_state)
|
|
|
|
|
: scope_(scope),
|
|
|
|
|
place_(place),
|
|
|
|
|
fused_ops_(ops),
|
|
|
|
|
var_type_map_(var_type_map),
|
|
|
|
|
persistables_(persist),
|
|
|
|
|
fetches_(fetches),
|
|
|
|
|
post_op_inputs_(post_op_inputs),
|
|
|
|
|
ng_op_state_(ng_op_state) {}
|
|
|
|
|
|
|
|
|
|
void Run(const Scope& scope, const platform::Place& place) const;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
|
|
|
|
|
func_cache;
|
|
|
|
|
const Scope& scope;
|
|
|
|
|
const platform::Place& place;
|
|
|
|
|
std::vector<std::shared_ptr<OperatorBase>> fused_ops;
|
|
|
|
|
std::unordered_map<std::string, ngraph::element::Type> var_type_map;
|
|
|
|
|
std::unordered_set<std::string> persistables;
|
|
|
|
|
std::unordered_set<std::string> fetches;
|
|
|
|
|
std::unordered_set<std::string> post_op_inputs;
|
|
|
|
|
// 0 = default; 1 = (is_test && not is_complete)
|
|
|
|
|
// 2 = (is_test && is_complete)
|
|
|
|
|
// 3 = (is_training && not is_complete)
|
|
|
|
|
// 4 = (is_training && is_complete)
|
|
|
|
|
int is_test_or_train;
|
|
|
|
|
const Scope& scope_;
|
|
|
|
|
const platform::Place& place_;
|
|
|
|
|
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
|
|
|
|
|
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
|
|
|
|
|
std::unordered_set<std::string> persistables_;
|
|
|
|
|
std::unordered_set<std::string> fetches_;
|
|
|
|
|
std::unordered_set<std::string> post_op_inputs_;
|
|
|
|
|
op_state ng_op_state_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
|
|
|
|
@ -131,19 +134,19 @@ FusedOperator::FusedOperator(
|
|
|
|
|
const ProgramDesc& prog, size_t block_id,
|
|
|
|
|
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
|
|
|
|
|
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
|
|
|
|
|
const std::string& type = "fused_op", const VariableNameMap& inputs = {},
|
|
|
|
|
const VariableNameMap& outputs = {}, const AttributeMap& attrs = {})
|
|
|
|
|
const std::string& type, const VariableNameMap& inputs,
|
|
|
|
|
const VariableNameMap& outputs, const AttributeMap& attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs), pdesc(prog), block(block_id) {
|
|
|
|
|
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = start;
|
|
|
|
|
it != end; ++it) {
|
|
|
|
|
fused_ops.push_back(std::move(*it));
|
|
|
|
|
fused_ops_.push_back(std::move(*it));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (std::vector<std::unique_ptr<OperatorBase>>::iterator it = end;
|
|
|
|
|
(*it)->Type() != kFetchOpType; ++it) {
|
|
|
|
|
for (auto& var_name_item : (*it)->Inputs()) {
|
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
|
post_op_inputs.insert(var_name);
|
|
|
|
|
post_op_inputs_.insert(var_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -152,11 +155,11 @@ FusedOperator::FusedOperator(
|
|
|
|
|
is_complete = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
process();
|
|
|
|
|
Process();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FusedOperator::process() {
|
|
|
|
|
auto& bdesc = pdesc.Block(block);
|
|
|
|
|
void FusedOperator::Process() {
|
|
|
|
|
auto& bdesc = pdesc_.Block(block_);
|
|
|
|
|
for (auto& var : bdesc.AllVars()) {
|
|
|
|
|
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
|
|
|
|
|
var->GetType() == proto::VarType::LOD_TENSOR ||
|
|
|
|
@ -175,39 +178,40 @@ void FusedOperator::process() {
|
|
|
|
|
PADDLE_THROW("Data type of var %s not found in pd2ng_type_map",
|
|
|
|
|
var_name);
|
|
|
|
|
}
|
|
|
|
|
var_type_map[var_name] = pd2ng_type_map[pd_type];
|
|
|
|
|
var_type_map_[var_name] = pd2ng_type_map[pd_type];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (var->Persistable()) {
|
|
|
|
|
persistables.insert(var->Name());
|
|
|
|
|
persistables_.insert(var->Name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto* op : bdesc.AllOps()) {
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
std::string fetch_target_name = op->Input("X")[0];
|
|
|
|
|
fetches.insert(fetch_target_name);
|
|
|
|
|
fetches_.insert(fetch_target_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FusedOperator::RunImpl(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
int is_test_or_train = 1;
|
|
|
|
|
auto& bdesc = pdesc.Block(block);
|
|
|
|
|
op_state ng_op_state = PARTIAL_TEST;
|
|
|
|
|
auto& bdesc = pdesc_.Block(block_);
|
|
|
|
|
for (auto* op : bdesc.AllOps()) {
|
|
|
|
|
if (op->Type().find("_grad") != std::string::npos) {
|
|
|
|
|
is_test_or_train = 3;
|
|
|
|
|
ng_op_state = PARTIAL_TRAIN;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (is_complete) {
|
|
|
|
|
is_test_or_train = is_test_or_train == 1 ? 2 : 4;
|
|
|
|
|
if (is_full) {
|
|
|
|
|
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
NgraphOperator ngraph_op(scope, place, fused_ops, var_type_map, persistables,
|
|
|
|
|
fetches, post_op_inputs, is_test_or_train);
|
|
|
|
|
NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_,
|
|
|
|
|
persistables_, fetches_, post_op_inputs_,
|
|
|
|
|
ng_op_state);
|
|
|
|
|
ngraph_op.Run(scope, place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|