|
|
|
@ -67,6 +67,52 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MKLCPUKernel::BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<size_t> *src1_shape,
|
|
|
|
|
std::vector<size_t> *dst_shape) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(src0_shape);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(src1_shape);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dst_shape);
|
|
|
|
|
bool need_swap = false;
|
|
|
|
|
if (dst_shape->size() == 0) {
|
|
|
|
|
dst_shape->emplace_back(1);
|
|
|
|
|
src0_shape->emplace_back(1);
|
|
|
|
|
src1_shape->emplace_back(1);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Binary broadcast in: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape;
|
|
|
|
|
if (src0_shape->size() != dst_shape->size()) {
|
|
|
|
|
need_swap = true;
|
|
|
|
|
for (size_t i = src0_shape->size(); i < dst_shape->size(); ++i) {
|
|
|
|
|
src0_shape->insert(src0_shape->begin(), 1);
|
|
|
|
|
}
|
|
|
|
|
} else if (src1_shape->size() != dst_shape->size()) {
|
|
|
|
|
for (size_t i = src1_shape->size(); i < dst_shape->size(); ++i) {
|
|
|
|
|
src1_shape->insert(src1_shape->begin(), 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (src0_shape->size() == src1_shape->size()) {
|
|
|
|
|
bool visit_src0 = false;
|
|
|
|
|
bool visit_src1 = false;
|
|
|
|
|
for (size_t i = 0; i < src0_shape->size(); ++i) {
|
|
|
|
|
if (src0_shape->at(i) != src1_shape->at(i)) {
|
|
|
|
|
if (src0_shape->at(i) == 1 && !visit_src1) {
|
|
|
|
|
need_swap = true;
|
|
|
|
|
visit_src0 = true;
|
|
|
|
|
} else if (src1_shape->at(i) == 1 && !visit_src0) {
|
|
|
|
|
need_swap = false;
|
|
|
|
|
visit_src1 = true;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid broadcast! " << *src0_shape << " vs " << *src1_shape;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid broadcast! src0: " << *src0_shape << " src1: " << *src1_shape
|
|
|
|
|
<< " dst: " << *dst_shape;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Binary broadcast out: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape;
|
|
|
|
|
return need_swap;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const {
|
|
|
|
|
dnnl::memory::format_tag mem_tag;
|
|
|
|
|
auto dim_size = dims.size();
|
|
|
|
|