!7659 Add stop gradient into Categorical Sampling, fix log_prob calculation

Merge pull request !7659 from XunDeng/pp_issue_branch
pull/7659/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3fcc9683f9

@ -16,6 +16,7 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.functional import stop_gradient
from mindspore._checkparam import Validator
import mindspore.nn as nn
from mindspore.common import dtype as mstype
@ -146,6 +147,8 @@ class Categorical(Distribution):
self.shape = P.Shape()
self.softmax = P.Softmax()
self.squeeze = P.Squeeze()
self.squeeze_first_axis = P.Squeeze(0)
self.squeeze_last_axis = P.Squeeze(-1)
self.square = P.Square()
self.transpose = P.Transpose()
@ -270,7 +273,7 @@ class Categorical(Distribution):
# flatten value to shape (number of labels, 1)
# clip value to be in range from 0 to num_classes -1 and cast into int32
value = self.reshape(value, (-1, 1))
out_of_bound = self.squeeze(self.logicor(\
out_of_bound = self.squeeze_last_axis(self.logicor(\
self.less(value, 0.0), self.less(num_classes-1, value)))
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
value_clipped = self.cast(value_clipped, self.index_type)
@ -325,7 +328,7 @@ class Categorical(Distribution):
# drop one dimension to match cdf
# clip value to be in range from 0 to num_classes -1 and cast into int32
less_than_zero = self.squeeze(self.less(value, 0.0))
less_than_zero = self.squeeze_last_axis(self.less(value, 0.0))
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
value_clipped = self.cast(value_clipped, self.index_type)
@ -375,5 +378,6 @@ class Categorical(Distribution):
samples = self.squeeze(self.transpose(samples, (1, 0)))
samples = self.cast(self.reshape(samples, sample_shape), self.dtype)
if drop_dim:
return self.squeeze(samples)
return self.squeeze_first_axis(samples)
samples = stop_gradient(samples)
return samples

Loading…
Cancel
Save