|
|
@ -24,8 +24,8 @@ namespace mindspore {
|
|
|
|
namespace kernel {
|
|
|
|
namespace kernel {
|
|
|
|
void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode,
|
|
|
|
void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode,
|
|
|
|
const std::vector<size_t> &src_shape, const std::vector<size_t> &kernel_size,
|
|
|
|
const std::vector<size_t> &src_shape, const std::vector<size_t> &kernel_size,
|
|
|
|
const std::vector<int> &stride, std::vector<int> *padding_l,
|
|
|
|
const std::vector<int> &stride, std::vector<int> *padding_l, std::vector<int> *padding_r,
|
|
|
|
std::vector<int> *padding_r) {
|
|
|
|
const std::vector<int> &dilation) {
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
if (src_shape.size() < 2) {
|
|
|
|
if (src_shape.size() < 2) {
|
|
|
|
MS_LOG(EXCEPTION) << "set pad only support src dim >= 2!";
|
|
|
|
MS_LOG(EXCEPTION) << "set pad only support src dim >= 2!";
|
|
|
@ -38,13 +38,9 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
|
|
|
|
if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) {
|
|
|
|
if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) {
|
|
|
|
for (size_t i = 0; i < weight_height.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < weight_height.size(); ++i) {
|
|
|
|
auto wh = weight_height[i];
|
|
|
|
auto wh = weight_height[i];
|
|
|
|
int re = wh % stride[i];
|
|
|
|
int out = (wh + stride[i] - 1) / stride[i];
|
|
|
|
int pad_along;
|
|
|
|
int effective_k = (SizeToInt(kernel_size[i]) - 1) * dilation[i] + 1;
|
|
|
|
if (re == 0) {
|
|
|
|
int pad_along = std::max(0, (out - 1) * stride[i] + effective_k - wh);
|
|
|
|
pad_along = std::max(SizeToInt(kernel_size[i]) - stride[i], 0);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
pad_along = std::max(SizeToInt(kernel_size[i]) - re, 0);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
int pad = pad_along / 2;
|
|
|
|
int pad = pad_along / 2;
|
|
|
|
padding_l->emplace_back(pad);
|
|
|
|
padding_l->emplace_back(pad);
|
|
|
|
padding_r->emplace_back(pad_along - pad);
|
|
|
|
padding_r->emplace_back(pad_along - pad);
|
|
|
|