|
|
@ -35,13 +35,13 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i
|
|
|
|
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
|
|
|
|
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
|
|
|
|
|
|
|
|
|
|
|
|
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
|
|
|
|
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
|
|
|
|
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false);
|
|
|
|
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name());
|
|
|
|
if (strides[0] != strides[1]) {
|
|
|
|
if (strides[0] != strides[1]) {
|
|
|
|
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
|
|
|
|
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
|
|
|
|
<< ", width " << strides[1];
|
|
|
|
<< ", width " << strides[1];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
this->set_stride(strides);
|
|
|
|
this->set_stride(strides);
|
|
|
|
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false);
|
|
|
|
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name());
|
|
|
|
if (dilations[0] != dilations[1]) {
|
|
|
|
if (dilations[0] != dilations[1]) {
|
|
|
|
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
|
|
|
|
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
|
|
|
|
<< ", width " << dilations[1];
|
|
|
|
<< ", width " << dilations[1];
|
|
|
@ -57,7 +57,7 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
|
|
|
|
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true));
|
|
|
|
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name()));
|
|
|
|
|
|
|
|
|
|
|
|
this->set_out_channel(
|
|
|
|
this->set_out_channel(
|
|
|
|
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));
|
|
|
|
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));
|
|
|
|