|
|
@ -25,49 +25,7 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
|
|
|
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
|
|
|
if (dst_shape.size() == 0) {
|
|
|
|
need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape);
|
|
|
|
dst_shape.emplace_back(1);
|
|
|
|
|
|
|
|
src0_shape.emplace_back(1);
|
|
|
|
|
|
|
|
src1_shape.emplace_back(1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
size_t src0_length = 1;
|
|
|
|
|
|
|
|
size_t src1_length = 1;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < src0_shape.size(); ++i) {
|
|
|
|
|
|
|
|
src0_length = src0_length * src0_shape[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (size_t i = 0; i < src1_shape.size(); ++i) {
|
|
|
|
|
|
|
|
src1_length = src1_length * src1_shape[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (src1_shape.size() != src0_shape.size()) {
|
|
|
|
|
|
|
|
if (src0_length == 1 && src0_shape.size() != dst_shape.size()) {
|
|
|
|
|
|
|
|
need_swap_ = true;
|
|
|
|
|
|
|
|
for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) {
|
|
|
|
|
|
|
|
src0_shape.emplace_back(1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) {
|
|
|
|
|
|
|
|
for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) {
|
|
|
|
|
|
|
|
src1_shape.emplace_back(1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
bool visit_src0 = false;
|
|
|
|
|
|
|
|
bool visit_src1 = false;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < src0_shape.size(); ++i) {
|
|
|
|
|
|
|
|
if (src0_shape[i] != src1_shape[i]) {
|
|
|
|
|
|
|
|
if (src0_shape[i] == 1 && !visit_src1) {
|
|
|
|
|
|
|
|
need_swap_ = true;
|
|
|
|
|
|
|
|
visit_src0 = true;
|
|
|
|
|
|
|
|
} else if (src1_shape[i] == 1 && !visit_src0) {
|
|
|
|
|
|
|
|
need_swap_ = false;
|
|
|
|
|
|
|
|
visit_src1 = true;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
dnnl::memory::desc src0_desc;
|
|
|
|
dnnl::memory::desc src0_desc;
|
|
|
|
dnnl::memory::desc src1_desc;
|
|
|
|
dnnl::memory::desc src1_desc;
|
|
|
|
if (need_swap_) {
|
|
|
|
if (need_swap_) {
|
|
|
|