|
|
@ -15,17 +15,24 @@
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
#include "nnacl/infer/infer_register.h"
|
|
|
|
#include "nnacl/infer/infer_register.h"
|
|
|
|
|
|
|
|
|
|
|
|
InferShape g_infer_func[PrimType_MAX];
|
|
|
|
InferShape *g_infer_func;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__attribute__((constructor(101))) void InitInferFuncBuf() {
|
|
|
|
|
|
|
|
g_infer_func = malloc(PrimType_MAX * sizeof(InferShape));
|
|
|
|
|
|
|
|
if (g_infer_func != NULL) {
|
|
|
|
|
|
|
|
memset(g_infer_func, 0, PrimType_MAX * sizeof(InferShape));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
InferShape GetInferFunc(int prim_type) {
|
|
|
|
InferShape GetInferFunc(int prim_type) {
|
|
|
|
if (prim_type < PrimType_MAX) {
|
|
|
|
if (g_infer_func != NULL && prim_type < PrimType_MAX) {
|
|
|
|
return g_infer_func[prim_type];
|
|
|
|
return g_infer_func[prim_type];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return NULL;
|
|
|
|
return NULL;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void RegInfer(int prim_type, InferShape func) {
|
|
|
|
void RegInfer(int prim_type, InferShape func) {
|
|
|
|
if (prim_type < PrimType_MAX) {
|
|
|
|
if (g_infer_func != NULL && prim_type < PrimType_MAX) {
|
|
|
|
g_infer_func[prim_type] = func;
|
|
|
|
g_infer_func[prim_type] = func;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|