|
|
|
@ -284,7 +284,7 @@ class SampledSoftmaxLoss(_Loss):
|
|
|
|
|
where a sampled class equals one of the target classes. Default is True.
|
|
|
|
|
seed (int): Random seed for candidate sampling. Default: 0
|
|
|
|
|
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
|
|
|
|
|
If "none", do not perform reduction. Default: "None".
|
|
|
|
|
If "none", do not perform reduction. Default: "none".
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **weights** (Tensor) - Tensor of shape (C, dim).
|
|
|
|
@ -311,7 +311,22 @@ class SampledSoftmaxLoss(_Loss):
|
|
|
|
|
def __init__(self, num_sampled, num_classes, num_true=1,
|
|
|
|
|
sampled_values=None, remove_accidental_hits=True, seed=0,
|
|
|
|
|
reduction='none'):
|
|
|
|
|
super(SampledSoftmaxLoss, self).__init__()
|
|
|
|
|
super(SampledSoftmaxLoss, self).__init__(reduction)
|
|
|
|
|
|
|
|
|
|
if num_true < 1:
|
|
|
|
|
raise ValueError(f"num_true {num_true} is less than 1.")
|
|
|
|
|
if seed < 0:
|
|
|
|
|
raise ValueError(f"seed {seed} is less than 0.")
|
|
|
|
|
if num_sampled > num_classes:
|
|
|
|
|
raise ValueError(f"num_sampled {num_sampled} is great than num_classes {num_classes}.")
|
|
|
|
|
if num_true > num_classes:
|
|
|
|
|
raise ValueError(f"num_true {num_true} is great than num_classes {num_classes}.")
|
|
|
|
|
if sampled_values is not None:
|
|
|
|
|
if not isinstance(sampled_values, (list, tuple)):
|
|
|
|
|
raise TypeError(f"sampled_values {sampled_values} is not a list.")
|
|
|
|
|
if len(sampled_values) != 3:
|
|
|
|
|
raise ValueError(f"sampled_values size {len(sampled_values)} is not 3.")
|
|
|
|
|
|
|
|
|
|
self.num_sampled = num_sampled
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.num_true = num_true
|
|
|
|
|