Adding ngraph_engine_op (#14948)
* enable ngraph_engine_op test=develop * merge develop test=develop * avoid const_cast test=develop * rm ngraph_operator test=develop * Added TODO to move EnableNgraph test=develop * Add TODO to remove const_cast test=developinference-pre-release-gpu
parent
7166b52a6e
commit
efce25673c
@ -1,64 +0,0 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/attribute.h"
|
||||
#include "paddle/fluid/framework/op_info.h"
|
||||
#include "paddle/fluid/framework/op_kernel_type.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/variant.h"
|
||||
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class NgraphOperator : public OperatorBase {
|
||||
public:
|
||||
static std::vector<
|
||||
std::vector<std::vector<std::unique_ptr<OperatorBase>>::iterator>>
|
||||
NgraphOpIntervals(
|
||||
std::vector<std::unique_ptr<paddle::framework::OperatorBase>>* ops);
|
||||
|
||||
explicit NgraphOperator(
|
||||
const ProgramDesc& prog, size_t block_id,
|
||||
std::vector<std::unique_ptr<OperatorBase>>::iterator start,
|
||||
std::vector<std::unique_ptr<OperatorBase>>::iterator end,
|
||||
const std::string& type = "fused_op", const VariableNameMap& inputs = {},
|
||||
const VariableNameMap& outputs = {}, const AttributeMap& attrs = {});
|
||||
|
||||
void RunImpl(const Scope& scope, const platform::Place& place) const final;
|
||||
|
||||
private:
|
||||
const ProgramDesc pdesc_;
|
||||
size_t block_;
|
||||
std::vector<std::shared_ptr<OperatorBase>> fused_ops_;
|
||||
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
|
||||
std::unordered_set<std::string> persistables_;
|
||||
std::unordered_set<std::string> fetches_;
|
||||
std::unordered_set<std::string> post_op_inputs_;
|
||||
bool is_full_ = false;
|
||||
|
||||
void Process();
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,4 @@
|
||||
if(WITH_NGRAPH)
|
||||
cc_library(ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto)
|
||||
op_library(ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context)
|
||||
endif()
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,93 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
enum class OpState { /* nGraph support state on ops */
|
||||
FULL_TRAIN, /* Support full ops for train */
|
||||
PARTIAL_TRAIN, /* Support partial ops for train */
|
||||
FULL_TEST, /* Support full list of ops for test */
|
||||
PARTIAL_TEST, /* Support partial list of ops for test */
|
||||
FULL, /* All ops supported from feed to fetch */
|
||||
UNKNOWN /* Output all for debug purpose */
|
||||
};
|
||||
|
||||
// perform graph build through bridge and execute computation
|
||||
class NgraphEngine {
|
||||
public:
|
||||
explicit NgraphEngine(const framework::Scope& scope,
|
||||
const platform::Place& place,
|
||||
const std::string& serialized_graph,
|
||||
const std::vector<int>& interval);
|
||||
|
||||
void Run(const framework::Scope& scope, const platform::Place& place) const;
|
||||
|
||||
static void EnableNgraph(const framework::ProgramDesc& program);
|
||||
|
||||
private:
|
||||
static std::unordered_map<std::string, std::shared_ptr<ngraph::Function>>
|
||||
func_cache_;
|
||||
const framework::Scope& scope_;
|
||||
const platform::Place& place_;
|
||||
std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
|
||||
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
|
||||
std::unordered_set<std::string> persistables_;
|
||||
std::unordered_set<std::string> fetches_;
|
||||
std::unordered_set<std::string> post_op_inputs_;
|
||||
OpState ng_op_state_ = OpState::UNKNOWN;
|
||||
std::string func_cache_key_;
|
||||
|
||||
// ngraph backend eg. CPU
|
||||
static std::shared_ptr<ngraph::runtime::Backend> backend_;
|
||||
// ngraph function to call and execute
|
||||
std::shared_ptr<ngraph::Function> ngraph_function_;
|
||||
// var_name of inputs
|
||||
std::vector<std::string> var_in_;
|
||||
// var_name of outputs from fetch in order
|
||||
std::vector<std::string> var_out_;
|
||||
// map input vars to nodes
|
||||
std::shared_ptr<
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
||||
var_in_node_map_;
|
||||
// map each var name with a ngraph node
|
||||
std::shared_ptr<
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
||||
var_node_map_;
|
||||
// prepare info for nraph engine
|
||||
void Prepare(const framework::BlockDesc& block,
|
||||
const std::vector<int>& interval);
|
||||
// get ngraph input and define ngraph input parameters
|
||||
void GetNgInputShape(std::shared_ptr<framework::OperatorBase> op);
|
||||
// Call ngraph bridge to map ops
|
||||
void BuildNgNodes();
|
||||
// get the ngraph input and output var list
|
||||
void BuildNgIO();
|
||||
// build ngraph function call
|
||||
void BuildNgFunction();
|
||||
// Check cache for ngraph function or otherwise build the function
|
||||
void GetNgFunction();
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/block_desc.h"
|
||||
#include "paddle/fluid/framework/op_desc.h"
|
||||
#include "paddle/fluid/framework/op_info.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/ngraph/ngraph_engine_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Xs", "A list of inputs.").AsDispensable();
|
||||
AddOutput("Ys", "A list of outputs").AsDispensable();
|
||||
AddAttr<std::string>("graph", "the graph.");
|
||||
AddAttr<std::vector<int>>("interval", "op interval supported by ngraph");
|
||||
AddComment("ngraph engine operator.");
|
||||
}
|
||||
};
|
||||
|
||||
class NgraphEngineInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc &op_desc,
|
||||
framework::BlockDesc *block) const override {}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(ngraph_engine, ops::NgraphEngineOp, ops::NgraphEngineOpMaker,
|
||||
ops::NgraphEngineOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
ngraph_engine,
|
||||
ops::NgraphEngineKernel<paddle::platform::CPUDeviceContext, float>);
|
@ -0,0 +1,58 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class NgraphEngineOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
framework::OpKernelType kt = framework::OpKernelType(
|
||||
framework::proto::VarType::FP32, ctx.GetPlace());
|
||||
return kt;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class NgraphEngineKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& scope = ctx.scope();
|
||||
auto place = ctx.GetPlace();
|
||||
std::string serialized_graph = ctx.Attr<std::string>("graph");
|
||||
auto interval = ctx.Attr<std::vector<int>>("interval");
|
||||
|
||||
NgraphEngine ngraph_engine(scope, place, serialized_graph, interval);
|
||||
ngraph_engine.Run(scope, place);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue