check equalcount input shape same

pull/1571/head
VectorSL 5 years ago
parent 83b47c370d
commit 4a8356b19a

@ -1415,7 +1415,7 @@ class EqualCount(PrimitiveWithInfer):
"""
Computes the number of the same elements of two tensors.
The two input tensors should have same data type.
The two input tensors should have same data type and shape.
Inputs:
- **input_x** (Tensor) - The first input tensor.
@ -1438,6 +1438,7 @@ class EqualCount(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, y_shape):
validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
output_shape = (1,)
return output_shape

Loading…
Cancel
Save