|
|
|
@ -2167,7 +2167,6 @@ class Concat(PrimitiveWithInfer):
|
|
|
|
|
x_shp = input_x['shape']
|
|
|
|
|
x_type = input_x['dtype']
|
|
|
|
|
_, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
|
|
|
|
|
self.add_prim_attr('T', x_type[0].element_type())
|
|
|
|
|
self.add_prim_attr('inputNums', len(x_shp))
|
|
|
|
|
ret_shp = x_shp[0].copy()
|
|
|
|
|
value = None
|
|
|
|
@ -2616,7 +2615,6 @@ class Select(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, cond_type, x_type, y_type):
|
|
|
|
|
self.add_prim_attr('T', x_type)
|
|
|
|
|
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
|
|
|
|
|
validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name)
|
|
|
|
|