add eltwise mkldnn operator scalar support

pull/9372/head
zhaoting 4 years ago
parent 102b74682b
commit 42633cf9ff

@ -51,6 +51,9 @@ dnnl::eltwise_forward::desc EltWiseCPUKernel::GetForwardEltwiseDesc(const CNodeP
void EltWiseCPUKernel::InitKernel(const CNodePtr &kernel_node) { void EltWiseCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (src_shape.size() == 0) {
src_shape.insert(src_shape.begin(), 1);
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
auto desc = GetForwardEltwiseDesc(kernel_node, src_desc); auto desc = GetForwardEltwiseDesc(kernel_node, src_desc);

Loading…
Cancel
Save