parent
b5ebca47a3
commit
58ad40cc15
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/sample_prob.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template class SampleWithProb<platform::CPUDeviceContext, float>;
|
||||
template class SampleWithProb<platform::CPUDeviceContext, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,188 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <thrust/random.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/operators/math/sample_prob.h"
|
||||
#include "paddle/fluid/operators/math/sampler.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
__device__ T gpu_adjust_prob(const T prob, const int num_samples,
|
||||
const int num_tries) {
|
||||
if (num_samples == num_tries) {
|
||||
return prob * num_samples;
|
||||
} else {
|
||||
return -expm1(num_tries * log1p(-prob));
|
||||
}
|
||||
}
|
||||
|
||||
class GPULogUniformSampler {
|
||||
public:
|
||||
__device__ int64_t Sample(float random, const int range,
|
||||
const float log_range) const;
|
||||
__device__ float Probability(int64_t value, const float log_range) const;
|
||||
};
|
||||
|
||||
__device__ int64_t GPULogUniformSampler::Sample(float random, const int range,
|
||||
const float log_range) const {
|
||||
// Got Log Uniform distribution from uniform distribution by
|
||||
// inverse_transform_sampling method
|
||||
const int64_t value = static_cast<int64_t>(exp(random * log_range)) - 1;
|
||||
// Mathematically, value should be <= range_, but might not be due to some
|
||||
// floating point roundoff, so we mod by range_.
|
||||
return value % range;
|
||||
}
|
||||
|
||||
__device__ float GPULogUniformSampler::Probability(
|
||||
int64_t value, const float log_range) const {
|
||||
// Given f(x) = 1/[(x+1) * log_range_]
|
||||
// The value's probability is integral of f(x) from value to (value + 1)
|
||||
return (log((value + 2.0) / (value + 1.0))) / log_range;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SamplingCondidate(
|
||||
const size_t n, const int num_tries, const int range, const float log_range,
|
||||
const int num_true, const std::size_t num_samples,
|
||||
const int64_t* label_data, int64_t* samples_data, T* probabilities_data) {
|
||||
const int num_sampled_classes = num_true + num_samples;
|
||||
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int step_size = 0;
|
||||
GPULogUniformSampler sampler;
|
||||
|
||||
for (; idx < n; idx += blockDim.x * gridDim.x) {
|
||||
int col_idx = idx % num_sampled_classes;
|
||||
int row_idx = idx / num_sampled_classes;
|
||||
if (col_idx < num_true) {
|
||||
samples_data[idx] = label_data[row_idx * num_true + col_idx];
|
||||
} else {
|
||||
samples_data[idx] = samples_data[col_idx];
|
||||
}
|
||||
probabilities_data[idx] = sampler.Probability(samples_data[idx], log_range);
|
||||
probabilities_data[idx] =
|
||||
gpu_adjust_prob(probabilities_data[idx], num_samples, num_tries);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int UniqSampler(const Sampler& sampler, const std::size_t num_samples,
|
||||
int64_t* samples_data) {
|
||||
// sample num_samles unique samples for an example, note that they are not
|
||||
// all negative samples
|
||||
std::unordered_set<int64_t> tmp_samples;
|
||||
tmp_samples.clear();
|
||||
int num_tries = 0;
|
||||
int j = 0;
|
||||
while (j < num_samples) {
|
||||
++num_tries;
|
||||
auto v = sampler.Sample();
|
||||
auto insert_ok = tmp_samples.insert(v).second;
|
||||
if (!insert_ok) {
|
||||
continue;
|
||||
}
|
||||
samples_data[j] = v;
|
||||
++j;
|
||||
}
|
||||
return num_tries;
|
||||
}
|
||||
/*
|
||||
template <typename T>
|
||||
void Print(Tensor & t, std::string name) {
|
||||
if (!FLAGS_debug_print) {
|
||||
return;
|
||||
}
|
||||
VLOG(1) << "qxz print "<< name;
|
||||
VLOG(1) << name << "size = " << t.numel();
|
||||
size_t size = t.numel();
|
||||
type *d = t.data<type>();
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
std::vector<type> vec;
|
||||
platform::DeviceContextPool::Instance().Get(t.place())->Wait();
|
||||
if (platform::is_gpu_place(t.place())) {
|
||||
vec.resize(size);
|
||||
cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost);
|
||||
d = vec.data();
|
||||
}
|
||||
#endif
|
||||
VLOG(1) << name << " data_ptr = " << static_cast<void*>(d);
|
||||
std::string out;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
out += std::to_string(d[i]);
|
||||
out += ",";
|
||||
}
|
||||
VLOG(1) << out;
|
||||
}*/
|
||||
|
||||
template <typename T>
|
||||
void GPUSampleWithProb<T>::operator()(
|
||||
const platform::CUDADeviceContext& context, const int seed,
|
||||
const int dict_size, const bool uniq, const std::size_t num_samples,
|
||||
const Tensor* L, Tensor* S, Tensor* P) {
|
||||
// UNDERSTAND: dimension issues
|
||||
const auto lbl_dim = L->dims();
|
||||
const int batch_size = lbl_dim[0];
|
||||
const int num_true = lbl_dim[1];
|
||||
const int num_sampled_classes = num_true + num_samples;
|
||||
framework::DDim ret_dim{batch_size, num_sampled_classes};
|
||||
|
||||
// UNDERSTAND: raw data view
|
||||
const int64_t* label_data = L->data<int64_t>();
|
||||
int64_t* samples_data = S->data<int64_t>();
|
||||
T* probabilities_data = P->data<T>();
|
||||
|
||||
int s_size = num_samples;
|
||||
framework::DDim s_dim{s_size};
|
||||
Tensor s;
|
||||
int64_t* s_data = s.mutable_data<int64_t>(s_dim, platform::CPUPlace());
|
||||
|
||||
math::LogUniformSampler sampler(dict_size, seed);
|
||||
|
||||
int range = dict_size;
|
||||
float log_range = log(range + 1);
|
||||
|
||||
int num_tries = UniqSampler<T>(sampler, num_samples, s_data);
|
||||
VLOG(1) << "num_tries: " << num_tries;
|
||||
PADDLE_ENFORCE(cudaMemcpy(samples_data + num_true, s_data,
|
||||
sizeof(int64_t) * num_samples,
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
int threads = 512;
|
||||
const size_t size = batch_size * num_sampled_classes;
|
||||
int grid = (batch_size * num_sampled_classes + threads - 1) / threads;
|
||||
SamplingCondidate<T><<<grid, threads, 0, context.stream()>>>(
|
||||
size, num_tries, range, log_range, num_true, num_samples, label_data,
|
||||
samples_data, probabilities_data);
|
||||
}
|
||||
|
||||
template class GPUSampleWithProb<float>;
|
||||
template class GPUSampleWithProb<double>;
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,118 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/operators/math/sampler.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
/* UNDERSTAND: utility function to adjust probability for unique sampling,
|
||||
return whatever as it is if not using unique samping */
|
||||
template <typename T>
|
||||
static T adjust_prob(const T prob, const int num_samples, const int num_tries) {
|
||||
if (num_samples == num_tries) {
|
||||
return prob * num_samples;
|
||||
} else {
|
||||
return -expm1(num_tries * log1p(-prob));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class SampleWithProb {
|
||||
public:
|
||||
void operator()(const DeviceContext& context, const Sampler& sampler,
|
||||
const std::size_t num_samples, const Tensor* L, Tensor* S,
|
||||
Tensor* P) {
|
||||
// UNDERSTAND: dimension issues
|
||||
const auto lbl_dim = L->dims();
|
||||
const int batch_size = lbl_dim[0];
|
||||
const int num_true = lbl_dim[1];
|
||||
const int num_sampled_classes = num_true + num_samples;
|
||||
framework::DDim ret_dim{batch_size, num_sampled_classes};
|
||||
|
||||
// UNDERSTAND: raw data view
|
||||
const int64_t* label_data = L->data<int64_t>();
|
||||
int64_t* samples_data =
|
||||
S->mutable_data<int64_t>(ret_dim, context.GetPlace());
|
||||
T* probabilities_data = P->mutable_data<T>(ret_dim, context.GetPlace());
|
||||
|
||||
// temp sets for unique sampling
|
||||
std::unordered_set<int64_t> tmp_samples;
|
||||
int j = 0; // column index
|
||||
// add true labels, not that efficient
|
||||
while (j < num_true) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
auto samples_index = i * num_sampled_classes + j;
|
||||
auto v = label_data[i * num_true + j];
|
||||
samples_data[samples_index] = v;
|
||||
probabilities_data[samples_index] = sampler.Probability(v);
|
||||
}
|
||||
++j;
|
||||
}
|
||||
|
||||
// sample num_samles unique samples for an example, note that they are not
|
||||
// all negative samples
|
||||
tmp_samples.clear();
|
||||
int num_tries = 0;
|
||||
while (j < num_sampled_classes) {
|
||||
++num_tries;
|
||||
auto v = sampler.Sample();
|
||||
auto insert_ok = tmp_samples.insert(v).second;
|
||||
if (!insert_ok) {
|
||||
continue;
|
||||
}
|
||||
auto p = sampler.Probability(v);
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
auto samples_index = i * num_sampled_classes + j;
|
||||
samples_data[samples_index] = v;
|
||||
probabilities_data[samples_index] = p;
|
||||
}
|
||||
++j;
|
||||
}
|
||||
|
||||
// compute Q(y|x), because of unique sampling, probabilities need to be
|
||||
// adjusted
|
||||
for (int k = 0; k < num_sampled_classes; ++k) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
auto samples_index = i * num_sampled_classes + k;
|
||||
probabilities_data[samples_index] = adjust_prob(
|
||||
probabilities_data[samples_index], num_samples, num_tries);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
template <typename T>
|
||||
class GPUSampleWithProb {
|
||||
public:
|
||||
void operator()(const platform::CUDADeviceContext& context, const int seed,
|
||||
const int dict_size, const bool uniq,
|
||||
const std::size_t num_samples, const Tensor* L, Tensor* S,
|
||||
Tensor* P);
|
||||
};
|
||||
#endif
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,248 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/sample_logits_op.h"
|
||||
#include "paddle/fluid/operators/math/sample_prob.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Logits",
|
||||
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
|
||||
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
|
||||
"and K is the class number.");
|
||||
AddInput("Label",
|
||||
"(Tensor) The ground truth which is a 2-D tensor. Label is a "
|
||||
"Tensor<int64> with shape [N x NT], where NT is the number of"
|
||||
"true labels for each example.");
|
||||
AddInput(
|
||||
"CustomSamples",
|
||||
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shaoe [N x "
|
||||
"S+NT]."
|
||||
"The customized sample labels with true labels at first. This tensor"
|
||||
"is only use_custom_samples is true.")
|
||||
.AsDispensable();
|
||||
AddInput(
|
||||
"CustomProbabilities",
|
||||
"(Tensor, default: Tensor<float>), A 2-D tensor with shaoe [N x S+NT]."
|
||||
"The customized sample probabilities with true labels at first. This "
|
||||
"tensor is only use_custom_samples is true.")
|
||||
.AsDispensable();
|
||||
AddOutput(
|
||||
"Samples",
|
||||
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N x "
|
||||
"S+NT]."
|
||||
"The outputs value of sampler by given the true label, where S is the "
|
||||
"number of negative sample for each example. So Samples includes NT "
|
||||
"true"
|
||||
"labels and S negative labels for each example. This will be used in"
|
||||
"backward calculation.")
|
||||
.AsIntermediate();
|
||||
AddOutput(
|
||||
"Probabilities",
|
||||
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x "
|
||||
"S+NT]."
|
||||
"The outputs value of progabilites of samples by given the true label, "
|
||||
"where S is the "
|
||||
"number of negative sample for each example. So Samples includes NT "
|
||||
"true"
|
||||
"labels and S negative labels for each example.")
|
||||
.AsIntermediate();
|
||||
AddOutput("SampledLogits",
|
||||
"(Tensor, default: Tensor<float>), A 2-D tensor with shape"
|
||||
"[N x S+NT]. The outputs value of sampled softmax, which will be"
|
||||
"used in backward calculation.")
|
||||
.AsIntermediate();
|
||||
AddOutput("SampledLabel",
|
||||
"(Tensor, default: Tensor<int64>), A 2-D tensor. The cross "
|
||||
"entropy loss with shape [N x NT].");
|
||||
AddAttr<bool>(
|
||||
"use_custom_samples",
|
||||
"An indicator whether to use custom samples with probabilities, if True"
|
||||
"the operator will use custom samples and custom probabilities"
|
||||
"otherwise, the operator will generate them by itself.")
|
||||
.SetDefault(false);
|
||||
AddAttr<bool>(
|
||||
"uniq",
|
||||
"An indicator whether to sample non-repetitive negtive labels, if True"
|
||||
"the operator will sample negtive labels without replacement."
|
||||
"otherwise, the operator will sample negtive labels with replacement.")
|
||||
.SetDefault(false);
|
||||
AddAttr<bool>(
|
||||
"remove_accidental_hits",
|
||||
"An indicator whether to remove accidental hits when samples hits true"
|
||||
"labels, the removal is implemented by subtracting the corresponding"
|
||||
"logits by float_max to subpress their softmax to be zero.")
|
||||
.SetDefault(true);
|
||||
AddAttr<int>("num_samples", "The number of negative samples.");
|
||||
AddAttr<int>("seed", "Random seed for generating samples").SetDefault(0);
|
||||
|
||||
AddComment(R"DOC(
|
||||
TODO(chenfeiyu): Write documentation for this Operator.
|
||||
Sampled Softmax With Cross Entropy Operator.
|
||||
|
||||
Cross entropy loss with sampled softmax is used as the output layer extensively.
|
||||
This operator computes the softmax normalized values for each row of the input
|
||||
tensor, after which cross-entropy loss is computed. This provides a more
|
||||
numerically stable gradient.
|
||||
|
||||
Because this operator performs a softmax on logits internally, it expects
|
||||
unscaled logits. This operator should not be used with the output of
|
||||
softmax operator since that would produce incorrect results.
|
||||
|
||||
When the attribute soft_label is set false, this operators expects mutually
|
||||
exclusive hard labels, each sample in a batch is in exactly one class with a
|
||||
probability of 1.0. Each sample in the batch will have a single label.
|
||||
|
||||
The equation is as follows:
|
||||
|
||||
1) Hard label (one-hot label, so every sample has exactly one class)
|
||||
|
||||
$$Loss_j = -\text{Logit}_{Label_j} +
|
||||
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
|
||||
j = 1,..., K$$
|
||||
|
||||
2) Soft label (each sample can have a distribution over all classes)
|
||||
|
||||
$$Loss_j = -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
|
||||
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
|
||||
j = 1,...,K$$
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SampleLogitsOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Logits"),
|
||||
"Input(Logits) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Samples"),
|
||||
"Output(Samples) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Probabilities"),
|
||||
"Output(Probabilities) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("SampledLogits"),
|
||||
"Output(SampledLogits) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("SampledLabel"),
|
||||
"Output(SampledLabel) should be not null.");
|
||||
|
||||
auto logits_dims = ctx->GetInputDim("Logits");
|
||||
auto labels_dims = ctx->GetInputDim("Label");
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
logits_dims.size(), 2UL,
|
||||
"The logits of softmax_with_cross_entropy should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
|
||||
"The labels should be a 2-D tensor.");
|
||||
|
||||
const int num_samples = ctx->Attrs().Get<int>("num_samples");
|
||||
const int num_sampled_classes = labels_dims[1] + num_samples;
|
||||
ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes});
|
||||
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
|
||||
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});
|
||||
ctx->SetOutputDim("SampledLabel", {logits_dims[0], labels_dims[1]});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits"));
|
||||
framework::OpKernelType kt =
|
||||
framework::OpKernelType(data_type, ctx.device_context());
|
||||
// kt.place_ = platform::CPUPlace();
|
||||
return kt;
|
||||
}
|
||||
};
|
||||
|
||||
// UNDERSTAND: InferShape for Grad
|
||||
class SampleLogitsOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Logits"),
|
||||
"Input(Logits) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Samples"),
|
||||
"Input(Samples) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("SampledLogits"),
|
||||
"Input(SampledLogits) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("SampledLogits")),
|
||||
"Input(SampledLogits@Grad) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
|
||||
"Output(Logits@Grad) should be not null.");
|
||||
|
||||
auto logit_dims = ctx->GetInputDim("Logits");
|
||||
auto label_dims = ctx->GetInputDim("Label");
|
||||
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
|
||||
"The label should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL,
|
||||
"The logits should be a 2-D tensor.");
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("Logits"),
|
||||
ctx->GetInputDim("Logits"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = framework::GetDataTypeOfVar(
|
||||
ctx.InputVar(framework::GradVarName("SampledLogits")));
|
||||
framework::OpKernelType kt =
|
||||
framework::OpKernelType(data_type, ctx.device_context());
|
||||
// kt.place_ = platform::CPUPlace();
|
||||
return kt;
|
||||
}
|
||||
};
|
||||
|
||||
// UNDERSTAND: what's the rule for making a GradMaker TODO
|
||||
class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
auto* grad_op = new framework::OpDesc();
|
||||
grad_op->SetType("sample_logits_grad");
|
||||
grad_op->SetInput("Logits", Input("Logits"));
|
||||
grad_op->SetInput("Label", Input("Label"));
|
||||
grad_op->SetInput("Samples", Output("Samples"));
|
||||
grad_op->SetInput("SampledLogits", Output("SampledLogits"));
|
||||
grad_op->SetInput(framework::GradVarName("SampledLogits"),
|
||||
OutputGrad("SampledLogits"));
|
||||
grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
|
||||
grad_op->SetAttrMap(Attrs());
|
||||
return std::unique_ptr<framework::OpDesc>(grad_op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(sample_logits, ops::SampleLogitsOp, ops::SampleLogitsOpMaker,
|
||||
ops::SampleLogitsGradMaker);
|
||||
REGISTER_OPERATOR(sample_logits_grad, ops::SampleLogitsOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(sample_logits, ops::SampleLogitsKernel<float>,
|
||||
ops::SampleLogitsKernel<double>);
|
||||
REGISTER_OP_CPU_KERNEL(sample_logits_grad, ops::SampleLogitsGradKernel<float>,
|
||||
ops::SampleLogitsGradKernel<double>);
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue