fix argmaxwithvalue

pull/3536/head
fangzehua 5 years ago
parent e730224a2c
commit 32519ef570

@ -1185,7 +1185,8 @@ class ArgMaxWithValue(PrimitiveWithInfer):
"""init ArgMaxWithValue"""
self.axis = axis
self.keep_dims = keep_dims
_check_infer_attr_reduce(axis, keep_dims, self.name)
validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
validator.check_value_type('axis', axis, [int], self.name)
def infer_shape(self, x_shape):
axis = self.axis

Loading…
Cancel
Save