|
|
@ -390,8 +390,8 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
|
|
|
|
this->ConvInferShape(input_h, input_w, &output_h, &output_w);
|
|
|
|
this->ConvInferShape(input_h, input_w, &output_h, &output_w);
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int> out_shape{input_tensor->shape()};
|
|
|
|
std::vector<int> out_shape{input_tensor->shape()};
|
|
|
|
out_shape.at(1) = output_h;
|
|
|
|
out_shape.at(1) = output_h > 0 ? output_h : 1;
|
|
|
|
out_shape.at(2) = output_w;
|
|
|
|
out_shape.at(2) = output_w > 0 ? output_w : 1;
|
|
|
|
out_shape.at(3) = weight_tensor->shape()[0];
|
|
|
|
out_shape.at(3) = weight_tensor->shape()[0];
|
|
|
|
out_tensor->set_shape(out_shape);
|
|
|
|
out_tensor->set_shape(out_shape);
|
|
|
|
|
|
|
|
|
|
|
|