Merge pull request #15609 from xuezhong/add_sample_logits_op
add sample_logits and sampled_softmax_with_cross_entropy oprevert-15774-anakin_subgraph_engine
commit
1dad36f6aa
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/sample_prob.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template class SampleWithProb<platform::CPUDeviceContext, float>;
|
||||
template class SampleWithProb<platform::CPUDeviceContext, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,161 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <thrust/random.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/operators/math/sample_prob.h"
|
||||
#include "paddle/fluid/operators/math/sampler.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
__device__ T gpu_adjust_prob(const T prob, const int num_samples,
|
||||
const int num_tries) {
|
||||
if (num_samples == num_tries) {
|
||||
return prob * num_samples;
|
||||
} else {
|
||||
return -expm1(num_tries * log1p(-prob));
|
||||
}
|
||||
}
|
||||
|
||||
class GPULogUniformSampler {
|
||||
public:
|
||||
__device__ int64_t Sample(float random, const int range,
|
||||
const float log_range) const;
|
||||
__device__ float Probability(int64_t value, const float log_range) const;
|
||||
};
|
||||
|
||||
__device__ int64_t GPULogUniformSampler::Sample(float random, const int range,
|
||||
const float log_range) const {
|
||||
// Got Log Uniform distribution from uniform distribution by
|
||||
// inverse_transform_sampling method
|
||||
const int64_t value = static_cast<int64_t>(exp(random * log_range)) - 1;
|
||||
// Mathematically, value should be <= range_, but might not be due to some
|
||||
// floating point roundoff, so we mod by range_.
|
||||
return value % range;
|
||||
}
|
||||
|
||||
__device__ float GPULogUniformSampler::Probability(
|
||||
int64_t value, const float log_range) const {
|
||||
// Given f(x) = 1/[(x+1) * log_range_]
|
||||
// The value's probability is integral of f(x) from value to (value + 1)
|
||||
return (log((value + 2.0) / (value + 1.0))) / log_range;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SamplingCondidate(
|
||||
const size_t n, const int num_tries, const int range, const float log_range,
|
||||
const int num_true, const std::size_t num_samples,
|
||||
const int64_t* label_data, int64_t* samples_data, T* probabilities_data) {
|
||||
const int num_sampled_classes = num_true + num_samples;
|
||||
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int step_size = 0;
|
||||
GPULogUniformSampler sampler;
|
||||
|
||||
for (; idx < n; idx += blockDim.x * gridDim.x) {
|
||||
int col_idx = idx % num_sampled_classes;
|
||||
int row_idx = idx / num_sampled_classes;
|
||||
if (col_idx < num_true) {
|
||||
samples_data[idx] = label_data[row_idx * num_true + col_idx];
|
||||
} else {
|
||||
samples_data[idx] = samples_data[col_idx];
|
||||
}
|
||||
probabilities_data[idx] = sampler.Probability(samples_data[idx], log_range);
|
||||
probabilities_data[idx] =
|
||||
gpu_adjust_prob(probabilities_data[idx], num_samples, num_tries);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int UniqSampler(const Sampler& sampler, const std::size_t num_samples,
|
||||
int64_t* samples_data) {
|
||||
// sample num_samles unique samples for an example, note that they are not
|
||||
// all negative samples
|
||||
std::unordered_set<int64_t> tmp_samples;
|
||||
tmp_samples.clear();
|
||||
int num_tries = 0;
|
||||
int j = 0;
|
||||
while (j < num_samples) {
|
||||
++num_tries;
|
||||
auto v = sampler.Sample();
|
||||
auto insert_ok = tmp_samples.insert(v).second;
|
||||
if (!insert_ok) {
|
||||
continue;
|
||||
}
|
||||
samples_data[j] = v;
|
||||
++j;
|
||||
}
|
||||
return num_tries;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GPUSampleWithProb<T>::operator()(
|
||||
const platform::CUDADeviceContext& context, const int seed,
|
||||
const int dict_size, const bool uniq, const std::size_t num_samples,
|
||||
const Tensor* L, Tensor* S, Tensor* P) {
|
||||
// UNDERSTAND: dimension issues
|
||||
const auto lbl_dim = L->dims();
|
||||
const int batch_size = lbl_dim[0];
|
||||
const int num_true = lbl_dim[1];
|
||||
const int num_sampled_classes = num_true + num_samples;
|
||||
framework::DDim ret_dim{batch_size, num_sampled_classes};
|
||||
|
||||
// UNDERSTAND: raw data view
|
||||
const int64_t* label_data = L->data<int64_t>();
|
||||
int64_t* samples_data = S->data<int64_t>();
|
||||
T* probabilities_data = P->data<T>();
|
||||
|
||||
int s_size = num_samples;
|
||||
framework::DDim s_dim{s_size};
|
||||
Tensor s;
|
||||
int64_t* s_data = s.mutable_data<int64_t>(s_dim, platform::CPUPlace());
|
||||
|
||||
math::LogUniformSampler sampler(dict_size, seed);
|
||||
|
||||
int range = dict_size;
|
||||
float log_range = log(range + 1);
|
||||
|
||||
int num_tries = UniqSampler<T>(sampler, num_samples, s_data);
|
||||
VLOG(1) << "num_tries: " << num_tries;
|
||||
PADDLE_ENFORCE(cudaMemcpy(samples_data + num_true, s_data,
|
||||
sizeof(int64_t) * num_samples,
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
int threads = 512;
|
||||
const size_t size = batch_size * num_sampled_classes;
|
||||
int grid = (batch_size * num_sampled_classes + threads - 1) / threads;
|
||||
SamplingCondidate<T><<<grid, threads, 0, context.stream()>>>(
|
||||
size, num_tries, range, log_range, num_true, num_samples, label_data,
|
||||
samples_data, probabilities_data);
|
||||
}
|
||||
|
||||
template class GPUSampleWithProb<float>;
|
||||
template class GPUSampleWithProb<double>;
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,118 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/operators/math/sampler.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
/* UNDERSTAND: utility function to adjust probability for unique sampling,
|
||||
return whatever as it is if not using unique samping */
|
||||
template <typename T>
|
||||
static T adjust_prob(const T prob, const int num_samples, const int num_tries) {
|
||||
if (num_samples == num_tries) {
|
||||
return prob * num_samples;
|
||||
} else {
|
||||
return -expm1(num_tries * log1p(-prob));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SampleWithProb {
|
||||
public:
|
||||
void operator()(const DeviceContext& context, const Sampler& sampler,
|
||||
const std::size_t num_samples, const Tensor* L, Tensor* S,
|
||||
Tensor* P) {
|
||||
// UNDERSTAND: dimension issues
|
||||
const auto lbl_dim = L->dims();
|
||||
const int batch_size = lbl_dim[0];
|
||||
const int num_true = lbl_dim[1];
|
||||
const int num_sampled_classes = num_true + num_samples;
|
||||
framework::DDim ret_dim{batch_size, num_sampled_classes};
|
||||
|
||||
// UNDERSTAND: raw data view
|
||||
const int64_t* label_data = L->data<int64_t>();
|
||||
int64_t* samples_data =
|
||||
S->mutable_data<int64_t>(ret_dim, context.GetPlace());
|
||||
T* probabilities_data = P->mutable_data<T>(ret_dim, context.GetPlace());
|
||||
|
||||
// temp sets for unique sampling
|
||||
std::unordered_set<int64_t> tmp_samples;
|
||||
int j = 0; // column index
|
||||
// add true labels, not that efficient
|
||||
while (j < num_true) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
auto samples_index = i * num_sampled_classes + j;
|
||||
auto v = label_data[i * num_true + j];
|
||||
samples_data[samples_index] = v;
|
||||
probabilities_data[samples_index] = sampler.Probability(v);
|
||||
}
|
||||
++j;
|
||||
}
|
||||
|
||||
// sample num_samles unique samples for an example, note that they are not
|
||||
// all negative samples
|
||||
tmp_samples.clear();
|
||||
int num_tries = 0;
|
||||
while (j < num_sampled_classes) {
|
||||
++num_tries;
|
||||
auto v = sampler.Sample();
|
||||
auto insert_ok = tmp_samples.insert(v).second;
|
||||
if (!insert_ok) {
|
||||
continue;
|
||||
}
|
||||
auto p = sampler.Probability(v);
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
auto samples_index = i * num_sampled_classes + j;
|
||||
samples_data[samples_index] = v;
|
||||
probabilities_data[samples_index] = p;
|
||||
}
|
||||
++j;
|
||||
}
|
||||
|
||||
// compute Q(y|x), because of unique sampling, probabilities need to be
|
||||
// adjusted
|
||||
for (int k = 0; k < num_sampled_classes; ++k) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
auto samples_index = i * num_sampled_classes + k;
|
||||
probabilities_data[samples_index] = adjust_prob(
|
||||
probabilities_data[samples_index], num_samples, num_tries);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
template <typename T>
|
||||
class GPUSampleWithProb {
|
||||
public:
|
||||
void operator()(const platform::CUDADeviceContext& context, const int seed,
|
||||
const int dict_size, const bool uniq,
|
||||
const std::size_t num_samples, const Tensor* L, Tensor* S,
|
||||
Tensor* P);
|
||||
};
|
||||
#endif
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,225 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/sample_logits_op.h"
|
||||
#include "paddle/fluid/operators/math/sample_prob.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Logits",
|
||||
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
|
||||
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
|
||||
"and K is the class number.");
|
||||
AddInput("Labels",
|
||||
"(Tensor) The ground truth which is a 2-D tensor. Labels is a "
|
||||
"Tensor<int64> with shape [N x NT], where NT is the number of"
|
||||
"true labels for each example.");
|
||||
AddInput("CustomizedSamples",
|
||||
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N, "
|
||||
"NT + S],"
|
||||
" where N is the batch size, NT is the number of true labels "
|
||||
"and S is the number of negtive sample for each example."
|
||||
"The first NT elements of each row should be the same with true "
|
||||
"labels, "
|
||||
"followed by S custom negtive samples. This tensor"
|
||||
"is only used when use_customized_samples is true.")
|
||||
.AsDispensable();
|
||||
AddInput(
|
||||
"CustomizedProbabilities",
|
||||
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
|
||||
"The tensor has the same shape with CustomSamples,"
|
||||
"and each element represents probability of element in CustomSamples. "
|
||||
"This "
|
||||
"tensor is only used when use_customized_samples is true.")
|
||||
.AsDispensable();
|
||||
AddOutput("Samples",
|
||||
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N, "
|
||||
"NT + S]."
|
||||
"The outputs value of sampler, including NT true lables and S "
|
||||
"negetive samples "
|
||||
"for each example. This will be used in"
|
||||
"backward calculation.")
|
||||
.AsIntermediate();
|
||||
AddOutput(
|
||||
"Probabilities",
|
||||
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
|
||||
"The probabilites of sampled positive and negtive labels.")
|
||||
.AsIntermediate();
|
||||
AddOutput("SampledLogits",
|
||||
"(Tensor, default: Tensor<float>), A 2-D tensor with shape"
|
||||
"[N, NT + S]. The outputs value of sampled logits, which will be"
|
||||
"used in backward propagation.")
|
||||
.AsIntermediate();
|
||||
AddOutput(
|
||||
"SampledLabels",
|
||||
"(Tensor, default: Tensor<int64>), A 2-D tensor. The sampled labels"
|
||||
"with shape [N, NT]. The tonsor contains hard labels as input to "
|
||||
" softmax op, that is 0, 1, ..., NT-1 because of the first NT elements"
|
||||
" of Sampels are positive lables.");
|
||||
AddAttr<bool>(
|
||||
"use_customized_samples",
|
||||
"An indicator whether to use customized samples with probabilities, if "
|
||||
"True"
|
||||
"the operator will use customized samples and customized probabilities"
|
||||
"otherwise, the operator will generate them by itself.")
|
||||
.SetDefault(false);
|
||||
AddAttr<bool>(
|
||||
"uniq",
|
||||
"An indicator whether to sample non-repetitive negtive labels, if True"
|
||||
"the operator will sample negtive labels without replacement."
|
||||
"Otherwise, the operator will sample negtive labels with replacement.")
|
||||
.SetDefault(true);
|
||||
AddAttr<bool>(
|
||||
"remove_accidental_hits",
|
||||
"An indicator whether to remove accidental hits when samples hits true"
|
||||
"labels, the removal is implemented by subtracting the corresponding"
|
||||
"logits by float_max to subpress their softmax to be zero.")
|
||||
.SetDefault(true);
|
||||
AddAttr<int>("num_samples", "The number of negative samples.");
|
||||
AddAttr<int>("seed", "Random seed for generating samples").SetDefault(0);
|
||||
|
||||
AddComment(R"DOC(
|
||||
"""
|
||||
Computes sampled output training logits and labels suitable for implementing
|
||||
sampled softmax.
|
||||
"""
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SampleLogitsOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Logits"),
|
||||
"Input(Logits) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
||||
"Input(Labels) should be not null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Samples"),
|
||||
"Output(Samples) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Probabilities"),
|
||||
"Output(Probabilities) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("SampledLogits"),
|
||||
"Output(SampledLogits) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("SampledLabels"),
|
||||
"Output(SampledLabels) should be not null.");
|
||||
|
||||
auto logits_dims = ctx->GetInputDim("Logits");
|
||||
auto labels_dims = ctx->GetInputDim("Labels");
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
logits_dims.size(), 2UL,
|
||||
"The logits of softmax_with_cross_entropy should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
|
||||
"The labels should be a 2-D tensor.");
|
||||
|
||||
const int num_samples = ctx->Attrs().Get<int>("num_samples");
|
||||
const int num_sampled_classes = labels_dims[1] + num_samples;
|
||||
ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes});
|
||||
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
|
||||
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});
|
||||
ctx->SetOutputDim("SampledLabels", {logits_dims[0], labels_dims[1]});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits"));
|
||||
framework::OpKernelType kt =
|
||||
framework::OpKernelType(data_type, ctx.device_context());
|
||||
return kt;
|
||||
}
|
||||
};
|
||||
|
||||
// UNDERSTAND: InferShape for Grad
|
||||
class SampleLogitsOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Logits"),
|
||||
"Input(Logits) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
||||
"Input(Labels) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Samples"),
|
||||
"Input(Samples) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("SampledLogits"),
|
||||
"Input(SampledLogits) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("SampledLogits")),
|
||||
"Input(SampledLogits@Grad) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
|
||||
"Output(Logits@Grad) should be not null.");
|
||||
|
||||
auto logit_dims = ctx->GetInputDim("Logits");
|
||||
auto label_dims = ctx->GetInputDim("Labels");
|
||||
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
|
||||
"The label should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL,
|
||||
"The logits should be a 2-D tensor.");
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("Logits"),
|
||||
ctx->GetInputDim("Logits"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = framework::GetDataTypeOfVar(
|
||||
ctx.InputVar(framework::GradVarName("SampledLogits")));
|
||||
framework::OpKernelType kt =
|
||||
framework::OpKernelType(data_type, ctx.device_context());
|
||||
return kt;
|
||||
}
|
||||
};
|
||||
|
||||
// UNDERSTAND: what's the rule for making a GradMaker TODO
|
||||
class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
auto* grad_op = new framework::OpDesc();
|
||||
grad_op->SetType("sample_logits_grad");
|
||||
grad_op->SetInput("Logits", Input("Logits"));
|
||||
grad_op->SetInput("Labels", Input("Labels"));
|
||||
grad_op->SetInput("Samples", Output("Samples"));
|
||||
grad_op->SetInput("SampledLogits", Output("SampledLogits"));
|
||||
grad_op->SetInput(framework::GradVarName("SampledLogits"),
|
||||
OutputGrad("SampledLogits"));
|
||||
grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
|
||||
grad_op->SetAttrMap(Attrs());
|
||||
return std::unique_ptr<framework::OpDesc>(grad_op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(sample_logits, ops::SampleLogitsOp, ops::SampleLogitsOpMaker,
|
||||
ops::SampleLogitsGradMaker);
|
||||
REGISTER_OPERATOR(sample_logits_grad, ops::SampleLogitsOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(sample_logits, ops::SampleLogitsKernel<float>,
|
||||
ops::SampleLogitsKernel<double>);
|
||||
REGISTER_OP_CPU_KERNEL(sample_logits_grad, ops::SampleLogitsGradKernel<float>,
|
||||
ops::SampleLogitsGradKernel<double>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,245 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/operators/math/sample_prob.h"
|
||||
#include "paddle/fluid/operators/math/softmax.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename T>
|
||||
struct TolerableValue {
|
||||
HOSTDEVICE T operator()(const T& x) const {
|
||||
PADDLE_ASSERT(std::is_floating_point<T>::value);
|
||||
const T kApproInf = 1e20;
|
||||
if (x == INFINITY) return kApproInf;
|
||||
if (x == -INFINITY) return -kApproInf;
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
// UNDERSTAND: something like take_along_axis in numpy.
|
||||
template <typename T>
|
||||
static void CPUTakeAlongD1(const platform::DeviceContext& ctx,
|
||||
const framework::Tensor& array,
|
||||
const framework::Tensor& index,
|
||||
framework::Tensor* value) {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
|
||||
// UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K)
|
||||
PADDLE_ENFORCE(index.dims().size() == 2 && array.dims().size() == 2 &&
|
||||
index.dims()[0] == array.dims()[0] &&
|
||||
index.dims() == value->dims());
|
||||
|
||||
const auto batch_size = index.dims()[0];
|
||||
const auto num_take = index.dims()[1];
|
||||
const auto array_dims = array.dims();
|
||||
const auto idx_dims = index.dims();
|
||||
|
||||
// UNDERSTAND: no allocations here
|
||||
const T* p_array = array.data<T>();
|
||||
const int64_t* p_index = index.data<int64_t>();
|
||||
T* p_value = value->data<T>();
|
||||
|
||||
// src slice size
|
||||
const auto array_slice_size = array_dims[1];
|
||||
|
||||
// index slice size
|
||||
const auto idx_slice_size = idx_dims[1];
|
||||
const auto value_slice_size = idx_slice_size;
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
for (int j = 0; j < num_take; ++j) {
|
||||
auto array_index = p_index[i * idx_slice_size + j];
|
||||
p_value[i * value_slice_size + j] =
|
||||
p_array[i * array_slice_size + array_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate
|
||||
// indices, scatter is done in += way.
|
||||
template <typename T>
|
||||
static void CPUPutAlongD1(const platform::DeviceContext& ctx,
|
||||
framework::Tensor* array,
|
||||
const framework::Tensor& index,
|
||||
const framework::Tensor& value) {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
|
||||
// UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K)
|
||||
PADDLE_ENFORCE(index.dims().size() == 2 && array->dims().size() == 2 &&
|
||||
index.dims()[0] == array->dims()[0] &&
|
||||
index.dims() == value.dims());
|
||||
const auto batch_size = index.dims()[0];
|
||||
const auto num_put = index.dims()[1];
|
||||
auto array_dims = array->dims();
|
||||
auto idx_dims = index.dims();
|
||||
|
||||
// UNDERSTAND: no allocations here
|
||||
T* p_array = array->data<T>();
|
||||
const int64_t* p_index = index.data<int64_t>();
|
||||
const T* p_value = value.data<T>();
|
||||
|
||||
// slice sizes
|
||||
const auto array_slice_size = array_dims[1];
|
||||
const auto idx_slice_size = idx_dims[1];
|
||||
const auto value_slice_size = idx_slice_size;
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
for (int j = 0; j < num_put; ++j) {
|
||||
auto array_index = p_index[i * idx_slice_size + j];
|
||||
p_array[i * array_slice_size + array_index] +=
|
||||
p_value[i * value_slice_size + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UNDERSTAND: compute accidentdal hits from samples and minus corresponding
|
||||
// logits by a float max, here 1e20
|
||||
template <typename T>
|
||||
static void compute_remove_accidental_hits(const platform::DeviceContext& ctx,
|
||||
framework::Tensor* sampled_logits,
|
||||
const framework::Tensor& samples,
|
||||
const int num_true) {
|
||||
const auto batch_size = sampled_logits->dims()[0];
|
||||
const auto num_sampled_classes = sampled_logits->dims()[1];
|
||||
T* sampled_logits_data = sampled_logits->data<T>();
|
||||
const auto samples_data = samples.data<int64_t>();
|
||||
|
||||
std::unordered_set<int64_t> tmp_true_labels;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
tmp_true_labels.clear();
|
||||
tmp_true_labels.insert(samples_data + i * num_sampled_classes,
|
||||
samples_data + i * num_sampled_classes + num_true);
|
||||
for (int j = num_true; j < num_sampled_classes; ++j) {
|
||||
const auto idx = i * num_sampled_classes + j;
|
||||
if (tmp_true_labels.find(samples_data[idx]) != tmp_true_labels.end())
|
||||
sampled_logits_data[idx] -= 1e20;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class SampleLogitsKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
using Tensor = framework::Tensor;
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()),
|
||||
"This kernel only runs on CPU.");
|
||||
VLOG(3) << "Enter SampleLogitsKernel";
|
||||
// get necessary inputs
|
||||
const Tensor* logits = context.Input<Tensor>("Logits");
|
||||
const Tensor* labels = context.Input<Tensor>("Labels");
|
||||
|
||||
// get necessary outputs
|
||||
Tensor* samples = context.Output<Tensor>("Samples");
|
||||
Tensor* probabilities = context.Output<Tensor>("Probabilities");
|
||||
Tensor* sampled_logits = context.Output<Tensor>("SampledLogits");
|
||||
Tensor* sampled_labels = context.Output<Tensor>("SampledLabels");
|
||||
|
||||
// shapes
|
||||
const auto batch_size = logits->dims()[0];
|
||||
const auto num_classes = logits->dims()[1];
|
||||
const auto labels_dim = labels->dims();
|
||||
const auto num_true = labels_dim[1];
|
||||
const auto samples_dim = samples->dims();
|
||||
|
||||
// attrs
|
||||
const auto num_samples = context.Attr<int>("num_samples");
|
||||
const bool use_customized_samples =
|
||||
context.Attr<bool>("use_customized_samples");
|
||||
const bool remove_accidental_hits =
|
||||
context.Attr<bool>("remove_accidental_hits");
|
||||
|
||||
// device contexts
|
||||
auto& dev_ctx =
|
||||
context.template device_context<platform::CPUDeviceContext>();
|
||||
|
||||
// UNDERSTAND: allocate memories for temporaries
|
||||
sampled_logits->mutable_data<T>(samples_dim, context.GetPlace());
|
||||
auto sampled_labels_data =
|
||||
sampled_labels->mutable_data<int64_t>(labels_dim, context.GetPlace());
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
for (int j = 0; j < num_true; ++j) {
|
||||
sampled_labels_data[i * num_true + j] = j;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_customized_samples) {
|
||||
const Tensor* customized_samples =
|
||||
context.Input<Tensor>("CustomizedSamples");
|
||||
const Tensor* customized_probabilities =
|
||||
context.Input<Tensor>("CustomizedProbabilities");
|
||||
samples->ShareDataWith(*customized_samples);
|
||||
probabilities->ShareDataWith(*customized_probabilities);
|
||||
} else {
|
||||
samples->mutable_data<int64_t>(context.GetPlace());
|
||||
probabilities->mutable_data<T>(samples_dim, context.GetPlace());
|
||||
// UNDERSTAND: sampling
|
||||
const auto seed = context.Attr<int>("seed");
|
||||
auto sampler_with_prob =
|
||||
math::SampleWithProb<platform::CPUDeviceContext, T>();
|
||||
sampler_with_prob(dev_ctx, math::LogUniformSampler(num_classes, seed),
|
||||
num_samples, labels, samples, probabilities);
|
||||
}
|
||||
|
||||
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
|
||||
CPUTakeAlongD1<T>(dev_ctx, *logits, *samples, sampled_logits);
|
||||
if (remove_accidental_hits) {
|
||||
compute_remove_accidental_hits<T>(dev_ctx, sampled_logits, *samples,
|
||||
num_true);
|
||||
}
|
||||
|
||||
// subtracted sampled logits with logQ(y|x)
|
||||
auto probs = EigenMatrix<T>::From(*probabilities);
|
||||
auto smp_logits = EigenMatrix<T>::From(*sampled_logits);
|
||||
smp_logits.device(*dev_ctx.eigen_device()) =
|
||||
(smp_logits - probs.log().unaryExpr(TolerableValue<T>()))
|
||||
.unaryExpr(TolerableValue<T>());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SampleLogitsGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
using Tensor = framework::Tensor;
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto logits_grad = context.Output<Tensor>(framework::GradVarName("Logits"));
|
||||
const Tensor* samples = context.Input<Tensor>("Samples");
|
||||
const Tensor* sampled_logits_grad =
|
||||
context.Input<Tensor>(framework::GradVarName("SampledLogits"));
|
||||
logits_grad->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto& dev_ctx =
|
||||
context.template device_context<platform::CPUDeviceContext>();
|
||||
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
|
||||
set_zero(dev_ctx, logits_grad, static_cast<T>(0));
|
||||
|
||||
// UNDERSTAND: scatter it back to logit_grad
|
||||
CPUPutAlongD1<T>(dev_ctx, logits_grad, *samples, *sampled_logits_grad);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue