|
|
|
@ -24,6 +24,10 @@ void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
|
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
|
if (src1_shape.size() == 0 && src0_shape.size() == 0) {
|
|
|
|
|
src0_shape.insert(src0_shape.begin(), 1);
|
|
|
|
|
src1_shape.insert(src1_shape.begin(), 1);
|
|
|
|
|
}
|
|
|
|
|
if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "AssignAdd only support same dim input or tensor * scalar " << src0_shape.size() << " vs "
|
|
|
|
|
<< src1_shape.size();
|
|
|
|
|