@ -44,28 +44,29 @@ void MultinomialFunctor(int64_t* out_data, const T* in_data,
int64_t num_zeros = 0 ;
int64_t num_zeros = 0 ;
for ( int64_t j = 0 ; j < num_categories ; j + + ) {
for ( int64_t j = 0 ; j < num_categories ; j + + ) {
prob_value = in_data [ i * num_categories + j ] ;
prob_value = in_data [ i * num_categories + j ] ;
PADDLE_ENFORCE_GE (
PADDLE_ENFORCE_GE ( prob_value , 0.0 ,
prob_value , 0.0 ,
platform : : errors : : InvalidArgument (
platform : : errors : : OutOfRange (
" The input of multinomial distribution "
" The input of multinomial distribution should be >= 0 " ) ) ;
" should be >= 0, but got %f. " ,
PADDLE_ENFORCE_EQ ( ( std : : isinf ( static_cast < double > ( prob_value ) ) | |
prob_value ) ) ;
std : : isnan ( static_cast < double > ( prob_value ) ) ) ,
false , platform : : errors : : OutOfRange (
" The input of multinomial distribution "
" shoud not be infinity or NaN " ) ) ;
probs_sum + = prob_value ;
probs_sum + = prob_value ;
if ( prob_value = = 0 ) {
if ( prob_value = = 0 ) {
num_zeros + = 1 ;
num_zeros + = 1 ;
}
}
cumulative_probs [ j ] = probs_sum ;
cumulative_probs [ j ] = probs_sum ;
}
}
PADDLE_ENFORCE_GT ( probs_sum , 0.0 , platform : : errors : : OutOfRange (
PADDLE_ENFORCE_GT ( probs_sum , 0.0 ,
" The sum of input should not be 0 " ) ) ;
platform : : errors : : InvalidArgument (
" The sum of one multinomial distribution "
" probability should be > 0, but got %f. " ,
probs_sum ) ) ;
PADDLE_ENFORCE_EQ (
PADDLE_ENFORCE_EQ (
( replacement | | ( num_categories - num_zeros > = num_samples ) ) , true ,
( replacement | | ( num_categories - num_zeros > = num_samples ) ) , true ,
platform : : errors : : OutOfRange ( " When replacement is False, number of "
platform : : errors : : InvalidArgument (
" samples should be less than non-zero "
" When replacement is False, number of "
" categories " ) ) ;
" samples should be less than non-zero "
" categories. " ) ) ;
for ( int64_t j = 0 ; j < num_categories ; j + + ) {
for ( int64_t j = 0 ; j < num_categories ; j + + ) {
cumulative_probs [ j ] / = probs_sum ;
cumulative_probs [ j ] / = probs_sum ;