|
|
|
@ -25,14 +25,27 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
|
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
|
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
|
|
|
|
if (dst_shape.size() == 0) {
|
|
|
|
|
dst_shape.emplace_back(1);
|
|
|
|
|
src0_shape.emplace_back(1);
|
|
|
|
|
src1_shape.emplace_back(1);
|
|
|
|
|
}
|
|
|
|
|
size_t src0_length = 1;
|
|
|
|
|
size_t src1_length = 1;
|
|
|
|
|
for (size_t i = 0; i < src0_shape.size(); ++i) {
|
|
|
|
|
src0_length = src0_length * src0_shape[i];
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < src1_shape.size(); ++i) {
|
|
|
|
|
src1_length = src1_length * src1_shape[i];
|
|
|
|
|
}
|
|
|
|
|
if (src1_shape.size() != src0_shape.size()) {
|
|
|
|
|
if (src0_shape.size() == 0) {
|
|
|
|
|
if (src0_length == 1 && src0_shape.size() != dst_shape.size()) {
|
|
|
|
|
need_swap_ = true;
|
|
|
|
|
for (size_t i = 0; i < src1_shape.size(); ++i) {
|
|
|
|
|
for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) {
|
|
|
|
|
src0_shape.emplace_back(1);
|
|
|
|
|
}
|
|
|
|
|
} else if (src1_shape.size() == 0) {
|
|
|
|
|
for (size_t i = 0; i < src0_shape.size(); ++i) {
|
|
|
|
|
} else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) {
|
|
|
|
|
for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) {
|
|
|
|
|
src1_shape.emplace_back(1);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|