!9180 fix dnnl binary broadcast

From: @zhao_ting_v
Reviewed-by: 
Signed-off-by:
pull/9180/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1ffecf1874

@ -181,7 +181,7 @@ void ArithmeticCPUKernel::LaunchLess(const std::vector<AddressPtr> &inputs, cons
T *input2 = reinterpret_cast<T *>(inputs[1]->addr);
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(bool)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;

@ -67,6 +67,52 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
}
}
bool MKLCPUKernel::BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<size_t> *src1_shape,
std::vector<size_t> *dst_shape) {
MS_EXCEPTION_IF_NULL(src0_shape);
MS_EXCEPTION_IF_NULL(src1_shape);
MS_EXCEPTION_IF_NULL(dst_shape);
bool need_swap = false;
if (dst_shape->size() == 0) {
dst_shape->emplace_back(1);
src0_shape->emplace_back(1);
src1_shape->emplace_back(1);
}
MS_LOG(DEBUG) << "Binary broadcast in: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape;
if (src0_shape->size() != dst_shape->size()) {
need_swap = true;
for (size_t i = src0_shape->size(); i < dst_shape->size(); ++i) {
src0_shape->insert(src0_shape->begin(), 1);
}
} else if (src1_shape->size() != dst_shape->size()) {
for (size_t i = src1_shape->size(); i < dst_shape->size(); ++i) {
src1_shape->insert(src1_shape->begin(), 1);
}
}
if (src0_shape->size() == src1_shape->size()) {
bool visit_src0 = false;
bool visit_src1 = false;
for (size_t i = 0; i < src0_shape->size(); ++i) {
if (src0_shape->at(i) != src1_shape->at(i)) {
if (src0_shape->at(i) == 1 && !visit_src1) {
need_swap = true;
visit_src0 = true;
} else if (src1_shape->at(i) == 1 && !visit_src0) {
need_swap = false;
visit_src1 = true;
} else {
MS_LOG(EXCEPTION) << "Invalid broadcast! " << *src0_shape << " vs " << *src1_shape;
}
}
}
} else {
MS_LOG(EXCEPTION) << "Invalid broadcast! src0: " << *src0_shape << " src1: " << *src1_shape
<< " dst: " << *dst_shape;
}
MS_LOG(DEBUG) << "Binary broadcast out: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape;
return need_swap;
}
dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const {
dnnl::memory::format_tag mem_tag;
auto dim_size = dims.size();

@ -32,6 +32,8 @@ class MKLCPUKernel : public CPUKernel {
~MKLCPUKernel() override = default;
protected:
bool BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<size_t> *src1_shape,
std::vector<size_t> *dst_shape);
void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector<size_t> &src_shape,
const std::vector<size_t> &kernel_size, int stride, std::vector<int> *padding_l,
std::vector<int> *padding_r);

@ -25,49 +25,7 @@ void MulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
if (dst_shape.size() == 0) {
dst_shape.emplace_back(1);
src0_shape.emplace_back(1);
src1_shape.emplace_back(1);
}
size_t src0_length = 1;
size_t src1_length = 1;
for (size_t i = 0; i < src0_shape.size(); ++i) {
src0_length = src0_length * src0_shape[i];
}
for (size_t i = 0; i < src1_shape.size(); ++i) {
src1_length = src1_length * src1_shape[i];
}
if (src1_shape.size() != src0_shape.size()) {
if (src0_length == 1 && src0_shape.size() != dst_shape.size()) {
need_swap_ = true;
for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) {
src0_shape.emplace_back(1);
}
} else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) {
for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) {
src1_shape.emplace_back(1);
}
} else {
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
}
} else {
bool visit_src0 = false;
bool visit_src1 = false;
for (size_t i = 0; i < src0_shape.size(); ++i) {
if (src0_shape[i] != src1_shape[i]) {
if (src0_shape[i] == 1 && !visit_src1) {
need_swap_ = true;
visit_src0 = true;
} else if (src1_shape[i] == 1 && !visit_src0) {
need_swap_ = false;
visit_src1 = true;
} else {
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
}
}
}
}
need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape);
dnnl::memory::desc src0_desc;
dnnl::memory::desc src1_desc;
if (need_swap_) {

@ -25,49 +25,7 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
if (dst_shape.size() == 0) {
dst_shape.emplace_back(1);
src0_shape.emplace_back(1);
src1_shape.emplace_back(1);
}
size_t src0_length = 1;
size_t src1_length = 1;
for (size_t i = 0; i < src0_shape.size(); ++i) {
src0_length = src0_length * src0_shape[i];
}
for (size_t i = 0; i < src1_shape.size(); ++i) {
src1_length = src1_length * src1_shape[i];
}
if (src1_shape.size() != src0_shape.size()) {
if (src0_length == 1 && src0_shape.size() != dst_shape.size()) {
need_swap_ = true;
for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) {
src0_shape.emplace_back(1);
}
} else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) {
for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) {
src1_shape.emplace_back(1);
}
} else {
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
}
} else {
bool visit_src0 = false;
bool visit_src1 = false;
for (size_t i = 0; i < src0_shape.size(); ++i) {
if (src0_shape[i] != src1_shape[i]) {
if (src0_shape[i] == 1 && !visit_src1) {
need_swap_ = true;
visit_src0 = true;
} else if (src1_shape[i] == 1 && !visit_src0) {
need_swap_ = false;
visit_src1 = true;
} else {
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
}
}
}
}
need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape);
dnnl::memory::desc src0_desc;
dnnl::memory::desc src1_desc;
if (need_swap_) {

@ -47,6 +47,8 @@ def test_mul():
y2 = Tensor(2, mstype.float32)
x3 = Tensor(2, mstype.float32)
y3 = Tensor(2, mstype.float32)
x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.float32))
y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.float32))
mul = Net()
out = mul(x0, y0).asnumpy()
exp = x0.asnumpy() * y0.asnumpy()
@ -75,3 +77,10 @@ def test_mul():
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape
out = mul(x4, y4).asnumpy()
exp = x4.asnumpy() * y4.asnumpy()
diff = np.abs(out - exp)
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape

@ -45,6 +45,8 @@ def test_tensor_add():
y2 = Tensor(2, mstype.float32)
x3 = Tensor(2, mstype.float32)
y3 = Tensor(2, mstype.float32)
x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.float32))
y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.float32))
add = TensorAdd()
out = add(x0, y0).asnumpy()
exp = x0.asnumpy() + y0.asnumpy()
@ -73,3 +75,10 @@ def test_tensor_add():
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape
out = add(x4, y4).asnumpy()
exp = x4.asnumpy() + y4.asnumpy()
diff = np.abs(out - exp)
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape

Loading…
Cancel
Save