|
|
|
@ -842,7 +842,7 @@ class Conv2D(PrimitiveWithInfer):
|
|
|
|
|
self.group = validator.check_integer('group', group, 0, Rel.GT, self.name)
|
|
|
|
|
self.add_prim_attr('offset_a', 0)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, w_shape):
|
|
|
|
|
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
|
|
|
|
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
|
|
|
@ -887,7 +887,7 @@ class Conv2D(PrimitiveWithInfer):
|
|
|
|
|
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
|
|
|
|
return out_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, w_dtype):
|
|
|
|
|
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
|
|
|
|
|
args = {'x': x_dtype, 'w': w_dtype}
|
|
|
|
|
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
|
|
|
|
validator.check_tensor_type_same(args, valid_types, self.name)
|
|
|
|
@ -968,7 +968,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
|
|
|
|
self.group = validator.check_integer("group", group, 0, Rel.GT, self.name)
|
|
|
|
|
self.add_prim_attr('offset_a', 0)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, w_shape):
|
|
|
|
|
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
|
|
|
|
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
|
|
|
@ -1011,7 +1011,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
|
|
|
|
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
|
|
|
|
return out_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, w_dtype):
|
|
|
|
|
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
|
|
|
|
|
args = {'x': x_dtype, 'w': w_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
|
|
|
|
if x_dtype.element_type() == mstype.int8:
|
|
|
|
|