|
|
|
@ -500,3 +500,61 @@ class Multinomial(PrimitiveWithInfer):
|
|
|
|
|
"dtype": mstype.int32,
|
|
|
|
|
"value": None}
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
class UniformCandidateSampler(PrimitiveWithInfer):
|
|
|
|
|
r"""
|
|
|
|
|
Uniform candidate sampler.
|
|
|
|
|
|
|
|
|
|
This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
|
|
|
|
|
If unique=True, candidates are drawn without replacement, else unique=False with replacement.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
num_true (int): The number of target classes in each training example.
|
|
|
|
|
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
|
|
|
|
|
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
|
|
|
|
|
unique (bool): Whether all sampled classes in a batch are unique.
|
|
|
|
|
range_max (int): The number of possible classes, must be non-negative.
|
|
|
|
|
seed (int): Random seed, must be non-negative. Default: 0.
|
|
|
|
|
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **true_classes** (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true).
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
|
|
|
|
|
Shape: (num_sampled, ).
|
|
|
|
|
- **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
|
|
|
|
|
of true_classes. Shape: (batch_size, num_true).
|
|
|
|
|
- **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
|
|
|
|
|
each of sampled_candidates. Shape: (num_sampled, ).
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> sampler = P.UniformCandidateSampler(1, 3, False, 4)
|
|
|
|
|
>>> output1, output2, output3 = sampler(Tensor(np.array([[1],[3],[4],[6],[3]], dtype=np.int32)))
|
|
|
|
|
>>> print(output1, output2, output3)
|
|
|
|
|
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False):
|
|
|
|
|
"""Initialize UniformCandidateSampler"""
|
|
|
|
|
Validator.check_value_type("num_true", num_true, [int], self.name)
|
|
|
|
|
Validator.check_value_type("num_sampled", num_sampled, [int], self.name)
|
|
|
|
|
Validator.check_value_type("unique", unique, [bool], self.name)
|
|
|
|
|
Validator.check_value_type("range_max", range_max, [int], self.name)
|
|
|
|
|
Validator.check_value_type("seed", seed, [int], self.name)
|
|
|
|
|
Validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name)
|
|
|
|
|
Validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
|
|
|
|
|
Validator.check("value of range_max", range_max, '', 0, Rel.GT, self.name)
|
|
|
|
|
self.num_true = num_true
|
|
|
|
|
if unique:
|
|
|
|
|
Validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name)
|
|
|
|
|
Validator.check("value of seed", seed, '', 0, Rel.GE, self.name)
|
|
|
|
|
self.num_sampled = num_sampled
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, true_classes_type):
|
|
|
|
|
return (true_classes_type, mstype.float32, mstype.float32)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, true_classes_shape):
|
|
|
|
|
Validator.check("true_class.shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
|
|
|
|
|
return ([self.num_sampled], true_classes_shape, [self.num_sampled])
|
|
|
|
|