!4009 fix AbsGrad bug

Merge pull request !4009 from caojian05/ms_master_dev3
pull/4009/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit c552e8d9f4

@ -158,7 +158,7 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
case BROADCAST_TYPE_ABSGRAD:
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
}
}
@ -204,7 +204,7 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const
case BROADCAST_TYPE_FLOORDIV:
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_ABSGRAD:
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
return NoBroadcastOperator<T, S, AbsGradFunc<T, S>>(nums, input0, input1, output);
}
}

Loading…
Cancel
Save