From 975e65054affa9dfe74023025781f011b0528365 Mon Sep 17 00:00:00 2001 From: yanghaoran Date: Wed, 2 Dec 2020 15:11:21 +0800 Subject: [PATCH] update optiling headers --- .../fwkacllib/inc/register/op_tiling.h | 21 ++- .../inc/register/op_tiling_registry.h | 130 ++++++------------ 2 files changed, 63 insertions(+), 88 deletions(-) diff --git a/third_party/fwkacllib/inc/register/op_tiling.h b/third_party/fwkacllib/inc/register/op_tiling.h index e94ad556..f720afec 100644 --- a/third_party/fwkacllib/inc/register/op_tiling.h +++ b/third_party/fwkacllib/inc/register/op_tiling.h @@ -20,12 +20,29 @@ #include "graph/debug/ge_attr_define.h" #include "graph/node.h" #include "register/op_tiling_registry.h" +#include namespace optiling { +#define REGISTER_OP_TILING_FUNC(optype, opfunc) REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, __COUNTER__) +#define REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, counter) \ + REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) +#define REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) \ + static OpTilingInterf g_##optype##TilingInterf##counter(#optype, opfunc) + +using OpTilingFuncOld = + std::function; + +class FMK_FUNC_HOST_VISIBILITY OpTilingInterf { + public: + OpTilingInterf(std::string op_type, OpTilingFuncOld func); + ~OpTilingInterf() = default; + static std::string OpTilingUuid; +}; + extern "C" ge::graphStatus OpParaCalculate(const ge::Node &node, OpRunInfo &run_info); extern "C" ge::graphStatus OpAtomicCalculate(const ge::Node &node, OpRunInfo &run_info); -} +} // namespace optiling -#endif // INC_REGISTER_OP_TILING_H_ +#endif // INC_REGISTER_OP_TILING_H_ diff --git a/third_party/fwkacllib/inc/register/op_tiling_registry.h b/third_party/fwkacllib/inc/register/op_tiling_registry.h index dbc00fab..c6efabda 100644 --- a/third_party/fwkacllib/inc/register/op_tiling_registry.h +++ b/third_party/fwkacllib/inc/register/op_tiling_registry.h @@ -19,135 +19,93 @@ #include #include -#include #include #include #include #include "external/register/register_types.h" #include "external/graph/tensor.h" -#define REGISTER_OP_TILING_FUNC(optype, opfunc) \ - REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, __COUNTER__) +#define REGISTER_OP_TILING(optype, opfunc) REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, __COUNTER__) -#define REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, counter) \ - REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) +#define REGISTER_OP_TILING_FUNC_NEW(optype, opfunc) REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, __COUNTER__) -#define REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) \ - static OpTilingInterf g_##optype##TilingInterf##counter(#optype, opfunc) +#define REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, counter) REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) -#define REGISTER_OP_TILING_FUNC_NEW(optype, opfunc) \ - REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, __COUNTER__) - -#define REGISTER_OP_TILING(optype, opfunc) \ - REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, __COUNTER__) - -#define REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, counter) \ - REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) - -#define REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) \ - static OpTilingRegistryInterf g_##optype##TilingRegistryInterf##counter(#optype, opfunc) +#define REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) \ + static OpTilingRegistryInterf g_##optype##TilingRegistryInterf##counter(#optype, opfunc) namespace optiling { enum TensorArgType { - TA_NONE, - TA_SINGLE, - TA_LIST, + TA_NONE, + TA_SINGLE, + TA_LIST, }; using ByteBuffer = std::stringstream; struct TeOpTensor { - std::vector shape; - std::vector ori_shape; - std::string format; - std::string ori_format; - std::string dtype; - std::map attrs; + std::vector shape; + std::vector ori_shape; + std::string format; + std::string ori_format; + std::string dtype; + std::map attrs; }; - struct TeOpTensorArg { - TensorArgType arg_type; - std::vector tensor; + TensorArgType arg_type; + std::vector tensor; }; struct OpRunInfo { - uint32_t block_dim; - std::vector workspaces; - ByteBuffer tiling_data; - bool clear_atomic; + uint32_t block_dim; + std::vector workspaces; + ByteBuffer tiling_data; + bool clear_atomic; }; - using TeOpAttrArgs = std::vector; -using TeConstTensorData = std::tuple; +using TeConstTensorData = std::tuple; struct TeOpParas { - std::vector inputs; - std::vector outputs; - std::map const_inputs; - TeOpAttrArgs attrs; - std::string op_type; -}; - - -using OpTilingFunc = std::function; - -using OpTilingFuncPtr = bool(*)(const std::string&, const TeOpParas&, const nlohmann::json& , OpRunInfo&); - -class FMK_FUNC_HOST_VISIBILITY OpTilingInterf -{ -public: - OpTilingInterf(std::string op_type, OpTilingFunc func); - ~OpTilingInterf() = default; - static std::map &RegisteredOpInterf(); - static std::string OpTilingUuid; + std::vector inputs; + std::vector outputs; + std::map const_inputs; + TeOpAttrArgs attrs; + std::string op_type; }; struct OpCompileInfo { - std::string str; - std::string key; + std::string str; + std::string key; }; -using OpTilingFuncNew = std::function; +using OpTilingFunc = std::function; -using OpTilingFuncPtrNew = bool(*)(const TeOpParas&, const OpCompileInfo& , OpRunInfo&); +using OpTilingFuncPtr = bool (*)(const TeOpParas &, const OpCompileInfo &, OpRunInfo &); class FMK_FUNC_HOST_VISIBILITY OpTilingRegistryInterf { -public: - OpTilingRegistryInterf(std::string op_type, OpTilingFuncNew func); - ~OpTilingRegistryInterf() = default; - static std::map &RegisteredOpInterf(); + public: + OpTilingRegistryInterf(std::string op_type, OpTilingFunc func); + ~OpTilingRegistryInterf() = default; + static std::map &RegisteredOpInterf(); }; template -ByteBuffer& ByteBufferPut(ByteBuffer &buf, const T &value) -{ - buf.write(reinterpret_cast(&value), sizeof(value)); - buf.flush(); - return buf; +ByteBuffer &ByteBufferPut(ByteBuffer &buf, const T &value) { + buf.write(reinterpret_cast(&value), sizeof(value)); + buf.flush(); + return buf; } template -ByteBuffer& ByteBufferGet(ByteBuffer &buf, T &value) -{ - buf.read(reinterpret_cast(&value), sizeof(value)); - return buf; +ByteBuffer &ByteBufferGet(ByteBuffer &buf, T &value) { + buf.read(reinterpret_cast(&value), sizeof(value)); + return buf; } -inline size_t ByteBufferGetAll(ByteBuffer &buf, char *dest, size_t dest_len) -{ - size_t nread = 0; - size_t rn = 0; - do { - rn = buf.readsome(dest + nread, dest_len - nread); - nread += rn; - } while (rn > 0 && dest_len > nread); - - return nread; -} -} +size_t ByteBufferGetAll(ByteBuffer &buf, char *dest, size_t dest_len); +} // namespace optiling -#endif // INC_REGISTER_OP_TILING_REGISTRY_H_ +#endif // INC_REGISTER_OP_TILING_REGISTRY_H_