|
|
|
@ -446,10 +446,7 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param
|
|
|
|
|
for (size_t i = 0; i < cnodes_size; ++i) {
|
|
|
|
|
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) {
|
|
|
|
|
auto load_node = cnodes[i]->input(1);
|
|
|
|
|
if (IsPrimitiveCNode(load_node, prim::kPrimCast)) {
|
|
|
|
|
load_node = load_node->cast<CNodePtr>()->input(1);
|
|
|
|
|
}
|
|
|
|
|
if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
|
|
|
|
|
if (IsPrimitiveCNode(load_node, prim::kPrimLoad) || IsPrimitiveCNode(load_node, prim::kPrimCast)) {
|
|
|
|
|
auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
|
|
|
|
|
if (param_set.find(param_node) != param_set.end()) {
|
|
|
|
|
sparse_gather_v2_with_cache.push_back(cnodes[i]);
|
|
|
|
|