do not merge tensor move to one in cse pass

pull/10276/head
LianLiguang 4 years ago
parent ffe61081d3
commit 414f38df8d

@ -37,6 +37,12 @@ bool HasSideEffectAttr(const AnfNodePtr &node) {
bool BackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
if (main->isa<CNode>()) {
auto main_name = AnfAlgo::GetCNodeName(main);
if (main_name == prim::kPrimTensorMove->name() || main_name == prim::kPrimMemCpyAsync->name()) {
return false;
}
}
auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info());
auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
if (main_kernel_info == nullptr && node_kernel_info == nullptr) {

@ -200,6 +200,7 @@ inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAda
inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay");
inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");
inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum");
inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove");
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");

Loading…
Cancel
Save