|
|
|
@ -926,32 +926,31 @@ class InplaceAdd(PrimitiveWithInfer):
|
|
|
|
|
"""init InplaceAdd"""
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
|
|
|
|
|
self.indices = indices
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, v_shape):
|
|
|
|
|
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
|
|
|
|
|
if isinstance(self.indices, int):
|
|
|
|
|
validator.check("size of indices", 1, "v's first dimension", v_shape[0],
|
|
|
|
|
Rel.EQ, self.name)
|
|
|
|
|
if self.indices < 0 or self.indices >= x_shape[0]:
|
|
|
|
|
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.')
|
|
|
|
|
else:
|
|
|
|
|
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
|
|
|
|
|
Rel.EQ, self.name)
|
|
|
|
|
for i in self.indices:
|
|
|
|
|
if i < 0 or i >= x_shape[0]:
|
|
|
|
|
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
|
|
|
|
|
if len(x_shape) > 1:
|
|
|
|
|
validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:],
|
|
|
|
|
Rel.EQ, self.name)
|
|
|
|
|
return x_shape
|
|
|
|
|
validator.check_value_type('indices', indices, [tuple, int], self.name)
|
|
|
|
|
if isinstance(indices, int):
|
|
|
|
|
self.indices = (indices,)
|
|
|
|
|
for item in self.indices:
|
|
|
|
|
validator.check_value_type("item of indices", item, [int], self.name)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, v_dtype):
|
|
|
|
|
args = {'x': x_dtype, 'v': v_dtype}
|
|
|
|
|
valid_type = [mstype.int32, mstype.float16, mstype.float32]
|
|
|
|
|
validator.check_tensor_type_same(args, valid_type, self.name)
|
|
|
|
|
validator.check_value_type('indices', self.indices, [tuple, int], self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, v_shape):
|
|
|
|
|
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
|
|
|
|
|
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
|
|
|
|
|
Rel.EQ, self.name)
|
|
|
|
|
for i in self.indices:
|
|
|
|
|
if i < 0 or i >= x_shape[0]:
|
|
|
|
|
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
|
|
|
|
|
x_rank = len(x_shape)
|
|
|
|
|
for idx in range(x_rank)[1:]:
|
|
|
|
|
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
|
|
|
|
|
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InplaceSub(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
@ -985,32 +984,31 @@ class InplaceSub(PrimitiveWithInfer):
|
|
|
|
|
"""init InplaceSub"""
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
|
|
|
|
|
self.indices = indices
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, v_shape):
|
|
|
|
|
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
|
|
|
|
|
if isinstance(self.indices, int):
|
|
|
|
|
validator.check("size of indices", 1, "v's first dimension", v_shape[0],
|
|
|
|
|
Rel.EQ, self.name)
|
|
|
|
|
if self.indices < 0 or self.indices >= x_shape[0]:
|
|
|
|
|
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.')
|
|
|
|
|
else:
|
|
|
|
|
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
|
|
|
|
|
Rel.EQ, self.name)
|
|
|
|
|
for i in self.indices:
|
|
|
|
|
if i < 0 or i >= x_shape[0]:
|
|
|
|
|
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
|
|
|
|
|
if len(x_shape) > 1:
|
|
|
|
|
validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:],
|
|
|
|
|
Rel.EQ, self.name)
|
|
|
|
|
return x_shape
|
|
|
|
|
validator.check_value_type('indices', indices, [tuple, int], self.name)
|
|
|
|
|
if isinstance(indices, int):
|
|
|
|
|
self.indices = (indices,)
|
|
|
|
|
for item in self.indices:
|
|
|
|
|
validator.check_value_type("item of indices", item, [int], self.name)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, v_dtype):
|
|
|
|
|
args = {'x': x_dtype, 'v': v_dtype}
|
|
|
|
|
valid_type = [mstype.int32, mstype.float16, mstype.float32]
|
|
|
|
|
validator.check_tensor_type_same(args, valid_type, self.name)
|
|
|
|
|
validator.check_value_type('indices', self.indices, [tuple, int], self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, v_shape):
|
|
|
|
|
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
|
|
|
|
|
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
|
|
|
|
|
Rel.EQ, self.name)
|
|
|
|
|
for i in self.indices:
|
|
|
|
|
if i < 0 or i >= x_shape[0]:
|
|
|
|
|
raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
|
|
|
|
|
x_rank = len(x_shape)
|
|
|
|
|
for idx in range(x_rank)[1:]:
|
|
|
|
|
validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
|
|
|
|
|
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Sub(_MathBinaryOp):
|
|
|
|
|
"""
|
|
|
|
|