You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/ngraph/ngraph_engine.cc

492 lines
17 KiB

/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
namespace paddle {
namespace operators {
static ngraph::Shape Ddim2Shape(const framework::DDim& dims) {
ngraph::Shape sp;
for (int i = 0; i < dims.size(); ++i) {
int k = dims[i];
k = k == 0 ? 1 : k;
sp.push_back(k);
}
return sp;
}
static std::map<framework::proto::VarType::Type, ngraph::element::Type>
pd2ng_type_map = {
{framework::proto::VarType::FP32, ngraph::element::f32},
{framework::proto::VarType::FP64, ngraph::element::f64},
{framework::proto::VarType::INT32, ngraph::element::i32},
{framework::proto::VarType::INT64, ngraph::element::i64},
{framework::proto::VarType::BOOL, ngraph::element::boolean},
};
std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
NgraphEngine::func_cache_ = {};
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
static std::vector<std::vector<int>> NgraphOpIntervals(
framework::BlockDesc* block) {
std::vector<std::vector<int>> intervals;
auto ops = block->AllOps();
int size = ops.size();
int left = 0;
while (left < size && ops.at(left)->Type() != framework::kFeedOpType) {
++left;
}
if (left == size) {
return intervals;
}
while (left < size && ops.at(left)->Type() == framework::kFeedOpType) {
++left;
}
int right = left;
while (right < size && ops.at(right)->Type() != framework::kFetchOpType) {
++right;
}
if (right == size) {
return intervals;
}
if (left >= right) return intervals;
// (left, right - 1) represents indices between feed and fetch
int pivot = left;
while (pivot < right) {
auto op_type = ops.at(pivot)->Type();
if (NgraphBridge::NG_NODE_MAP.find(op_type) ==
NgraphBridge::NG_NODE_MAP.end()) {
++pivot;
} else {
int start = pivot, end = start;
while (pivot < right &&
(NgraphBridge::NG_NODE_MAP.find(ops.at(pivot)->Type()) !=
NgraphBridge::NG_NODE_MAP.end())) {
++pivot;
++end;
}
std::vector<int> interval = {start, end};
intervals.push_back(interval);
}
} // end while
return intervals;
}
static void SubstituteNgraphOp(framework::BlockDesc* block,
std::string block_str,
std::vector<int> interval) {
framework::ProgramDesc program;
block->RemoveOp(interval.at(0), interval.at(1));
auto* ng_op = block->InsertOp(interval.at(0));
ng_op->SetType("ngraph_engine");
ng_op->SetAttr("interval", interval);
ng_op->SetAttr("graph", block_str);
}
// TODO(baojun-nervana): Move EnableNgraph to compile time per PR #15089
void NgraphEngine::EnableNgraph(const framework::ProgramDesc& program) {
#ifdef PADDLE_WITH_NGRAPH
VLOG(4) << "use_ngraph=True";
for (size_t bid = 0; bid < program.Size(); ++bid) {
// TODO(baojun-nervana): Remove the const_cast
auto* block =
const_cast<framework::ProgramDesc&>(program).MutableBlock(bid);
std::string block_str = block->Proto()->SerializeAsString();
auto intervals = NgraphOpIntervals(block);
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
SubstituteNgraphOp(block, block_str, *it);
}
}
#else
LOG(WARNING)
<< "'NGRAPH' is not supported, Please re-compile with WITH_NGRAPH option";
#endif
}
NgraphEngine::NgraphEngine(const framework::Scope& scope,
const platform::Place& place,
const std::string& serialized_graph,
const std::vector<int>& interval)
: scope_(scope), place_(place) {
var_in_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
func_cache_key_ = std::to_string(interval[0]) + std::to_string(interval[1]) +
serialized_graph;
framework::proto::BlockDesc bdesc;
bdesc.ParseFromString(serialized_graph);
framework::BlockDesc block(nullptr, &bdesc);
Prepare(block, interval);
BuildNgIO();
GetNgFunction();
}
void NgraphEngine::Prepare(const framework::BlockDesc& block,
const std::vector<int>& interval) {
for (auto& var : block.AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR ||
var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY)) {
continue;
}
auto var_name = var->Name();
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var_name != framework::kFeedOpType &&
var_name != framework::kFetchOpType) {
auto pd_type = var->GetDataType();
if (pd2ng_type_map.find(pd_type) == pd2ng_type_map.end()) {
PADDLE_THROW("Data type of var %s not found in pd2ng_type_map",
var_name);
}
var_type_map_[var_name] = pd2ng_type_map[pd_type];
}
if (var->Persistable()) {
persistables_.insert(var->Name());
}
}
auto ops_desc = block.AllOps();
int idx = interval[0];
while (idx < interval[1]) {
auto op_desc = ops_desc.at(idx);
auto op = framework::OpRegistry::CreateOp(*op_desc);
fused_ops_.push_back(std::move(op));
++idx;
}
while (ops_desc.at(idx)->Type() != framework::kFetchOpType) {
auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) {
post_op_inputs_.insert(var_name);
}
}
++idx;
}
while (idx < static_cast<int>(ops_desc.size()) &&
ops_desc.at(idx)->Type() == framework::kFetchOpType) {
std::string fetch_target_name = ops_desc.at(idx)->Input("X")[0];
fetches_.insert(fetch_target_name);
++idx;
}
if (ops_desc.at(interval.at(0) - 1)->Type() == framework::kFeedOpType &&
ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) {
ng_op_state_ = OpState::FULL;
}
for (auto* op_desc : ops_desc) {
if (op_desc->Type().find("_grad") != std::string::npos) {
ng_op_state_ = ng_op_state_ == OpState::FULL ? OpState::FULL_TRAIN
: OpState::PARTIAL_TRAIN;
break;
}
}
if (ng_op_state_ != OpState::FULL_TRAIN &&
ng_op_state_ != OpState::PARTIAL_TRAIN) {
ng_op_state_ = ng_op_state_ == OpState::FULL ? OpState::FULL_TEST
: OpState::PARTIAL_TEST;
}
}
void NgraphEngine::GetNgInputShape(
std::shared_ptr<framework::OperatorBase> op) {
framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
op->RuntimeInferShape(scope_, place_, ctx);
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto sp = Ddim2Shape(tensor_pd->dims());
if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
var_in_.end()) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
// auto ng_type = pd2ng_type_map.at(GetDataTypeOfVar(var));
auto ng_type = var_type_map_.at(var_name);
auto prm =
std::make_shared<ngraph::op::Parameter>(ng_type, sp, true);
(*var_node_map_)[var_name] = prm;
(*var_in_node_map_)[var_name] = prm;
}
}
}
}
}
}
void NgraphEngine::BuildNgNodes() {
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Outputs()) {
for (auto& var_name : var_name_item.second) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto* var = scope_.FindVar(var_name);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto& ddim = tensor_pd->dims();
auto ng_shape = Ddim2Shape(ddim);
auto ng_type = var_type_map_.at(var_name);
auto prm = std::make_shared<ngraph::op::Parameter>(ng_type,
ng_shape, true);
(*var_node_map_)[var_name] = prm;
}
}
}
}
}
NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) {
ngb.BuildNgNode(op);
}
}
void NgraphEngine::BuildNgIO() {
std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
inputs.insert(var_name);
const bool is_output = outputs.find(var_name) != outputs.end();
if (!is_output &&
std::find(var_in_.begin(), var_in_.end(), var_name) ==
var_in_.end()) {
// fill var_in here to keep lhs and rhs order
var_in_.push_back(var_name);
}
}
}
if (op->Type() != "fill_constant") {
GetNgInputShape(op);
}
for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet",
op->Type());
for (auto& var_name : var_name_item.second) {
outputs.insert(var_name);
}
}
}
// var_out.clear();
for (auto& op : fused_ops_) {
for (auto& var_name_item : op->Outputs()) {
PADDLE_ENFORCE_LE(var_name_item.second.size(), 1,
"op %s has more than 1 output - Not handling yet",
op->Type());
for (auto& var_name : var_name_item.second) {
switch (ng_op_state_) {
case OpState::PARTIAL_TEST:
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
fetches_.find(var_name) != fetches_.end()) {
var_out_.push_back(var_name);
}
break;
case OpState::FULL_TEST:
if (fetches_.find(var_name) != fetches_.end()) {
var_out_.push_back(var_name);
}
break;
case OpState::PARTIAL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name);
}
break;
case OpState::FULL_TRAIN:
if (fetches_.find(var_name) != fetches_.end() ||
persistables_.find(var_name) != persistables_.end()) {
var_out_.push_back(var_name);
}
break;
default:
var_out_.push_back(var_name);
}
}
}
}
}
void NgraphEngine::BuildNgFunction() {
BuildNgNodes();
ngraph_function_ = nullptr;
ngraph::NodeVector func_outputs;
ngraph::ParameterVector func_inputs;
for (auto& vo : var_out_) {
func_outputs.push_back(var_node_map_->at(vo));
}
for (auto& vi : var_in_) {
std::shared_ptr<ngraph::op::Parameter> prm =
std::dynamic_pointer_cast<ngraph::op::Parameter>(
var_in_node_map_->at(vi));
func_inputs.push_back(prm);
}
ngraph_function_ =
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
}
void NgraphEngine::GetNgFunction() {
bool cache_on = true;
if (cache_on) {
std::string input_shape_str;
for (auto& var_name : var_in_) {
auto shape = var_node_map_->at(var_name)->get_shape();
for (size_t i = 0; i < shape.size(); ++i) {
input_shape_str += std::to_string(shape.at(i));
}
}
func_cache_key_ = input_shape_str + func_cache_key_;
if (func_cache_.find(func_cache_key_) != func_cache_.end()) {
ngraph_function_ = func_cache_.at(func_cache_key_);
} else {
BuildNgFunction();
func_cache_[func_cache_key_] = ngraph_function_;
}
} else {
BuildNgFunction();
}
}
void NgraphEngine::Run(const framework::Scope& scope,
const platform::Place& place) const {
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out;
for (size_t i = 0; i < var_in_.size(); ++i) {
auto vi = var_in_.at(i);
auto sp = var_node_map_->at(vi)->get_shape();
std::shared_ptr<ngraph::runtime::Tensor> ti;
auto* var = scope.FindVar(vi);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor");
auto ng_type = var_type_map_.at(vi);
if (ng_type == ngraph::element::f32) {
auto pd_arr = tensor_pd->mutable_data<float>(place);
ti = backend_->create_tensor(ngraph::element::f32, sp, pd_arr);
} else if (ng_type == ngraph::element::i32) {
const int* arr = tensor_pd->data<int>();
ti = backend_->create_tensor(ngraph::element::i32, sp,
const_cast<int*>(arr));
} else if (ng_type == ngraph::element::i64) {
auto pd_arr = tensor_pd->mutable_data<int64_t>(place);
ti = backend_->create_tensor(ngraph::element::i64, sp, pd_arr);
} else if (ng_type == ngraph::element::f64) {
auto pd_arr = tensor_pd->mutable_data<double>(place);
ti = backend_->create_tensor(ngraph::element::f64, sp, pd_arr);
} else if (ng_type == ngraph::element::boolean) {
auto pd_arr = tensor_pd->mutable_data<bool>(place);
ti = backend_->create_tensor(ngraph::element::boolean, sp, pd_arr);
} else {
PADDLE_THROW("Data type not handling for var %s", vi);
}
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
bool is_test = (ng_op_state_ == OpState::PARTIAL_TEST ||
ng_op_state_ == OpState::FULL_TEST)
? true
: false;
bool is_persistable =
(persistables_.find(vi) != persistables_.end()) ? true : false;
if (is_test && is_persistable) {
ti->set_stale(false);
}
t_in.push_back(ti);
}
for (size_t i = 0; i < var_out_.size(); ++i) {
auto vo = var_out_[i];
auto* var = scope.FindVar(vo);
std::shared_ptr<ngraph::runtime::Tensor> to;
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
auto dd = tensor_pd->dims();
ngraph::Shape sp = Ddim2Shape(dd);
auto ng_type = var_type_map_.at(vo);
if (ng_type == ngraph::element::f32) {
auto pd_arr = tensor_pd->mutable_data<float>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::i64) {
auto pd_arr = tensor_pd->mutable_data<int64_t>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::i32) {
auto pd_arr = tensor_pd->mutable_data<int>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::f64) {
auto pd_arr = tensor_pd->mutable_data<double>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else if (ng_type == ngraph::element::boolean) {
auto pd_arr = tensor_pd->mutable_data<bool>(place);
to = backend_->create_tensor(ng_type, sp, pd_arr);
} else {
PADDLE_THROW("Data type not handled in for var %s", vo);
}
t_out.push_back(to);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vo);
}
}
backend_->call(backend_->compile(ngraph_function_), t_out, t_in);
} // NgraphEngine::Run
} // namespace operators
} // namespace paddle