|
|
|
@ -5017,7 +5017,7 @@ class Sort(PrimitiveWithInfer):
|
|
|
|
|
return x_dtype, mstype.tensor_type(mstype.int32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
|
class EmbeddingLookup(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Returns a slice of input tensor based on the specified indices.
|
|
|
|
|
|
|
|
|
@ -5063,28 +5063,13 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
|
|
|
|
|
outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def __infer__(self, params, indices, offset):
|
|
|
|
|
def __check__(self, params, indices, offset):
|
|
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
|
|
|
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
|
|
|
|
params_shp = params['shape']
|
|
|
|
|
if len(params_shp) > 2:
|
|
|
|
|
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))
|
|
|
|
|
out_shape = indices['shape'] + params_shp[1:]
|
|
|
|
|
if 'max_shape' in indices:
|
|
|
|
|
out_max_shape = indices['max_shape'] + params_shp[1:]
|
|
|
|
|
else:
|
|
|
|
|
out_max_shape = out_shape
|
|
|
|
|
if 'min_shape' in indices:
|
|
|
|
|
out_min_shape = indices['min_shape'] + params_shp[1:]
|
|
|
|
|
else:
|
|
|
|
|
out_min_shape = out_shape
|
|
|
|
|
out = {'shape': out_shape,
|
|
|
|
|
'dtype': params['dtype'],
|
|
|
|
|
'value': None,
|
|
|
|
|
'max_shape': out_max_shape,
|
|
|
|
|
'min_shape': out_min_shape}
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GatherD(PrimitiveWithInfer):
|
|
|
|
|