!14291 GPU Trt operator factory and register
From: @wilfchen Reviewed-by: @limingqi107,@cristoval Signed-off-by: @cristovalpull/14291/MERGE
commit
b6605f5939
@ -0,0 +1,61 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
|
||||||
|
|
||||||
|
#include <NvInfer.h>
|
||||||
|
|
||||||
|
namespace mindspore::opt {
|
||||||
|
// Tensor-RT layer inputs include weight or tensor.
|
||||||
|
// Tensor: Anf-graph inputs or feature map which values change during inference.
|
||||||
|
// Weight: Anf-graph inputs or value node which remain unchanged during inference.
|
||||||
|
class LayerInput {
|
||||||
|
public:
|
||||||
|
LayerInput() : type_(InputType::kUnknown), weight_(), tensor_(nullptr) {}
|
||||||
|
explicit LayerInput(nvinfer1::Weights &w) : type_(InputType::kWeight), weight_(w), tensor_(nullptr) {}
|
||||||
|
explicit LayerInput(nvinfer1::ITensor *t) : type_(InputType::kTensor), weight_(), tensor_(t) {}
|
||||||
|
|
||||||
|
bool IsTensor() const { return type_ == InputType::kTensor; }
|
||||||
|
bool IsWeight() const { return type_ == InputType::kWeight; }
|
||||||
|
|
||||||
|
nvinfer1::Weights *weight() {
|
||||||
|
if (!IsWeight()) {
|
||||||
|
MS_LOG(WARNING) << "weight not initialized.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return &weight_;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::ITensor *tensor() const {
|
||||||
|
if (!IsTensor()) {
|
||||||
|
MS_LOG(WARNING) << "tensor not initialized.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensor_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
enum class InputType : char { kUnknown = 0, kTensor, kWeight };
|
||||||
|
InputType type_;
|
||||||
|
// Keep the copy rather than point cause Weights created as a local variable.
|
||||||
|
nvinfer1::Weights weight_;
|
||||||
|
// Keep the point as ITensor created/held by nvinfer1::INetworkDefinition.
|
||||||
|
nvinfer1::ITensor *tensor_;
|
||||||
|
};
|
||||||
|
} // namespace mindspore::opt
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
|
@ -0,0 +1,78 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "base/base.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class LayerInput;
|
||||||
|
class TrtConverterHelper;
|
||||||
|
using ConvertResult = std::pair<bool, std::vector<LayerInput>>;
|
||||||
|
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterHelper>)>;
|
||||||
|
|
||||||
|
class TrtOpFactory {
|
||||||
|
public:
|
||||||
|
static TrtOpFactory &GetInstance() {
|
||||||
|
static TrtOpFactory instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Register(const std::string &op_name, const ConvertFunc &func) {
|
||||||
|
if (op_convert_map_.count(op_name)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Operator: " << op_name << " re-registered.";
|
||||||
|
}
|
||||||
|
op_convert_map_.insert(std::make_pair(op_name, func));
|
||||||
|
}
|
||||||
|
|
||||||
|
ConvertFunc GetConvertFunc(const std::string &op_name) const {
|
||||||
|
auto iter = op_convert_map_.find(op_name);
|
||||||
|
if (iter == op_convert_map_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Operator: " << op_name << " not support.";
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TrtOpFactory() = default;
|
||||||
|
~TrtOpFactory() = default;
|
||||||
|
DISABLE_COPY_AND_ASSIGN(TrtOpFactory)
|
||||||
|
|
||||||
|
std::unordered_map<std::string, ConvertFunc> op_convert_map_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TrtOpRegister {
|
||||||
|
public:
|
||||||
|
TrtOpRegister(const std::string &op_name, ConvertFunc func) { TrtOpFactory::GetInstance().Register(op_name, func); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition.
|
||||||
|
#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \
|
||||||
|
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context); \
|
||||||
|
static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \
|
||||||
|
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context)
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_
|
Loading…
Reference in new issue