!14584 gpu trt converter

From: @wilfchen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by:
pull/14584/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 675661726b

@ -68,7 +68,7 @@ bool TrtKernel::Init(const CNodePtr &kernel_node) {
return true;
}
TrtKernel::ReleaseResource() {
void TrtKernel::ReleaseResource() {
// Make sure destroy trt object before TrtLoader destruct.
context_.reset();
engine_.reset();

@ -50,6 +50,33 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
builder.SetOutputsFormat(outputs_format);
return builder.Build();
}
AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u_input) {
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdam and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
const auto &n = node->cast<CNodePtr>()->input(2);
MS_EXCEPTION_IF_NULL(n);
const auto &fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam);
break;
}
}
return adam;
}
} // namespace
const BaseRef AdamFusion::DefinePattern() const {
@ -118,51 +145,19 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
// Fused into a FusedAdam operator.
auto prim = std::make_shared<Primitive>(kFusedAdamName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
beta1_input,
one_sub_beta1_input,
beta2_input,
one_sub_beta2_input,
eps_input,
lr_input,
param,
m_input,
v_input,
gradient_input};
auto prim_value = NewValueNode(prim);
std::vector<AnfNodePtr> inputs = {
prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param,
m_input, v_input, gradient_input};
auto adam = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get());
adam->set_scope(node->scope());
auto build_info = GenerateKernelBuildInfo(adam);
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdam and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
auto n = node->cast<CNodePtr>()->input(2);
auto fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam);
break;
}
}
return adam;
return RelpaceOutputEdge(node, adam, u_input);
}
} // namespace opt
} // namespace mindspore

@ -50,6 +50,34 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
builder.SetOutputsFormat(outputs_format);
return builder.Build();
}
AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay, AnfNodePtr u_input) {
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdamWeightDecay and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
const auto &n = node->cast<CNodePtr>()->input(2);
MS_EXCEPTION_IF_NULL(n);
const auto &fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay);
break;
}
}
return adam_weight_decay;
}
} // namespace
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
@ -122,18 +150,10 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
// Fused into a FusedAdamWeightDecay operator.
auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
beta1_input,
one_sub_beta1_input,
beta2_input,
one_sub_beta2_input,
eps_input,
lr_input,
param,
m_input,
v_input,
gradient_input,
weight_decay_input};
auto prim_value = NewValueNode(prim);
std::vector<AnfNodePtr> inputs = {
prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param,
m_input, v_input, gradient_input, weight_decay_input};
auto adam_weight_decay = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam_weight_decay);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
@ -143,31 +163,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
auto build_info = GenerateKernelBuildInfo(adam_weight_decay);
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get());
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdamWeightDecay and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
auto n = node->cast<CNodePtr>()->input(2);
auto fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay);
break;
}
}
return adam_weight_decay;
return ReplaceOutputEdge(node, adam_weight_decay, u_input);
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,89 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_
#include <unordered_map>
#include <vector>
#include <string>
#include <memory>
#include <NvInfer.h>
#include "base/base.h"
#include "ir/anf.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/trt_pass/layer_input.h"
namespace mindspore {
namespace opt {
class TrtConverterContext : public std::enable_shared_from_this<TrtConverterContext> {
public:
explicit TrtConverterContext(FuncGraphPtr fg)
: func_graph_(fg),
batch_size_(1),
workspace_size_(4UL << 30),
builder_(nullptr),
network_(nullptr),
config_(nullptr),
engine_(nullptr) {}
~TrtConverterContext() = default;
bool Init();
// Parser KernelGraph to trt graph
bool Parser();
// Serialize trt models.
bool Serialize(std::string *model);
// Get trt graph inputs without weights. The inputs keep same order as binding name.
std::vector<AnfNodePtr> GetGraphInputs();
// Get trt graph outputs. All outputs are flatten to vector with concret shape.
std::vector<session::KernelWithIndex> GetGraphOutputs();
// Store trt layer outputs to the cache.
bool StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &inputs);
// Get trt layer inputs from the cache.
bool LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs);
// Create and keep temporary weight, as constant folding demanding new weight excluded in graph,
// which should release until building finish.
std::shared_ptr<tensor::Tensor> CreateTempWeight(const TypeId &type, const std::vector<size_t> &shape);
std::shared_ptr<nvinfer1::INetworkDefinition> network() const { return network_; }
private:
bool InitInputTable();
bool InitValueNodeTable();
FuncGraphPtr func_graph_;
uint32_t batch_size_;
size_t workspace_size_;
std::shared_ptr<nvinfer1::IBuilder> builder_;
std::shared_ptr<nvinfer1::INetworkDefinition> network_;
std::shared_ptr<nvinfer1::IBuilderConfig> config_;
std::shared_ptr<nvinfer1::ICudaEngine> engine_;
// Cache (AnfNode + output_index : ILayer output).
std::unordered_map<AnfNodePtr, std::unordered_map<size_t, LayerInput>> output_map_;
std::vector<std::shared_ptr<tensor::Tensor>> temp_weights_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_HELPER_H_

@ -29,9 +29,9 @@
namespace mindspore {
namespace opt {
class LayerInput;
class TrtConverterHelper;
class TrtConverterContext;
using ConvertResult = std::pair<bool, std::vector<LayerInput>>;
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterHelper>)>;
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterContext>)>;
class TrtOpFactory {
public:
@ -69,10 +69,10 @@ class TrtOpRegister {
};
// Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition.
#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context); \
static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context)
#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context); \
static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context)
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_

Loading…
Cancel
Save