fix atomic add judge condition for Ascend

pull/14649/head
looop5 4 years ago
parent ce248c37e0
commit 60474eaf20

@ -180,23 +180,11 @@ bool AtomicAddCheckerGPU::SuitableForAtomicAdd(const AnfNodePtr &node) {
} }
bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) { bool AtomicAddCheckerAscend::SuitableForAtomicAdd(const AnfNodePtr &node) {
auto input = node->cast<CNodePtr>()->input(kFirstDataInputIndex); auto dst_shape_vec = AnfAlgo::GetOutputDeviceShape(node, 0);
auto src_shape_vec = GetShape(input);
std::set<int64_t> axis_set = GetUniqReduceAxes(node);
// case 1: all reduce // all reduce
if (src_shape_vec.size() == axis_set.size()) { // non-reduce axes with dimension 1
return true; return std::all_of(dst_shape_vec.cbegin(), dst_shape_vec.cend(), [](const size_t &dim) { return dim == 1; });
}
// case 2: non-reduce axes with dimension 1
for (size_t i = 0; i < src_shape_vec.size(); ++i) {
if (axis_set.find(i) == axis_set.end()) {
if (src_shape_vec[i] != 1) {
return false;
}
}
}
return true;
} }
void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) { void AtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {

Loading…
Cancel
Save