!14658 tensor-rt op converter

From: @wilfchen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
pull/14658/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f3f43d4cc7

@ -42,7 +42,7 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
// Skip control node
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
return GetRealOutputRecursive(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
return GetRealOutputRecursively(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
}
// Bypass TupleGetItem
@ -57,11 +57,11 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
auto make_tuple = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
auto real_input = AnfAlgo::GetInputNode(make_tuple, index);
return GetRealOutputRecursive(real_input, 0, inputs);
return GetRealOutputRecursively(real_input, 0, inputs);
}
// Skip TupleGetItem.
return GetRealOutputRecursive(input, index, inputs);
return GetRealOutputRecursively(input, index, inputs);
}
// Flatten MakeTuple inputs.
@ -71,7 +71,7 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index);
GetRealOutputRecursive(input_node, 0, inputs);
GetRealOutputRecursively(input_node, 0, inputs);
}
return;
}

@ -67,12 +67,6 @@ class TrtOpRegister {
public:
TrtOpRegister(const std::string &op_name, ConvertFunc func) { TrtOpFactory::GetInstance().Register(op_name, func); }
};
// Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition.
#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context); \
static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context)
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_

Loading…
Cancel
Save