|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops.functional import stop_gradient
|
|
|
|
|
from mindspore._checkparam import Validator
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
@ -146,6 +147,8 @@ class Categorical(Distribution):
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
self.softmax = P.Softmax()
|
|
|
|
|
self.squeeze = P.Squeeze()
|
|
|
|
|
self.squeeze_first_axis = P.Squeeze(0)
|
|
|
|
|
self.squeeze_last_axis = P.Squeeze(-1)
|
|
|
|
|
self.square = P.Square()
|
|
|
|
|
self.transpose = P.Transpose()
|
|
|
|
|
|
|
|
|
@ -270,7 +273,7 @@ class Categorical(Distribution):
|
|
|
|
|
# flatten value to shape (number of labels, 1)
|
|
|
|
|
# clip value to be in range from 0 to num_classes -1 and cast into int32
|
|
|
|
|
value = self.reshape(value, (-1, 1))
|
|
|
|
|
out_of_bound = self.squeeze(self.logicor(\
|
|
|
|
|
out_of_bound = self.squeeze_last_axis(self.logicor(\
|
|
|
|
|
self.less(value, 0.0), self.less(num_classes-1, value)))
|
|
|
|
|
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
|
|
|
|
|
value_clipped = self.cast(value_clipped, self.index_type)
|
|
|
|
@ -325,7 +328,7 @@ class Categorical(Distribution):
|
|
|
|
|
|
|
|
|
|
# drop one dimension to match cdf
|
|
|
|
|
# clip value to be in range from 0 to num_classes -1 and cast into int32
|
|
|
|
|
less_than_zero = self.squeeze(self.less(value, 0.0))
|
|
|
|
|
less_than_zero = self.squeeze_last_axis(self.less(value, 0.0))
|
|
|
|
|
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
|
|
|
|
|
value_clipped = self.cast(value_clipped, self.index_type)
|
|
|
|
|
|
|
|
|
@ -375,5 +378,6 @@ class Categorical(Distribution):
|
|
|
|
|
samples = self.squeeze(self.transpose(samples, (1, 0)))
|
|
|
|
|
samples = self.cast(self.reshape(samples, sample_shape), self.dtype)
|
|
|
|
|
if drop_dim:
|
|
|
|
|
return self.squeeze(samples)
|
|
|
|
|
return self.squeeze_first_axis(samples)
|
|
|
|
|
samples = stop_gradient(samples)
|
|
|
|
|
return samples
|
|
|
|
|