|
|
@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License. */
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_NGRAPH
|
|
|
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
#include <algorithm>
|
|
|
@ -58,16 +57,16 @@ typedef enum { /* nGraph support state on ops */
|
|
|
|
} op_state;
|
|
|
|
} op_state;
|
|
|
|
|
|
|
|
|
|
|
|
// perform graph build through bridge and execute computation
|
|
|
|
// perform graph build through bridge and execute computation
|
|
|
|
class NgraphOperator {
|
|
|
|
class NgraphEngine {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
explicit NgraphOperator(const Scope& scope, const platform::Place& place,
|
|
|
|
explicit NgraphEngine(const Scope& scope, const platform::Place& place,
|
|
|
|
const std::vector<std::shared_ptr<OperatorBase>>& ops,
|
|
|
|
const std::vector<std::shared_ptr<OperatorBase>>& ops,
|
|
|
|
const std::unordered_map<
|
|
|
|
const std::unordered_map<
|
|
|
|
std::string, ngraph::element::Type>& var_type_map,
|
|
|
|
std::string, ngraph::element::Type>& var_type_map,
|
|
|
|
const std::unordered_set<std::string>& persist,
|
|
|
|
const std::unordered_set<std::string>& persist,
|
|
|
|
const std::unordered_set<std::string>& fetches,
|
|
|
|
const std::unordered_set<std::string>& fetches,
|
|
|
|
const std::unordered_set<std::string>& post_op_inputs,
|
|
|
|
const std::unordered_set<std::string>& post_op_inputs,
|
|
|
|
op_state ng_op_state)
|
|
|
|
op_state ng_op_state)
|
|
|
|
: scope_(scope),
|
|
|
|
: scope_(scope),
|
|
|
|
place_(place),
|
|
|
|
place_(place),
|
|
|
|
fused_ops_(ops),
|
|
|
|
fused_ops_(ops),
|
|
|
@ -132,7 +131,7 @@ class NgraphOperator {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
|
|
|
|
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
|
|
|
|
FusedOperator::FusedOpIntervals(
|
|
|
|
NgraphOperator::NgraphOpIntervals(
|
|
|
|
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops) {
|
|
|
|
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops) {
|
|
|
|
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
|
|
|
|
std::vector<std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
|
|
|
|
intervals;
|
|
|
|
intervals;
|
|
|
@ -185,7 +184,7 @@ FusedOperator::FusedOpIntervals(
|
|
|
|
return intervals;
|
|
|
|
return intervals;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
FusedOperator::FusedOperator(
|
|
|
|
NgraphOperator::NgraphOperator(
|
|
|
|
const ProgramDesc& prog, size_t block_id,
|
|
|
|
const ProgramDesc& prog, size_t block_id,
|
|
|
|
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
|
|
|
|
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
|
|
|
|
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
|
|
|
|
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
|
|
|
@ -215,7 +214,7 @@ FusedOperator::FusedOperator(
|
|
|
|
Process();
|
|
|
|
Process();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void FusedOperator::Process() {
|
|
|
|
void NgraphOperator::Process() {
|
|
|
|
auto& bdesc = pdesc_.Block(block_);
|
|
|
|
auto& bdesc = pdesc_.Block(block_);
|
|
|
|
for (auto& var : bdesc.AllVars()) {
|
|
|
|
for (auto& var : bdesc.AllVars()) {
|
|
|
|
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
|
|
|
|
if (!(var->GetType() == proto::VarType::SELECTED_ROWS ||
|
|
|
@ -251,8 +250,8 @@ void FusedOperator::Process() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void FusedOperator::RunImpl(const Scope& scope,
|
|
|
|
void NgraphOperator::RunImpl(const Scope& scope,
|
|
|
|
const platform::Place& place) const {
|
|
|
|
const platform::Place& place) const {
|
|
|
|
op_state ng_op_state = PARTIAL_TEST;
|
|
|
|
op_state ng_op_state = PARTIAL_TEST;
|
|
|
|
auto& bdesc = pdesc_.Block(block_);
|
|
|
|
auto& bdesc = pdesc_.Block(block_);
|
|
|
|
for (auto* op : bdesc.AllOps()) {
|
|
|
|
for (auto* op : bdesc.AllOps()) {
|
|
|
@ -266,19 +265,19 @@ void FusedOperator::RunImpl(const Scope& scope,
|
|
|
|
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
|
|
|
|
ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_,
|
|
|
|
NgraphEngine ngraph_engine(scope, place, fused_ops_, var_type_map_,
|
|
|
|
persistables_, fetches_, post_op_inputs_,
|
|
|
|
persistables_, fetches_, post_op_inputs_,
|
|
|
|
ng_op_state);
|
|
|
|
ng_op_state);
|
|
|
|
ngraph_op.Run(scope, place);
|
|
|
|
ngraph_engine.Run(scope, place);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
|
|
|
|
NgraphOperator::func_cache_ = {};
|
|
|
|
NgraphEngine::func_cache_ = {};
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::runtime::Backend> NgraphOperator::backend_ =
|
|
|
|
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
|
|
|
|
ngraph::runtime::Backend::create("CPU");
|
|
|
|
ngraph::runtime::Backend::create("CPU");
|
|
|
|
|
|
|
|
|
|
|
|
void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
|
|
|
|
void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
|
|
|
|
op->RuntimeInferShape(scope_, place_);
|
|
|
|
op->RuntimeInferShape(scope_, place_);
|
|
|
|
for (auto& var_name_item : op->Inputs()) {
|
|
|
|
for (auto& var_name_item : op->Inputs()) {
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
@ -301,7 +300,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void NgraphOperator::BuildNgNodes() {
|
|
|
|
void NgraphEngine::BuildNgNodes() {
|
|
|
|
for (auto& var_name : var_out_) {
|
|
|
|
for (auto& var_name : var_out_) {
|
|
|
|
if (var_node_map_->find(var_name) == var_node_map_->end()) {
|
|
|
|
if (var_node_map_->find(var_name) == var_node_map_->end()) {
|
|
|
|
auto* var = scope_.FindVar(var_name);
|
|
|
|
auto* var = scope_.FindVar(var_name);
|
|
|
@ -323,7 +322,7 @@ void NgraphOperator::BuildNgNodes() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void NgraphOperator::BuildNgIO() {
|
|
|
|
void NgraphEngine::BuildNgIO() {
|
|
|
|
std::unordered_set<std::string> inputs;
|
|
|
|
std::unordered_set<std::string> inputs;
|
|
|
|
std::unordered_set<std::string> outputs;
|
|
|
|
std::unordered_set<std::string> outputs;
|
|
|
|
|
|
|
|
|
|
|
@ -395,7 +394,7 @@ void NgraphOperator::BuildNgIO() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void NgraphOperator::BuildNgFunction() {
|
|
|
|
void NgraphEngine::BuildNgFunction() {
|
|
|
|
BuildNgNodes();
|
|
|
|
BuildNgNodes();
|
|
|
|
ngraph_function_ = nullptr;
|
|
|
|
ngraph_function_ = nullptr;
|
|
|
|
ngraph::NodeVector func_outputs;
|
|
|
|
ngraph::NodeVector func_outputs;
|
|
|
@ -416,7 +415,7 @@ void NgraphOperator::BuildNgFunction() {
|
|
|
|
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
|
|
|
|
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<std::string> NgraphOperator::GetCacheKey() {
|
|
|
|
std::shared_ptr<std::string> NgraphEngine::GetCacheKey() {
|
|
|
|
auto cache_key = std::make_shared<std::string>("");
|
|
|
|
auto cache_key = std::make_shared<std::string>("");
|
|
|
|
*cache_key += std::to_string(fused_ops_.size());
|
|
|
|
*cache_key += std::to_string(fused_ops_.size());
|
|
|
|
for (auto& op : fused_ops_) {
|
|
|
|
for (auto& op : fused_ops_) {
|
|
|
@ -444,7 +443,7 @@ std::shared_ptr<std::string> NgraphOperator::GetCacheKey() {
|
|
|
|
return cache_key;
|
|
|
|
return cache_key;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void NgraphOperator::GetNgFunction() {
|
|
|
|
void NgraphEngine::GetNgFunction() {
|
|
|
|
bool cache_on = true;
|
|
|
|
bool cache_on = true;
|
|
|
|
if (cache_on) {
|
|
|
|
if (cache_on) {
|
|
|
|
std::string cache_key_val = *GetCacheKey();
|
|
|
|
std::string cache_key_val = *GetCacheKey();
|
|
|
@ -459,8 +458,7 @@ void NgraphOperator::GetNgFunction() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void NgraphOperator::Run(const Scope& scope,
|
|
|
|
void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in;
|
|
|
|
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in;
|
|
|
|
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out;
|
|
|
|
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out;
|
|
|
|
|
|
|
|
|
|
|
@ -545,7 +543,6 @@ void NgraphOperator::Run(const Scope& scope,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
backend_->call(ngraph_function_, t_out, t_in);
|
|
|
|
backend_->call(ngraph_function_, t_out, t_in);
|
|
|
|
} // NgraphOperator::RunImpl
|
|
|
|
} // NgraphEngine::RunImpl
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
#endif
|
|
|
|
|
|
|
|