|
|
|
@ -280,6 +280,19 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|
|
|
|
ShapeVector x_min_shape = input_x->shape()->min_shape();
|
|
|
|
|
ShapeVector x_max_shape = input_x->shape()->max_shape();
|
|
|
|
|
(void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
|
|
|
|
|
for (size_t i = 0; i < x_shape.size(); ++i) {
|
|
|
|
|
if ((x_shape[i] < 0) && (x_shape[i] != Shape::SHP_ANY)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Shape element x_shape[" << i << "] must be positive integer, but got " << x_shape[i];
|
|
|
|
|
}
|
|
|
|
|
if (x_min_shape[i] < 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Min Shape element x_min_shape[" << i << "] must be positive integer, but got "
|
|
|
|
|
<< x_min_shape[i];
|
|
|
|
|
}
|
|
|
|
|
if (x_max_shape[i] < 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Max Shape element x_max_shape[" << i << "] must be positive integer, but got "
|
|
|
|
|
<< x_max_shape[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_w);
|
|
|
|
@ -288,13 +301,26 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|
|
|
|
ShapeVector w_min_shape = input_w->shape()->min_shape();
|
|
|
|
|
ShapeVector w_max_shape = input_w->shape()->max_shape();
|
|
|
|
|
(void)CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape);
|
|
|
|
|
for (size_t i = 0; i < w_shape.size(); ++i) {
|
|
|
|
|
if ((w_shape[i] < 0) && (w_shape[i] != Shape::SHP_ANY)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Shape element w_shape[" << i << "] must be positive integer, but got " << w_shape[i];
|
|
|
|
|
}
|
|
|
|
|
if (w_min_shape[i] < 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Min Shape element w_min_shape[" << i << "] must be positive integer, but got "
|
|
|
|
|
<< w_min_shape[i];
|
|
|
|
|
}
|
|
|
|
|
if (w_max_shape[i] < 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Max Shape element w_max_shape[" << i << "] must be positive integer, but got "
|
|
|
|
|
<< w_max_shape[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<std::string> available_data_format{"NCHW", "NHWC"};
|
|
|
|
|
int64_t n_axis = 0;
|
|
|
|
|
int64_t c_axis = 1;
|
|
|
|
|
int64_t h_axis = 2;
|
|
|
|
|
int64_t w_axis = 3;
|
|
|
|
|
auto data_format_ptr = primitive->GetAttr("format");
|
|
|
|
|
auto data_format_ptr = primitive->GetAttr("data_format");
|
|
|
|
|
std::string data_format = "NCHW";
|
|
|
|
|
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) {
|
|
|
|
|
data_format = data_format_ptr->cast<StringImmPtr>()->value();
|
|
|
|
@ -451,8 +477,27 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|
|
|
|
output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]};
|
|
|
|
|
output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]};
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < output_shape.size(); ++i) {
|
|
|
|
|
if ((output_shape[i] < 0) && (output_shape[i] != Shape::SHP_ANY)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Shape element output_shape[" << i << "] must be positive integer, but got "
|
|
|
|
|
<< output_shape[i];
|
|
|
|
|
}
|
|
|
|
|
if (output_shape_min[i] < 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Min Shape element output_shape_min[" << i << "] must be positive integer, but got "
|
|
|
|
|
<< output_shape_min[i];
|
|
|
|
|
}
|
|
|
|
|
if (output_shape_max[i] < 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Max Shape element output_shape_max[" << i << "] must be positive integer, but got "
|
|
|
|
|
<< output_shape_max[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max);
|
|
|
|
|
if (input_x->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt8) {
|
|
|
|
|
auto output = std::make_shared<AbstractTensor>(kInt32, output_shape);
|
|
|
|
|
output->set_shape(output_shape_ptr);
|
|
|
|
|
return output;
|
|
|
|
|
}
|
|
|
|
|
return std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|