|
|
@ -23,6 +23,7 @@
|
|
|
|
#include <set>
|
|
|
|
#include <set>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include "utils/check_convert_utils.h"
|
|
|
|
#include "utils/check_convert_utils.h"
|
|
|
|
|
|
|
|
#include "abstract/primitive_infer_map.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace {
|
|
|
|
namespace {
|
|
|
@ -36,6 +37,84 @@ constexpr auto kGroup = "group";
|
|
|
|
constexpr auto kOutputChannel = "output channel";
|
|
|
|
constexpr auto kOutputChannel = "output channel";
|
|
|
|
constexpr auto kPadList = "pad_list";
|
|
|
|
constexpr auto kPadList = "pad_list";
|
|
|
|
constexpr auto kConv2DName = "Conv2D";
|
|
|
|
constexpr auto kConv2DName = "Conv2D";
|
|
|
|
|
|
|
|
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
|
|
|
|
auto conv_prim = std::dynamic_pointer_cast<Conv2d>(primitive);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conv_prim);
|
|
|
|
|
|
|
|
auto prim_name = conv_prim->name();
|
|
|
|
|
|
|
|
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name);
|
|
|
|
|
|
|
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
|
|
|
|
|
|
|
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
|
|
|
|
|
|
|
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
|
|
|
|
|
|
|
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]",
|
|
|
|
|
|
|
|
w_shape[1], conv_prim->name());
|
|
|
|
|
|
|
|
auto out_channel = conv_prim->GetOutputChannel();
|
|
|
|
|
|
|
|
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
|
|
|
|
|
|
|
|
std::vector<int> temp_w;
|
|
|
|
|
|
|
|
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
|
|
|
|
|
|
|
CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w,
|
|
|
|
|
|
|
|
conv_prim->name());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto kernel_size_h = w_shape[2];
|
|
|
|
|
|
|
|
auto kernel_size_w = w_shape[3];
|
|
|
|
|
|
|
|
auto stride = conv_prim->GetStride();
|
|
|
|
|
|
|
|
auto dilation = conv_prim->GetDilation();
|
|
|
|
|
|
|
|
auto stride_h = stride[2];
|
|
|
|
|
|
|
|
auto stride_w = stride[3];
|
|
|
|
|
|
|
|
auto dilation_h = dilation[2];
|
|
|
|
|
|
|
|
auto dilation_w = dilation[3];
|
|
|
|
|
|
|
|
int h_out = -1;
|
|
|
|
|
|
|
|
int w_out = -1;
|
|
|
|
|
|
|
|
std::vector<int> pad_list(4, 0);
|
|
|
|
|
|
|
|
auto pad_mode = conv_prim->GetPadMode();
|
|
|
|
|
|
|
|
if (pad_mode == "valid") {
|
|
|
|
|
|
|
|
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
|
|
|
|
|
|
|
|
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
|
|
|
|
|
|
|
|
} else if (pad_mode == "same") {
|
|
|
|
|
|
|
|
h_out = ceil(x_shape[2] / stride_h);
|
|
|
|
|
|
|
|
w_out = ceil(x_shape[3] / stride_w);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
|
|
|
|
|
|
|
|
pad_list.emplace_back(floor(pad_needed_h / 2));
|
|
|
|
|
|
|
|
pad_list.emplace_back(pad_needed_h / 2);
|
|
|
|
|
|
|
|
auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
|
|
|
|
|
|
|
|
auto pad_left = floor(pad_needed_w / 2);
|
|
|
|
|
|
|
|
pad_list.emplace_back(pad_left);
|
|
|
|
|
|
|
|
pad_list.emplace_back(pad_needed_h - pad_left);
|
|
|
|
|
|
|
|
} else if (pad_mode == "pad") {
|
|
|
|
|
|
|
|
std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list));
|
|
|
|
|
|
|
|
auto pad_top = conv_prim->GetPad()[0];
|
|
|
|
|
|
|
|
auto pad_bottom = conv_prim->GetPad()[1];
|
|
|
|
|
|
|
|
auto pad_right = conv_prim->GetPad()[2];
|
|
|
|
|
|
|
|
auto pad_left = conv_prim->GetPad()[3];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
|
|
|
|
|
|
|
|
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
|
|
|
|
|
|
|
|
h_out = floor(h_out);
|
|
|
|
|
|
|
|
w_out = floor(w_out);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
conv_prim->SetPadList(pad_list);
|
|
|
|
|
|
|
|
std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out};
|
|
|
|
|
|
|
|
return std::make_shared<abstract::Shape>(out_shape);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
|
|
|
|
|
|
|
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name());
|
|
|
|
|
|
|
|
for (const auto &item : input_args) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(item);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
|
|
|
|
|
|
|
|
std::map<std::string, TypePtr> types;
|
|
|
|
|
|
|
|
types.emplace("x", input_args[0]->BuildType());
|
|
|
|
|
|
|
|
types.emplace("w", input_args[1]->BuildType());
|
|
|
|
|
|
|
|
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
|
|
|
|
|
|
|
if (infer_type == kNumberTypeInt8) {
|
|
|
|
|
|
|
|
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return TypeIdToType(infer_type);
|
|
|
|
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
} // namespace
|
|
|
|
Conv2d::Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
|
|
|
|
Conv2d::Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
|
|
|
|
|
|
|
|
|
|
|
@ -105,4 +184,11 @@ void Conv2d::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); }
|
|
|
|
void Conv2d::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); }
|
|
|
|
void Conv2d::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); }
|
|
|
|
void Conv2d::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
|
|
|
|
void Conv2d::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
|
|
|
|
void Conv2d::SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
|
|
|
|
void Conv2d::SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
|
|
|
|
const std::vector<AbstractBasePtr> &input_args) {
|
|
|
|
|
|
|
|
return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
|
|
|
|
|
|
|
|
Conv2dInferShape(primitive, input_args)->shape());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
|
|
|
|
} // namespace mindspore
|
|
|
|
} // namespace mindspore
|
|
|
|