diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h index a1fc9ea4aa..fe45b0ebc4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h @@ -30,7 +30,8 @@ namespace kernel { template class UniformSamplerGpuKernel : public GpuKernel { public: - UniformSamplerGpuKernel() : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0) {} + UniformSamplerGpuKernel() + : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0), remove_accidental_hits_(false) {} ~UniformSamplerGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -43,6 +44,16 @@ class UniformSamplerGpuKernel : public GpuKernel { T *sampled_candidates = GetDeviceAddress(outputs, 0); S *true_expected_count = GetDeviceAddress(outputs, 1); S *sampled_expected_count = GetDeviceAddress(outputs, 2); + if (remove_accidental_hits_) { + T *input = GetDeviceAddress(inputs, 0); + array_input_ = std::vector(input_size_, 0); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&array_input_[0], input, input_size_ * sizeof(T), + cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync sampled_candidates failed"); + for (const auto item : array_input_) { + set_input_.insert(item); + } + } int counter = Sampling(); float prob = Probability(); size_t sampled_candidates_size = num_sampled_ * sizeof(T); @@ -72,6 +83,7 @@ class UniformSamplerGpuKernel : public GpuKernel { unique_ = GetAttr(kernel_node, "unique"); range_max_ = GetAttr(kernel_node, "range_max"); int seed = GetAttr(kernel_node, "seed"); + remove_accidental_hits_ = GetAttr(kernel_node, "remove_accidental_hits"); if (seed == 0) seed = time(NULL); generator_.seed(seed); auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); @@ -80,6 +92,9 @@ class UniformSamplerGpuKernel : public GpuKernel { return false; } input_size_ = input_shape[0] * input_shape[1]; + if (num_sampled_ * num_true_ + static_cast(input_size_) > range_max_ * num_true_) { + remove_accidental_hits_ = false; + } InitSizeLists(); return true; } @@ -105,7 +120,8 @@ class UniformSamplerGpuKernel : public GpuKernel { while (picked < num_sampled_) { tmp = distribution(generator_); counter++; - if (set_container.find(tmp) == set_container.end()) { + if ((set_container.find(tmp) == set_container.end()) && + ((!remove_accidental_hits_) || set_input_.find(tmp) == set_input_.end())) { set_container.insert(tmp); sampled_candidates_.push_back(tmp); picked++; @@ -133,6 +149,9 @@ class UniformSamplerGpuKernel : public GpuKernel { bool unique_; int range_max_; size_t input_size_; + bool remove_accidental_hits_; + std::vector array_input_; + std::set set_input_; std::default_random_engine generator_; std::vector sampled_candidates_; std::vector input_size_list_; diff --git a/mindspore/nn/loss/__init__.py b/mindspore/nn/loss/__init__.py index 1327943ff2..873aa0d7f6 100644 --- a/mindspore/nn/loss/__init__.py +++ b/mindspore/nn/loss/__init__.py @@ -20,8 +20,9 @@ It shows how well the model works on a dataset and the optimization target which """ from .loss import L1Loss, MSELoss, SmoothL1Loss, \ - SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss + SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ + SampledSoftmaxLoss __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftmaxCrossEntropyWithLogits', 'BCELoss', - 'CosineEmbeddingLoss'] + 'CosineEmbeddingLoss', 'SampledSoftmaxLoss'] diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index bc6c995f86..4f02e4ab4b 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -263,6 +263,186 @@ class SoftmaxCrossEntropyWithLogits(_Loss): return self.get_loss(x) +class SampledSoftmaxLoss(_Loss): + r""" + Computes the sampled softmax training loss. + + Args: + num_sampled (int): The number of classes to randomly sample per batch. + num_classes (int): The number of possible classes. + num_true (int): The number of target classes per training example. + sampled_values (Tuple): Tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `*_candidate_sampler` function. + Default to None, `log_uniform_candidate_sampler` is applied. + remove_accidental_hits (bool): Whether to remove "accidental hits" + where a sampled class equals one of the target classes. Default is True. + seed (int): Random seed for candidate sampling. Default: 0 + reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". + If "none", do not perform reduction. Default: "None". + + Inputs: + - **weights** (Tensor) - Tensor of shape (C, dim). + - **bias** (Tensor) - Tensor of shape (C). The class biases. + - **labels** (Tensor) - Tensor of shape (N, num_true), type `int64`. The + target classes. + - **inputs** (Tensor) - Tensor of shape (N, dim). The forward activations of + the input network. + + Outputs: + Tensor, a tensor of shape (N) with the per-example sampled softmax losses. + + """ + + def __init__(self, num_sampled, num_classes, num_true=1, + sampled_values=None, remove_accidental_hits=True, seed=0, + reduction='none'): + super(SampledSoftmaxLoss, self).__init__() + self.num_sampled = num_sampled + self.num_classes = num_classes + self.num_true = num_true + self.sampled_values = sampled_values + self.remove_accidental_hits = remove_accidental_hits + self.seed = seed + self.sampler = P.UniformSampler( + num_true, + num_sampled, + True, + num_classes, + seed, + remove_accidental_hits) + self.cast = P.Cast() + self.reshape = P.Reshape() + self.shape = P.Shape() + self.exp = P.Exp() + self.log = P.Log() + self.slice_op = P.Slice() + self.matmul = P.MatMul(False, True) + self.gather_v2 = P.GatherV2() + self.reduce_max_true = P.ReduceMax(True) + self.reduce_sum = P.ReduceSum() + self.reduce_sum_true = P.ReduceSum(True) + self.concat_dim0 = P.Concat(0) + self.concat_dim1 = P.Concat(1) + self.ones_like = P.OnesLike() + self.zeros_like = P.ZerosLike() + self.mul = P.Mul() + self.expand_dims = P.ExpandDims() + + def construct(self, weights, biases, labels, inputs): + logits, labels = self._compute_sampled_logits( + weights=weights, + biases=biases, + labels=labels, + inputs=inputs, + num_true=self.num_true, + sampled_values=self.sampled_values, + subtract_log_q=True) + + x = self._softmax_cross_entropy(logits, labels) + return x + + def _softmax_cross_entropy(self, logits, targets): + stable_exp_logits = self.exp(logits - self.reduce_max_true(logits, 1)) + pred = stable_exp_logits / self.reduce_sum_true(stable_exp_logits, 1) + return -self.reduce_sum(targets * self.log(pred + 1.0e-20), 1) + + def _compute_sampled_logits(self, weights, + biases, + labels, + inputs, + num_true=1, + sampled_values=None, + subtract_log_q=True): + """Helper function for SampledSoftmaxLoss functions. + + Computes sampled output training logits and labels suitable + + Note: In the case where num_true > 1, we assign to each target class + the target probability 1 / num_true so that the target probabilities + sum to 1 per-example. + + Args: + weights (Tensor): Tensor of shape `[num_classes, dim]`. + biases (Tensor): Tensor of shape `[num_classes]`. + labels (Tensor): Tensor of shape `[batch_size, num_true]`. The target classes. + inputs (Tensor): Tensor of shape `[batch_size, dim]`. The forward + activations of the input network. + num_true (int): The number of target classes per training example. + sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `UniformSampler` function. + subtract_log_q: A `bool`. whether to subtract the log expected count of + the labels in the sample to get the logits of the true labels. + Default is True. + Returns: + out_logits: `Tensor` object with shape + `[batch_size, num_true + num_sampled]` + out_labels: A Tensor object with the same shape as `out_logits`. + """ + if not labels.dtype == mstype.int32: + labels = self.cast(labels, mstype.int32) + labels = self.reshape(labels, (-1, num_true)) + labels_flat = self.reshape(labels, (-1,)) + + # Sample the negative labels. + # sampled shape: [num_sampled] tensor + # true_expected_count shape = [batch_size, 1] tensor + # sampled_expected_count shape = [num_sampled] tensor + if sampled_values is None: + sampled_values = self.sampler(labels) + + (sampled, true_expected_count, sampled_expected_count) = sampled_values + + if not sampled.dtype == mstype.int32: + sampled = self.cast(sampled, mstype.int32) + all_ids = self.concat_dim0((labels_flat, sampled)) + all_w = self.gather_v2(weights, all_ids, 0) + + n_true = self.shape(labels_flat)[0] + n_sampled = self.shape(sampled)[0] + n_dim = self.shape(all_w)[1] + + # true_w shape is [batch_size * num_true, dim] + true_w = self.slice_op(all_w, [0, 0], [n_true, n_dim]) + sampled_w = self.slice_op(all_w, [n_true, 0], [n_sampled, n_dim]) + sampled_logits = self.matmul(inputs, sampled_w) + + all_b = self.gather_v2(biases, all_ids, 0) + true_b = self.slice_op(all_b, [0], [n_true]) + sampled_b = self.slice_op(all_b, [n_true], [n_sampled]) + + # inputs shape is [batch_size, dim] + # true_w shape is [batch_size * num_true, dim] + # row_wise_dots is [batch_size, num_true, dim] + new_true_w_shape = (-1, num_true, n_dim) + row_wise_dots = self.mul(self.expand_dims(inputs, 1), + self.reshape(true_w, new_true_w_shape)) + + # We want the row-wise dot plus biases which yields a + # [batch_size, num_true] tensor of true_logits. + dots_as_matrix = self.reshape(row_wise_dots, (-1, n_dim)) + true_logits = self.reshape(self.reduce_sum(dots_as_matrix, 1), (-1, num_true)) + true_b = self.reshape(true_b, (-1, num_true)) + true_logits += true_b + sampled_logits += sampled_b + + if subtract_log_q: + # Subtract log of Q(l), prior probability that l appears in sampled. + true_logits -= self.log(true_expected_count) + sampled_logits -= self.log(sampled_expected_count) + + # Construct output logits and labels. The true labels/logits start at col 0. + out_logits = self.concat_dim1((true_logits, sampled_logits)) + + # true_logits is a float tensor, ones_like(true_logits) is a float + # tensor of ones. We then divide by num_true to ensure the per-example + # labels sum to 1.0, i.e. form a proper probability distribution. + out_labels = self.concat_dim1(( + self.ones_like(true_logits) / num_true, + self.zeros_like(sampled_logits) + )) + return out_logits, out_labels + + class BCELoss(_Loss): r""" BCELoss creates a criterion to measure the Binary Cross Entropy between the true labels and predicted labels. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 3c0d460afa..4bf5663c0f 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5831,6 +5831,7 @@ class UniformSampler(PrimitiveWithInfer): unique (bool): Whether all sampled classes in a batch are unique. range_max (int): The number of possible classes. seed (int): Random seed, must be non-negative. Default: 0. + remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. Inputs: true_classes (int): A tensor. The target classes with a tensor shape of (batch_size, num_true). @@ -5850,13 +5851,14 @@ class UniformSampler(PrimitiveWithInfer): [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] """ @prim_attr_register - def __init__(self, num_true, num_sampled, unique, range_max, seed=0): + def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False): """Initialize UniformSampler""" validator.check_value_type("num_true", num_true, [int], self.name) validator.check_value_type("num_sampled", num_sampled, [int], self.name) validator.check_value_type("unique", unique, [bool], self.name) validator.check_value_type("range_max", range_max, [int], self.name) validator.check_value_type("seed", seed, [int], self.name) + validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name) validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name) if unique: validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name) diff --git a/tests/st/ops/gpu/test_sampled_softmax_loss_op.py b/tests/st/ops/gpu/test_sampled_softmax_loss_op.py new file mode 100644 index 0000000000..7c0f709a97 --- /dev/null +++ b/tests/st/ops/gpu/test_sampled_softmax_loss_op.py @@ -0,0 +1,137 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor + +def generate_test_data(num_classes, batch_size, sampled): + dim = 10 + weights_s = np.linspace(start=1, stop=num_classes * dim, num=num_classes * dim) + weights_s = np.reshape(weights_s, (num_classes, dim)).astype(np.float32) / 100.0 + biases_s = np.linspace(start=1, stop=num_classes, num=num_classes) + biases_s = np.reshape(biases_s, (num_classes,)).astype(np.float32) / 100.0 + hidden_acts_s = np.linspace(start=1, stop=batch_size * dim, num=batch_size * dim) + hidden_acts_s = np.reshape( + hidden_acts_s, (batch_size, dim)).astype(np.float32) / 100.0 + + true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32) + sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32) + sampled_values = (Tensor(sampled), Tensor(true_exp), Tensor(sampled_exp)) + return weights_s, biases_s, hidden_acts_s, sampled_values + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sampled_softmax_loss_assigned_sampler(): + np.random.seed(0) + num_classes = 7 + batch_size = 3 + labels = [0, 1, 2] + (weights, biases, hidden_acts, sampled_vals) = generate_test_data( + num_classes=num_classes, + batch_size=batch_size, + sampled=[4, 0, 2, 3]) + + def case_not_remove_accidental_hits(): + loss = nn.SampledSoftmaxLoss( + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals, + remove_accidental_hits=False) + + got_sampled_softmax_loss = loss(Tensor(weights), Tensor(biases), + Tensor(labels), Tensor(hidden_acts)) + exp_sampled_softmax_loss = np.array( + [1.7318448, 1.8015041, 1.7211525]).astype(np.float32) + assert np.allclose(got_sampled_softmax_loss.asnumpy(), + exp_sampled_softmax_loss) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + case_not_remove_accidental_hits() + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + case_not_remove_accidental_hits() + + (weights, biases, hidden_acts, sampled_vals) = generate_test_data( + num_classes=num_classes, + batch_size=batch_size, + sampled=[4, 5, 6, 3]) + + def case_remove_accidental_hits(): + loss = nn.SampledSoftmaxLoss( + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals, + remove_accidental_hits=True) + + got_sampled_softmax_loss = loss(Tensor(weights), Tensor(biases), + Tensor(labels), Tensor(hidden_acts)) + exp_sampled_softmax_loss = np.array( + [[1.85211, 2.10999, 2.20862]]).astype(np.float32) + assert np.allclose(got_sampled_softmax_loss.asnumpy(), + exp_sampled_softmax_loss) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + case_remove_accidental_hits() + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + case_remove_accidental_hits() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sampled_softmax_loss_none_sampler(): + np.random.seed(0) + num_classes = 7 + batch_size = 3 + labels = [0, 1, 2] + (weights, biases, hidden_acts, _) = generate_test_data( + num_classes=num_classes, + batch_size=batch_size, + sampled=[4, 0, 2, 3]) + + def case_no_sampler(): + loss = nn.SampledSoftmaxLoss( + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=None, + seed=1, + remove_accidental_hits=False) + + got_sampled_softmax_loss = loss(Tensor(weights), Tensor(biases), + Tensor(labels), Tensor(hidden_acts)) + exp_sampled_softmax_loss = np.array( + [1.7345718, 1.820291, 1.7704818]).astype(np.float32) + assert np.allclose(got_sampled_softmax_loss.asnumpy(), + exp_sampled_softmax_loss) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + case_no_sampler() + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + case_no_sampler() + + +if __name__ == "__main__": + test_sampled_softmax_loss_assigned_sampler() + test_sampled_softmax_loss_none_sampler() diff --git a/tests/st/ops/gpu/test_uniform_sampler_op.py b/tests/st/ops/gpu/test_uniform_sampler_op.py index f0625d8742..a3650a27a8 100644 --- a/tests/st/ops/gpu/test_uniform_sampler_op.py +++ b/tests/st/ops/gpu/test_uniform_sampler_op.py @@ -35,6 +35,25 @@ def uniform_sampler(x, num_true, num_sampled, unique, range_max): out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) return out1.shape, out2.shape, out3.shape + +class UniformSamplerHitNet(nn.Cell): + def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits): + super(UniformSamplerHitNet, self).__init__() + self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max, seed=seed, + remove_accidental_hits=remove_accidental_hits) + + def construct(self, x): + return self.sampler(x) + + +def uniform_sampler_hit(x, num_true, num_sampled, unique, range_max, seed, + remove_accidental_hits): + uniform_sampler_net = UniformSamplerHitNet(num_true, num_sampled, unique, range_max, + seed, remove_accidental_hits) + out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) + return out1, out2, out3 + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -114,3 +133,23 @@ def test_uniform_sampler_large_random(): np.testing.assert_array_equal(ms1, expected_1) np.testing.assert_array_equal(ms2, expected_2) np.testing.assert_array_equal(ms3, expected_3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_uniform_sampler_unique_1_true_hit(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, False) + expected_1 = np.array([0, 3, 1]) + np.testing.assert_array_equal(ms1.asnumpy(), expected_1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_uniform_sampler_unique_1_true_no_hit(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, True) + expected_1 = np.array([0, 3, 2]) + np.testing.assert_array_equal(ms1.asnumpy(), expected_1)