|
|
|
@ -28,29 +28,8 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace ops {
|
|
|
|
|
namespace {
|
|
|
|
|
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
|
auto prim_name = primitive->name();
|
|
|
|
|
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
|
|
|
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
|
|
|
|
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
|
|
|
|
|
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
|
|
|
|
if (format == NHWC) {
|
|
|
|
|
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
|
|
|
|
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
|
|
|
|
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
|
|
|
|
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual,
|
|
|
|
|
"w_shape[1]", w_shape[1], prim_name);
|
|
|
|
|
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
|
|
|
|
|
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
|
|
|
|
|
std::vector<int64_t> temp_w;
|
|
|
|
|
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
|
|
|
|
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
|
|
|
|
|
"w_shape[2:4]", temp_w, prim_name);
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &w_shape,
|
|
|
|
|
const std::vector<int64_t> &x_shape, const int64_t &out_channel) {
|
|
|
|
|
auto kernel_size_h = w_shape[2];
|
|
|
|
|
auto kernel_size_w = w_shape[3];
|
|
|
|
|
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
|
|
|
|
@ -92,13 +71,36 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|
|
|
|
h_out = floor(h_out);
|
|
|
|
|
w_out = floor(w_out);
|
|
|
|
|
}
|
|
|
|
|
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name);
|
|
|
|
|
primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name)));
|
|
|
|
|
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, primitive->name());
|
|
|
|
|
primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, primitive->name())));
|
|
|
|
|
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
|
|
|
|
|
return out_shape;
|
|
|
|
|
}
|
|
|
|
|
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
|
auto prim_name = primitive->name();
|
|
|
|
|
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
|
|
|
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
|
|
|
|
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
|
|
|
|
|
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
|
|
|
|
if (format == NHWC) {
|
|
|
|
|
out_shape = {x_shape[0], h_out, w_out, out_channel};
|
|
|
|
|
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
|
|
|
|
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
|
|
|
|
|
}
|
|
|
|
|
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
|
|
|
|
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
|
|
|
|
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual,
|
|
|
|
|
"w_shape[1]", w_shape[1], prim_name);
|
|
|
|
|
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
|
|
|
|
|
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
|
|
|
|
|
std::vector<int64_t> temp_w;
|
|
|
|
|
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
|
|
|
|
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
|
|
|
|
|
"w_shape[2:4]", temp_w, prim_name);
|
|
|
|
|
auto out_shape = SetPadList(primitive, w_shape, x_shape, out_channel);
|
|
|
|
|
if (format == NHWC) {
|
|
|
|
|
out_shape = {out_shape[0], out_shape[3], out_shape[1], out_shape[2]};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return std::make_shared<abstract::Shape>(out_shape);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|