@ -807,7 +807,11 @@ class GatherV2(PrimitiveWithCheck):
def __check__ ( self , params , indices , axis ) :
validator . check_subclass ( " params " , params [ ' dtype ' ] , mstype . tensor , self . name )
validator . check_tensor_dtype_valid ( " indices " , indices [ ' dtype ' ] , mstype . int_type , self . name )
validator . check_subclass ( " axis " , axis [ ' dtype ' ] , [ mstype . int_ ] , self . name )
validator . check_subclass ( " axis " , axis [ ' dtype ' ] , [ mstype . number ] , self . name )
axis_v = axis [ ' value ' ]
validator . check_value_type ( ' axis ' , axis_v , [ int ] , self . name )
rank = len ( params [ ' shape ' ] )
validator . check_int_range ( axis_v , - rank , rank , Rel . INC_LEFT , " axis " , self . name )
class SparseGatherV2 ( GatherV2 ) :
@ -975,6 +979,12 @@ class Split(PrimitiveWithCheck):
x_shape = list ( x [ ' shape ' ] )
dim = len ( x_shape )
validator . check_int_range ( self . axis , - dim , dim , Rel . INC_LEFT , ' axis value ' , self . name )
if - 1 not in x_shape :
# only validate when shape fully known
output_valid_check = x_shape [ self . axis ] % self . output_num
if output_valid_check != 0 :
raise ValueError ( f " x_shape[ { self . axis } ] { x_shape [ self . axis ] } must be divide exactly by "
f " output_num { self . output_num } " )
class Rank ( PrimitiveWithInfer ) :
@ -1945,18 +1955,21 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
self . add_prim_attr ( " dynamic_shape_depends " , [ 2 ] )
def __check__ ( self , x , segment_ids , num_segments ) :
x_shape = x [ ' shape ' ]
segment_ids_shape = segment_ids [ ' shape ' ]
valid_type = [ mstype . float16 , mstype . float32 , mstype . int32 ]
validator . check_tensor_dtype_valid ( " x " , x [ ' dtype ' ] , valid_type , self . name )
validator . check_tensor_dtype_valid ( " segment_ids " , segment_ids [ ' dtype ' ] , [ mstype . int32 ] , self . name )
validator . check_equal_int ( len ( segment_ids_shape ) , 1 , " rank of segment_ids_shape " , self . name )
num_segments_type = num_segments [ ' dtype ' ]
validator . check_subclass ( " num_segments " , num_segments_type , [ mstype . tensor , mstype . number ] , self . name )
if isinstance ( num_segments_type , type ( mstype . tensor ) ) :
validator . check_tensor_dtype_valid ( " num_segments " , num_segments_type , [ mstype . int32 , mstype . int64 ] ,
self . name )
else :
validator . check_value_type ( ' num_segments ' , num_segments [ ' value ' ] , [ int ] , self . name )
validator . check_subclass ( " num_segments " , num_segments_type , [ mstype . number ] , self . name )
if ( not - 1 in x_shape and not - 1 in segment_ids_shape ) :
# only validate when both shapes fully known
validator . check ( f ' first shape of input_x ' , x_shape [ 0 ] ,
' length of segments_id ' , segment_ids_shape [ 0 ] , Rel . EQ , self . name )
num_segments_v = num_segments [ ' value ' ]
validator . check_value_type ( ' num_segments ' , num_segments_v , [ int ] , self . name )
validator . check_positive_int ( num_segments_v , " num_segments " , self . name )
class UnsortedSegmentMax ( PrimitiveWithCheck ) :
@ -1998,6 +2011,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
self . add_prim_attr ( " dynamic_shape_depends " , [ 2 ] )
def __check__ ( self , x , segment_ids , num_segments ) :
x_shape = x [ ' shape ' ]
segment_ids_shape = segment_ids [ ' shape ' ]
valid_type = [ mstype . float16 , mstype . float32 , mstype . int32 ]
validator . check_tensor_dtype_valid ( " x " , x [ ' dtype ' ] , valid_type , self . name )
@ -2005,12 +2019,14 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
[ mstype . int32 , mstype . int64 ] , self . name )
validator . check_equal_int ( len ( segment_ids_shape ) , 1 , " rank of segment_ids_shape " , self . name )
num_segments_type = num_segments [ ' dtype ' ]
validator . check_subclass ( " num_segments " , num_segments_type , [ mstype . tensor , mstype . number ] , self . name )
if isinstance ( num_segments_type , type ( mstype . tensor ) ) :
validator . check_tensor_dtype_valid ( " num_segments " , num_segments_type , [ mstype . int32 , mstype . int64 ] ,
self . name )
else :
validator . check_value_type ( ' num_segments ' , num_segments [ ' value ' ] , [ int ] , self . name )
validator . check_subclass ( " num_segments " , num_segments_type , [ mstype . number ] , self . name )
if ( not - 1 in x_shape and not - 1 in segment_ids_shape ) :
# only validate when both shapes fully known
validator . check ( f ' first shape of input_x ' , x_shape [ 0 ] ,
' length of segments_id ' , segment_ids_shape [ 0 ] , Rel . EQ , self . name )
num_segments_v = num_segments [ ' value ' ]
validator . check_value_type ( ' num_segments ' , num_segments_v , [ int ] , self . name )
validator . check_positive_int ( num_segments_v , " num_segments " , self . name )
class UnsortedSegmentProd ( PrimitiveWithInfer ) :