Fixing type check mistakes of InplaceAdd, InplaceSub and InplaceUpdate vm ops

pull/2744/head
liuwenhao4 5 years ago
parent 65189e8ccc
commit 4090c1611b

@ -2818,33 +2818,30 @@ class InplaceUpdate(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, indices):
"""Init InplaceUpdate"""
self.init_prim_io_names(inputs=['x', 'indices', 'v'], outputs=['y'])
self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
self.indices = indices
validator.check_value_type("indices", indices, [int, tuple], self.name)
if isinstance(indices, int):
self.add_prim_attr('indices', (indices,))
self.indices = (indices,)
for item in self.indices:
validator.check_value_type("item of indices", item, [int], self.name)
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(
{
"x": x_dtype,
"v": v_dtype
}, valid_type, self.name)
validator.check_tensor_type_same(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, v_shape):
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
for i in self.indices:
if i < 0 or i >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
x_rank = len(x_shape)
for idx in range(x_rank)[1:]:
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
return x_shape

@ -926,32 +926,31 @@ class InplaceAdd(PrimitiveWithInfer):
"""init InplaceAdd"""
self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
self.indices = indices
def infer_shape(self, x_shape, v_shape):
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
if isinstance(self.indices, int):
validator.check("size of indices", 1, "v's first dimension", v_shape[0],
Rel.EQ, self.name)
if self.indices < 0 or self.indices >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.')
else:
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
for i in self.indices:
if i < 0 or i >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
if len(x_shape) > 1:
validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:],
Rel.EQ, self.name)
return x_shape
validator.check_value_type('indices', indices, [tuple, int], self.name)
if isinstance(indices, int):
self.indices = (indices,)
for item in self.indices:
validator.check_value_type("item of indices", item, [int], self.name)
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_value_type('indices', self.indices, [tuple, int], self.name)
return x_dtype
def infer_shape(self, x_shape, v_shape):
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
for i in self.indices:
if i < 0 or i >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
x_rank = len(x_shape)
for idx in range(x_rank)[1:]:
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
return x_shape
class InplaceSub(PrimitiveWithInfer):
"""
@ -985,32 +984,31 @@ class InplaceSub(PrimitiveWithInfer):
"""init InplaceSub"""
self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
self.indices = indices
def infer_shape(self, x_shape, v_shape):
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
if isinstance(self.indices, int):
validator.check("size of indices", 1, "v's first dimension", v_shape[0],
Rel.EQ, self.name)
if self.indices < 0 or self.indices >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.')
else:
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
for i in self.indices:
if i < 0 or i >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
if len(x_shape) > 1:
validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:],
Rel.EQ, self.name)
return x_shape
validator.check_value_type('indices', indices, [tuple, int], self.name)
if isinstance(indices, int):
self.indices = (indices,)
for item in self.indices:
validator.check_value_type("item of indices", item, [int], self.name)
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_value_type('indices', self.indices, [tuple, int], self.name)
return x_dtype
def infer_shape(self, x_shape, v_shape):
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
for i in self.indices:
if i < 0 or i >= x_shape[0]:
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
x_rank = len(x_shape)
for idx in range(x_rank)[1:]:
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
return x_shape
class Sub(_MathBinaryOp):
"""

Loading…
Cancel
Save