cpu conv2d support tuple pad

pull/7332/head
baihuawei 4 years ago
parent aa605e23d5
commit a534e0a320

@ -34,6 +34,7 @@ const char STRIDE[] = "stride";
const char STRIDES[] = "strides";
const char DILATION[] = "dilation";
const char PAD[] = "pad";
const char PAD_LIST[] = "pad_list";
const char PAD_MODE[] = "pad_mode";
const char PADDING[] = "padding";
const char PAD_MODE_LOWER_SAME[] = "same";

@ -52,11 +52,11 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
padding_r->emplace_back(0);
padding_r->emplace_back(0);
} else {
int pad = AnfAlgo::GetNodeAttr<int>(kernel_node, PAD);
padding_l->emplace_back(pad);
padding_l->emplace_back(pad);
padding_r->emplace_back(pad);
padding_r->emplace_back(pad);
std::vector<int> pad = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, PAD_LIST);
padding_l->emplace_back(pad[0]);
padding_l->emplace_back(pad[1]);
padding_r->emplace_back(pad[2]);
padding_r->emplace_back(pad[3]);
}
}

@ -37,7 +37,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &ke
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> label_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
if (label_shape.size() > 1) {
MS_LOG(EXCEPTION) << "label shape should be 1D";
MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length minus 1";
}
dnnl::memory::dims mem_dims;
mem_dims.insert(mem_dims.end(), shape.begin(), shape.end());

Loading…
Cancel
Save