|
|
@ -27,8 +27,6 @@ limitations under the License. */
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
DEFINE_bool(debug_print, true, "run debug mode");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// UNDERSTAND: something like take_along_axis in numpy.
|
|
|
|
// UNDERSTAND: something like take_along_axis in numpy.
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
__global__ void GPUTakeAlongD1(size_t size, const int batch_size,
|
|
|
|
__global__ void GPUTakeAlongD1(size_t size, const int batch_size,
|
|
|
@ -108,32 +106,6 @@ template <typename T>
|
|
|
|
class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
template <typename type>
|
|
|
|
|
|
|
|
void Print(const Tensor& t, std::string name) const {
|
|
|
|
|
|
|
|
if (!FLAGS_debug_print) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(1) << name << " size = " << t.numel();
|
|
|
|
|
|
|
|
size_t size = t.numel();
|
|
|
|
|
|
|
|
const type* d = t.data<type>();
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
|
|
|
std::vector<type> vec;
|
|
|
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(t.place())->Wait();
|
|
|
|
|
|
|
|
if (platform::is_gpu_place(t.place())) {
|
|
|
|
|
|
|
|
vec.resize(size);
|
|
|
|
|
|
|
|
cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost);
|
|
|
|
|
|
|
|
d = vec.data();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
VLOG(1) << name << " data_ptr = " << static_cast<const void*>(d);
|
|
|
|
|
|
|
|
std::string out;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
|
|
|
|
out += std::to_string(d[i]);
|
|
|
|
|
|
|
|
out += ",";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(1) << out;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
// get necessary inputs
|
|
|
|
// get necessary inputs
|
|
|
|
const Tensor* logits = context.Input<Tensor>("Logits");
|
|
|
|
const Tensor* logits = context.Input<Tensor>("Logits");
|
|
|
@ -189,12 +161,9 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
// UNDERSTAND: sampling
|
|
|
|
// UNDERSTAND: sampling
|
|
|
|
const auto seed = context.Attr<int>("seed");
|
|
|
|
const auto seed = context.Attr<int>("seed");
|
|
|
|
auto sampler_with_prob = math::GPUSampleWithProb<T>();
|
|
|
|
auto sampler_with_prob = math::GPUSampleWithProb<T>();
|
|
|
|
Print<int64_t>(*samples, std::string("samples1"));
|
|
|
|
|
|
|
|
sampler_with_prob(context.cuda_device_context(), seed, num_classes, uniq,
|
|
|
|
sampler_with_prob(context.cuda_device_context(), seed, num_classes, uniq,
|
|
|
|
num_samples, label, samples, probabilities);
|
|
|
|
num_samples, label, samples, probabilities);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Print<int64_t>(*samples, std::string("samples2"));
|
|
|
|
|
|
|
|
Print<T>(*probabilities, std::string("probabilities"));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
|
|
|
|
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
|
|
|
|
const auto num_take = samples->dims()[1];
|
|
|
|
const auto num_take = samples->dims()[1];
|
|
|
@ -216,7 +185,6 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
size, batch_size, array_slice_size, idx_slice_size, p_array, p_index,
|
|
|
|
size, batch_size, array_slice_size, idx_slice_size, p_array, p_index,
|
|
|
|
p_value);
|
|
|
|
p_value);
|
|
|
|
Print<T>(*sampled_logits, std::string("sampled_logits"));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (remove_accidental_hits) {
|
|
|
|
if (remove_accidental_hits) {
|
|
|
|
const size_t size = batch_size * (num_true + num_samples);
|
|
|
|
const size_t size = batch_size * (num_true + num_samples);
|
|
|
@ -224,8 +192,6 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
gpu_compute_remove_accidental_hits<
|
|
|
|
gpu_compute_remove_accidental_hits<
|
|
|
|
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
size, num_true, idx_slice_size, p_index, p_value);
|
|
|
|
size, num_true, idx_slice_size, p_index, p_value);
|
|
|
|
Print<T>(*sampled_logits,
|
|
|
|
|
|
|
|
std::string("sampled_logits_remove_accidental_hits"));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// subtracted sampled logits with logQ(y|x)
|
|
|
|
// subtracted sampled logits with logQ(y|x)
|
|
|
@ -234,7 +200,6 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
smp_logits.device(*dev_ctx.eigen_device()) =
|
|
|
|
smp_logits.device(*dev_ctx.eigen_device()) =
|
|
|
|
(smp_logits - probs.log().unaryExpr(TolerableValue<T>()))
|
|
|
|
(smp_logits - probs.log().unaryExpr(TolerableValue<T>()))
|
|
|
|
.unaryExpr(TolerableValue<T>());
|
|
|
|
.unaryExpr(TolerableValue<T>());
|
|
|
|
Print<T>(*sampled_logits, std::string("sampled_logits_res"));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -242,32 +207,6 @@ template <typename T>
|
|
|
|
class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
template <typename type>
|
|
|
|
|
|
|
|
void Print(const Tensor& t, std::string name) const {
|
|
|
|
|
|
|
|
if (!FLAGS_debug_print) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(1) << name << " size = " << t.numel();
|
|
|
|
|
|
|
|
size_t size = t.numel();
|
|
|
|
|
|
|
|
const type* d = t.data<type>();
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
|
|
|
std::vector<type> vec;
|
|
|
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(t.place())->Wait();
|
|
|
|
|
|
|
|
if (platform::is_gpu_place(t.place())) {
|
|
|
|
|
|
|
|
vec.resize(size);
|
|
|
|
|
|
|
|
cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost);
|
|
|
|
|
|
|
|
d = vec.data();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
VLOG(1) << name << " data_ptr = " << static_cast<const void*>(d);
|
|
|
|
|
|
|
|
std::string out;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
|
|
|
|
out += std::to_string(d[i]);
|
|
|
|
|
|
|
|
out += ",";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(1) << out;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto logits_grad = context.Output<Tensor>(framework::GradVarName("Logits"));
|
|
|
|
auto logits_grad = context.Output<Tensor>(framework::GradVarName("Logits"));
|
|
|
|
const Tensor* samples = context.Input<Tensor>("Samples");
|
|
|
|
const Tensor* samples = context.Input<Tensor>("Samples");
|
|
|
@ -298,13 +237,10 @@ class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
const size_t size = batch_size;
|
|
|
|
const size_t size = batch_size;
|
|
|
|
int grid = (size + threads - 1) / threads;
|
|
|
|
int grid = (size + threads - 1) / threads;
|
|
|
|
|
|
|
|
|
|
|
|
Print<T>(*sampled_logits_grad, std::string("sampled_logits_grad"));
|
|
|
|
|
|
|
|
Print<int64_t>(*samples, std::string("samples"));
|
|
|
|
|
|
|
|
GPUPutAlongD1<
|
|
|
|
GPUPutAlongD1<
|
|
|
|
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
|
|
|
|
size, batch_size, array_slice_size, idx_slice_size, p_array, p_index,
|
|
|
|
size, batch_size, array_slice_size, idx_slice_size, p_array, p_index,
|
|
|
|
p_value);
|
|
|
|
p_value);
|
|
|
|
Print<T>(*logits_grad, std::string("logits_grad"));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|