Remove constraint that last dimension is forced to be 1 by adding lookup_table_v2 (#19735)
* Remove constraint that last dimension is forced to be 1 by add lookup_table_v2 test=develop * modify into PADDLE_ENFORCE_CUDA_SUCCESS test=develop * Revert "modify into PADDLE_ENFORCE_CUDA_SUCCESS test=develop" This reverts commit 8a960bfc61e51aa27c3c529df8fb90b93ebd19f9. * move api into fluid.embedding test=develop * fix example code test=develop * move one_hot into fluid.one_hot * modify api.spec test=develop * fix loss shape test=developexpand_as_op_1
parent
80e0f547bb
commit
039b9710d5
@ -0,0 +1,192 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/lookup_table_v2_op.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
|
||||||
|
#include "paddle/fluid/framework/var_type_inference.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class LookupTableV2Op : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
|
||||||
|
"Input(W) of LookupTableV2Op should not be null.");
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true,
|
||||||
|
"Input(Ids) of LookupTableV2Op should not be null.");
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||||
|
"Output(Out) of LookupTableV2Op should not be null.");
|
||||||
|
|
||||||
|
auto table_dims = ctx->GetInputDim("W");
|
||||||
|
auto ids_dims = ctx->GetInputDim("Ids");
|
||||||
|
int ids_rank = ids_dims.size();
|
||||||
|
VLOG(5) << "ids rank is " << ids_rank << std::endl;
|
||||||
|
PADDLE_ENFORCE_EQ(table_dims.size(), 2);
|
||||||
|
|
||||||
|
auto output_dims = framework::vectorize(ids_dims);
|
||||||
|
output_dims.push_back(table_dims[1]);
|
||||||
|
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
|
||||||
|
|
||||||
|
if (ctx->GetOutputsVarType("Out")[0] ==
|
||||||
|
framework::proto::VarType::LOD_TENSOR) {
|
||||||
|
ctx->ShareLoD("Ids", /*->*/ "Out");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
|
||||||
|
return framework::OpKernelType(data_type, ctx.device_context());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("W",
|
||||||
|
"(Tensor) The input represents embedding tensors, "
|
||||||
|
"which is a learnable parameter.");
|
||||||
|
AddInput("Ids",
|
||||||
|
"An input with type int32 or int64 "
|
||||||
|
"contains the ids to be looked up in W. "
|
||||||
|
"The last dimension size must be 1.");
|
||||||
|
AddOutput("Out", "The lookup results, which have the same type as W.");
|
||||||
|
AddAttr<bool>("is_sparse",
|
||||||
|
"(boolean, default false) "
|
||||||
|
"Sparse update.")
|
||||||
|
.SetDefault(false);
|
||||||
|
AddAttr<bool>("is_distributed",
|
||||||
|
"(boolean, default false) distributed lookup table.")
|
||||||
|
.SetDefault(false);
|
||||||
|
AddAttr<int64_t>("padding_idx",
|
||||||
|
"(int64, default -1) "
|
||||||
|
"If the value is -1, it makes no effect to lookup. "
|
||||||
|
"Otherwise the given value indicates padding the output "
|
||||||
|
"with zeros whenever lookup encounters it in Ids.")
|
||||||
|
.SetDefault(kNoPadding);
|
||||||
|
|
||||||
|
// for parameter prefetch
|
||||||
|
AddAttr<bool>("remote_prefetch", "").SetDefault(false);
|
||||||
|
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
|
||||||
|
AddAttr<std::vector<int64_t>>("height_sections",
|
||||||
|
"Height for each output SelectedRows.")
|
||||||
|
.SetDefault(std::vector<int64_t>({}));
|
||||||
|
AddAttr<std::vector<std::string>>(
|
||||||
|
"epmap",
|
||||||
|
"(string vector, default 127.0.0.1:6164)"
|
||||||
|
"Server endpoints in the order of input variables for mapping")
|
||||||
|
.SetDefault({});
|
||||||
|
AddAttr<std::vector<std::string>>(
|
||||||
|
"table_names",
|
||||||
|
"(string vector, the splited table names that will be fetched from "
|
||||||
|
"parameter server)"
|
||||||
|
"in the order of input variables for mapping")
|
||||||
|
.SetDefault({});
|
||||||
|
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Lookup Table V2 Operator.
|
||||||
|
|
||||||
|
This operator is used to perform lookups on the parameter W,
|
||||||
|
then concatenated into a dense tensor.
|
||||||
|
|
||||||
|
The input Ids can carry the LoD (Level of Details) information,
|
||||||
|
or not. And the output only shares the LoD information with input Ids.
|
||||||
|
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(LookupTableV2GradOpNoBuffer, "W");
|
||||||
|
|
||||||
|
class LookupTableV2GradOpDescMaker : public framework::SingleGradOpDescMaker {
|
||||||
|
public:
|
||||||
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||||
|
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
||||||
|
|
||||||
|
op->SetType("lookup_table_v2_grad");
|
||||||
|
|
||||||
|
op->SetInput("W", Input("W"));
|
||||||
|
op->SetInput("Ids", Input("Ids"));
|
||||||
|
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
||||||
|
|
||||||
|
op->SetOutput(framework::GradVarName("W"), InputGrad("W"));
|
||||||
|
|
||||||
|
op->SetAttrMap(Attrs());
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LookupTableV2OpGrad : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
auto table_dims = ctx->GetInputDim("W");
|
||||||
|
ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto data_type = framework::GetDataTypeOfVar(
|
||||||
|
ctx.InputVar(framework::GradVarName("Out")));
|
||||||
|
return framework::OpKernelType(data_type, ctx.device_context());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LookupTableV2OpGradVarTypeInference : public framework::VarTypeInference {
|
||||||
|
public:
|
||||||
|
void operator()(framework::InferVarTypeContext* ctx) const override {
|
||||||
|
auto out_var_name = ctx->Output(framework::GradVarName("W")).front();
|
||||||
|
auto attr = ctx->GetAttr("is_sparse");
|
||||||
|
bool is_sparse = boost::get<bool>(attr);
|
||||||
|
if (is_sparse) {
|
||||||
|
VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W")
|
||||||
|
<< " is set to SelectedRows";
|
||||||
|
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
|
||||||
|
} else {
|
||||||
|
VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W")
|
||||||
|
<< " is set to LoDTensor";
|
||||||
|
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
|
||||||
|
}
|
||||||
|
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0]));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(lookup_table_v2, ops::LookupTableV2Op,
|
||||||
|
ops::LookupTableV2OpMaker, ops::LookupTableV2GradOpDescMaker);
|
||||||
|
|
||||||
|
REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad,
|
||||||
|
ops::LookupTableV2GradOpNoBuffer,
|
||||||
|
ops::LookupTableV2OpGradVarTypeInference);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>,
|
||||||
|
ops::LookupTableV2Kernel<double>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(lookup_table_v2_grad,
|
||||||
|
ops::LookupTableV2GradKernel<float>,
|
||||||
|
ops::LookupTableV2GradKernel<double>);
|
@ -0,0 +1,201 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/operators/lookup_table_v2_op.h"
|
||||||
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||||
|
#include "paddle/fluid/platform/float16.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
|
||||||
|
bool PaddingFlag>
|
||||||
|
__global__ void LookupTableV2(T *output, const T *table, const int64_t *ids,
|
||||||
|
const int64_t N, const int64_t K, const int64_t D,
|
||||||
|
const int64_t padding_idx) {
|
||||||
|
int idx = threadIdx.x;
|
||||||
|
int idy = blockIdx.x + threadIdx.y * GridDimX;
|
||||||
|
|
||||||
|
while (idy < K) {
|
||||||
|
int64_t id = ids[idy];
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
id >= 0,
|
||||||
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||||
|
"expected >= 0 and < %ld, but got %ld. Please check input value.",
|
||||||
|
N, id);
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
id < N,
|
||||||
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||||
|
"expected >= 0 and < %ld, but got %ld. Please check input value.",
|
||||||
|
N, id);
|
||||||
|
T *out = output + idy * D;
|
||||||
|
const T *tab = table + id * D;
|
||||||
|
for (int i = idx; i < D; i += BlockDimX) {
|
||||||
|
if (PaddingFlag) {
|
||||||
|
if (id == padding_idx)
|
||||||
|
out[i] = static_cast<T>(0);
|
||||||
|
else
|
||||||
|
out[i] = tab[i];
|
||||||
|
} else {
|
||||||
|
out[i] = tab[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
idy += BlockDimY * GridDimX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
|
||||||
|
__global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids,
|
||||||
|
const int64_t N, const int64_t K,
|
||||||
|
const int64_t D) {
|
||||||
|
int idx = threadIdx.x;
|
||||||
|
int idy = blockIdx.x + threadIdx.y * GridDimX;
|
||||||
|
|
||||||
|
while (idy < K) {
|
||||||
|
int64_t id = ids[idy];
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
id >= 0,
|
||||||
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||||
|
"expected >= 0 and < %ld, but got %ld. Please check input value.",
|
||||||
|
N, id);
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
id < N,
|
||||||
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||||
|
"expected >= 0 and < %ld, but got %ld. Please check input value.",
|
||||||
|
N, id);
|
||||||
|
const T *out = output + idy * D;
|
||||||
|
T *tab = table + id * D;
|
||||||
|
for (int i = idx; i < D; i += BlockDimX) {
|
||||||
|
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
|
||||||
|
}
|
||||||
|
idy += BlockDimY * GridDimX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &context) const override {
|
||||||
|
auto *table_t = context.Input<LoDTensor>("W");
|
||||||
|
auto *ids_t = context.Input<LoDTensor>("Ids");
|
||||||
|
auto *output_t = context.Output<LoDTensor>("Out");
|
||||||
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
||||||
|
|
||||||
|
auto id_name = context.Inputs("Ids").front();
|
||||||
|
auto out_name = context.Outputs("Out").front();
|
||||||
|
|
||||||
|
size_t N = table_t->dims()[0];
|
||||||
|
size_t D = table_t->dims()[1];
|
||||||
|
size_t K = ids_t->numel();
|
||||||
|
|
||||||
|
auto *ids = ids_t->data<int64_t>();
|
||||||
|
auto *table = table_t->data<T>();
|
||||||
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
||||||
|
|
||||||
|
dim3 threads(128, 8);
|
||||||
|
dim3 grids(8, 1);
|
||||||
|
|
||||||
|
if (padding_idx == -1)
|
||||||
|
LookupTableV2<
|
||||||
|
T, 128, 8, 8,
|
||||||
|
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
||||||
|
output, table, ids, N, K, D, padding_idx);
|
||||||
|
else
|
||||||
|
LookupTableV2<
|
||||||
|
T, 128, 8, 8,
|
||||||
|
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
|
||||||
|
output, table, ids, N, K, D, padding_idx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &context) const override {
|
||||||
|
auto &dev_ctx =
|
||||||
|
context.template device_context<platform::CUDADeviceContext>();
|
||||||
|
bool is_sparse = context.Attr<bool>("is_sparse");
|
||||||
|
|
||||||
|
// Since paddings are not trainable and fixed in forward, the gradient of
|
||||||
|
// paddings makes no sense and we don't deal with it in backward.
|
||||||
|
if (is_sparse) {
|
||||||
|
auto *ids = context.Input<LoDTensor>("Ids");
|
||||||
|
auto *table = context.Input<LoDTensor>("W");
|
||||||
|
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
||||||
|
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
|
||||||
|
|
||||||
|
auto *ids_data = ids->data<int64_t>();
|
||||||
|
int64_t ids_num = ids->numel();
|
||||||
|
|
||||||
|
auto stream = dev_ctx.stream();
|
||||||
|
// copy GPU memory to CPU pinned memory
|
||||||
|
framework::Vector<int64_t> new_rows;
|
||||||
|
new_rows.resize(ids_num);
|
||||||
|
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
|
||||||
|
|
||||||
|
// TODO(yuyang18): Strange code here.
|
||||||
|
memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
|
||||||
|
gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
|
||||||
|
d_table->set_rows(new_rows);
|
||||||
|
|
||||||
|
auto *d_table_value = d_table->mutable_value();
|
||||||
|
d_table_value->Resize({ids_num, table->dims()[1]});
|
||||||
|
d_table_value->mutable_data<T>(context.GetPlace());
|
||||||
|
|
||||||
|
auto *d_table_data = d_table_value->data<T>();
|
||||||
|
auto *d_output_data = d_output->data<T>();
|
||||||
|
auto d_output_dims = d_output->dims();
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
d_table_value->dims(),
|
||||||
|
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
|
||||||
|
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
|
||||||
|
d_output->numel() * sizeof(T), stream);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
auto ids_t = context.Input<LoDTensor>("Ids");
|
||||||
|
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
||||||
|
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
|
||||||
|
|
||||||
|
int N = d_table_t->dims()[0];
|
||||||
|
int D = d_table_t->dims()[1];
|
||||||
|
int K = ids_t->numel();
|
||||||
|
const int64_t *ids = ids_t->data<int64_t>();
|
||||||
|
const T *d_output = d_output_t->data<T>();
|
||||||
|
T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
|
||||||
|
|
||||||
|
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
|
||||||
|
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
|
||||||
|
|
||||||
|
dim3 threads(128, 8);
|
||||||
|
dim3 grids(8, 1);
|
||||||
|
LookupTableV2Grad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
|
||||||
|
d_table, d_output, ids, N, K, D);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
REGISTER_OP_CUDA_KERNEL(lookup_table_v2, ops::LookupTableV2CUDAKernel<float>,
|
||||||
|
ops::LookupTableV2CUDAKernel<double>,
|
||||||
|
ops::LookupTableV2CUDAKernel<plat::float16>);
|
||||||
|
REGISTER_OP_CUDA_KERNEL(lookup_table_v2_grad,
|
||||||
|
ops::LookupTableV2GradCUDAKernel<float>,
|
||||||
|
ops::LookupTableV2GradCUDAKernel<double>,
|
||||||
|
ops::LookupTableV2GradCUDAKernel<plat::float16>);
|
@ -0,0 +1,218 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/framework/lod_tensor.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/selected_rows.h"
|
||||||
|
#include "paddle/fluid/operators/math/blas.h"
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_DISTRIBUTE
|
||||||
|
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
using LoDTensor = framework::LoDTensor;
|
||||||
|
using SelectedRows = framework::SelectedRows;
|
||||||
|
using DDim = framework::DDim;
|
||||||
|
|
||||||
|
constexpr int64_t kNoPadding = -1;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class LookupTableV2Kernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &context) const override {
|
||||||
|
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
|
||||||
|
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
|
||||||
|
auto *table_var = context.InputVar("W");
|
||||||
|
|
||||||
|
auto id_name = context.Inputs("Ids").front();
|
||||||
|
auto embedding_name = context.Inputs("W").front();
|
||||||
|
auto out_name = context.Outputs("Out").front();
|
||||||
|
|
||||||
|
// for remote prefetch
|
||||||
|
auto epmap = context.Attr<std::vector<std::string>>("epmap");
|
||||||
|
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
|
||||||
|
auto height_sections =
|
||||||
|
context.Attr<std::vector<int64_t>>("height_sections");
|
||||||
|
auto table_names = context.Attr<std::vector<std::string>>("table_names");
|
||||||
|
|
||||||
|
if (remote_prefetch && !epmap.empty()) {
|
||||||
|
// if epmap is not empty, then the parameter will be fetched from remote
|
||||||
|
// parameter server
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_DISTRIBUTE
|
||||||
|
operators::distributed::prefetch(id_name, out_name, embedding_name, false,
|
||||||
|
table_names, epmap, height_sections,
|
||||||
|
context, context.scope());
|
||||||
|
#else
|
||||||
|
PADDLE_THROW(
|
||||||
|
"paddle is not compiled with distribute support, can not do "
|
||||||
|
"parameter prefetch!");
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
||||||
|
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
|
||||||
|
int64_t ids_numel = ids_t->numel();
|
||||||
|
|
||||||
|
if (table_var->IsType<LoDTensor>()) {
|
||||||
|
auto *table_t = context.Input<LoDTensor>("W");
|
||||||
|
int64_t row_number = table_t->dims()[0];
|
||||||
|
int64_t row_width = table_t->dims()[1];
|
||||||
|
|
||||||
|
auto *table = table_t->data<T>();
|
||||||
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
||||||
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
||||||
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
||||||
|
} else {
|
||||||
|
PADDLE_ENFORCE_LT(
|
||||||
|
ids[i], row_number,
|
||||||
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||||
|
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||||
|
"value.",
|
||||||
|
row_number, ids[i]);
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
ids[i], 0,
|
||||||
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||||
|
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||||
|
"value.",
|
||||||
|
row_number, ids[i]);
|
||||||
|
memcpy(output + i * row_width, table + ids[i] * row_width,
|
||||||
|
row_width * sizeof(T));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (table_var->IsType<SelectedRows>()) {
|
||||||
|
const auto &table_t = table_var->Get<SelectedRows>();
|
||||||
|
int64_t row_width = table_t.value().dims()[1];
|
||||||
|
const auto *table = table_t.value().data<T>();
|
||||||
|
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
||||||
|
|
||||||
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
||||||
|
for (int64_t i = 0; i < ids_numel; ++i) {
|
||||||
|
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
||||||
|
memset(output + i * row_width, 0, row_width * sizeof(T));
|
||||||
|
} else {
|
||||||
|
PADDLE_ENFORCE_GE(ids[i], 0);
|
||||||
|
auto id_index = table_t.Index(ids[i]);
|
||||||
|
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
|
||||||
|
blas.VCOPY(row_width, table + id_index * row_width,
|
||||||
|
output + i * row_width);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class LookupTableV2GradKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &context) const override {
|
||||||
|
auto *table_var = context.InputVar("W");
|
||||||
|
DDim table_dim;
|
||||||
|
if (table_var->IsType<LoDTensor>()) {
|
||||||
|
table_dim = context.Input<LoDTensor>("W")->dims();
|
||||||
|
} else if (table_var->IsType<SelectedRows>()) {
|
||||||
|
auto *table_t = context.Input<SelectedRows>("W");
|
||||||
|
table_dim = table_t->value().dims();
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(
|
||||||
|
"The parameter W of a LookupTableV2 "
|
||||||
|
"must be either LoDTensor or SelectedRows");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
||||||
|
bool is_sparse = context.Attr<bool>("is_sparse");
|
||||||
|
// Since paddings are not trainable and fixed in forward, the gradient of
|
||||||
|
// paddings makes no sense and we don't deal with it in backward.
|
||||||
|
if (is_sparse) {
|
||||||
|
auto *ids = context.Input<LoDTensor>("Ids");
|
||||||
|
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
||||||
|
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
|
||||||
|
|
||||||
|
auto *ids_data = ids->data<int64_t>();
|
||||||
|
int64_t ids_num = ids->numel();
|
||||||
|
|
||||||
|
std::vector<int64_t> new_rows;
|
||||||
|
new_rows.resize(ids_num);
|
||||||
|
std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
|
||||||
|
d_table->set_rows(new_rows);
|
||||||
|
|
||||||
|
auto *d_table_value = d_table->mutable_value();
|
||||||
|
d_table_value->Resize({ids_num, table_dim[1]});
|
||||||
|
|
||||||
|
d_table_value->mutable_data<T>(context.GetPlace());
|
||||||
|
|
||||||
|
d_table->set_height(table_dim[0]);
|
||||||
|
|
||||||
|
auto *d_output_data = d_output->data<T>();
|
||||||
|
auto *d_table_data = d_table_value->data<T>();
|
||||||
|
|
||||||
|
auto d_output_dims = d_output->dims();
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
d_table_value->dims(),
|
||||||
|
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
|
||||||
|
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
|
||||||
|
|
||||||
|
} else {
|
||||||
|
auto *ids = context.Input<LoDTensor>("Ids");
|
||||||
|
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
||||||
|
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
|
||||||
|
|
||||||
|
auto *ids_data = ids->data<int64_t>();
|
||||||
|
|
||||||
|
int64_t N = table_dim[0];
|
||||||
|
int64_t D = table_dim[1];
|
||||||
|
|
||||||
|
auto *d_output_data = d_output->data<T>();
|
||||||
|
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
|
||||||
|
|
||||||
|
memset(d_table_data, 0, d_table->numel() * sizeof(T));
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < ids->numel(); ++i) {
|
||||||
|
if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
|
||||||
|
// the gradient of padding_idx should be 0, already done by memset, so
|
||||||
|
// do nothing.
|
||||||
|
} else {
|
||||||
|
PADDLE_ENFORCE_LT(
|
||||||
|
ids_data[i], N,
|
||||||
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||||
|
"expected >= 0 and < %ld, but got %ld. Please check input value.",
|
||||||
|
N, ids_data[i]);
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
ids_data[i], 0,
|
||||||
|
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||||
|
"expected >= 0 and < %ld, but got %ld. Please check input value.",
|
||||||
|
N, ids_data[i]);
|
||||||
|
for (int j = 0; j < D; ++j) {
|
||||||
|
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,216 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid.op import Operator
|
||||||
|
import paddle.compat as cpt
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupTableOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "lookup_table_v2"
|
||||||
|
table = np.random.random((17, 31)).astype("float32")
|
||||||
|
ids = np.random.randint(0, 17, 4).astype("int64")
|
||||||
|
self.inputs = {'W': table, 'Ids': ids}
|
||||||
|
self.outputs = {'Out': table[ids]}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupTableOpWithTensorIds(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "lookup_table_v2"
|
||||||
|
table = np.random.random((17, 31)).astype("float32")
|
||||||
|
ids = np.random.randint(low=0, high=17, size=(2, 4, 5)).astype("int64")
|
||||||
|
self.inputs = {'W': table, 'Ids': ids}
|
||||||
|
self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupTableOpWithPadding(TestLookupTableOp):
|
||||||
|
def test_check_output(self):
|
||||||
|
ids = np.squeeze(self.inputs['Ids'])
|
||||||
|
padding_idx = np.random.choice(ids, 1)[0]
|
||||||
|
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
|
||||||
|
self.attrs = {'padding_idx': int(padding_idx)}
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
# Since paddings are not trainable and fixed in forward, the gradient of
|
||||||
|
# paddings makes no sense and we don't test the gradient here.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
|
||||||
|
def test_check_output(self):
|
||||||
|
ids = self.inputs['Ids']
|
||||||
|
flatten_idx = ids.flatten()
|
||||||
|
padding_idx = np.random.choice(flatten_idx, 1)[0]
|
||||||
|
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
|
||||||
|
self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
# Since paddings are not trainable and fixed in forward, the gradient of
|
||||||
|
# paddings makes no sense and we don't test the gradient here.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupTableWIsSelectedRows(OpTest):
|
||||||
|
def prepare_ids(self, scope, place):
|
||||||
|
ids_tensor = scope.var('Ids').get_tensor()
|
||||||
|
ids_array = np.array([0, 4, 3, 5]).astype("int64")
|
||||||
|
ids_tensor.set(ids_array, place)
|
||||||
|
return ids_array
|
||||||
|
|
||||||
|
def prepare_w(self, scope, place):
|
||||||
|
rows = [0, 1, 2, 3, 4, 5, 6]
|
||||||
|
row_numel = 12
|
||||||
|
|
||||||
|
w_selected_rows = scope.var('W').get_selected_rows()
|
||||||
|
w_selected_rows.set_height(len(rows))
|
||||||
|
w_selected_rows.set_rows(rows)
|
||||||
|
w_array = np.ones((len(rows), row_numel)).astype("float32")
|
||||||
|
for i in range(len(rows)):
|
||||||
|
w_array[i] *= i
|
||||||
|
w_tensor = w_selected_rows.get_tensor()
|
||||||
|
w_tensor.set(w_array, place)
|
||||||
|
|
||||||
|
def create_out_tensor(self, scope, place):
|
||||||
|
return scope.var('Out').get_tensor()
|
||||||
|
|
||||||
|
def check_result(self, ids_array, result_array):
|
||||||
|
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
|
||||||
|
for idx, row in enumerate(ids_array):
|
||||||
|
assert (row == result_array[idx]).all()
|
||||||
|
|
||||||
|
def check_with_place(self, place):
|
||||||
|
scope = core.Scope()
|
||||||
|
|
||||||
|
ids_array = self.prepare_ids(scope, place)
|
||||||
|
|
||||||
|
self.prepare_w(scope, place)
|
||||||
|
|
||||||
|
out_tensor = self.create_out_tensor(scope, place)
|
||||||
|
|
||||||
|
# create and run lookup_table operator
|
||||||
|
lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out')
|
||||||
|
lookup_table.run(scope, place)
|
||||||
|
|
||||||
|
# get result from Out
|
||||||
|
result_array = np.array(out_tensor)
|
||||||
|
|
||||||
|
self.check_result(ids_array, result_array)
|
||||||
|
|
||||||
|
def test_w_is_selected_rows(self):
|
||||||
|
places = [core.CPUPlace()]
|
||||||
|
# currently only support CPU
|
||||||
|
for place in places:
|
||||||
|
self.check_with_place(place)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupTableWithTensorIdsWIsSelectedRows(
|
||||||
|
TestLookupTableWIsSelectedRows):
|
||||||
|
def prepare_ids(self, scope, place):
|
||||||
|
ids_tensor = scope.var('Ids').get_tensor()
|
||||||
|
ids_array = np.random.randint(
|
||||||
|
low=0, high=6, size=(2, 4, 3)).astype("int64")
|
||||||
|
ids_tensor.set(ids_array, place)
|
||||||
|
return ids_array
|
||||||
|
|
||||||
|
def check_result(self, ids_array, result_array):
|
||||||
|
for idx, row in np.ndenumerate(ids_array):
|
||||||
|
assert (row == result_array[idx]).all()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupTableIsSparse(unittest.TestCase):
|
||||||
|
def init_data(self):
|
||||||
|
self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64")
|
||||||
|
self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32")
|
||||||
|
|
||||||
|
def get_w_grad(self, is_sparse):
|
||||||
|
self.init_data()
|
||||||
|
main_program = fluid.Program()
|
||||||
|
with fluid.program_guard(main_program, fluid.Program()):
|
||||||
|
x = fluid.layers.data(name='x', shape=[5], dtype='int64')
|
||||||
|
y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32')
|
||||||
|
emb = fluid.input.embedding(
|
||||||
|
input=x,
|
||||||
|
size=[10, 16],
|
||||||
|
param_attr=fluid.ParamAttr(
|
||||||
|
name="emb_weight",
|
||||||
|
learning_rate=10,
|
||||||
|
initializer=fluid.initializer.NumpyArrayInitializer(
|
||||||
|
self.w_data)),
|
||||||
|
is_sparse=is_sparse)
|
||||||
|
y = fluid.layers.reduce_sum(emb, dim=-1)
|
||||||
|
|
||||||
|
loss = fluid.layers.square_error_cost(input=y, label=y_)
|
||||||
|
loss = fluid.layers.mean(loss)
|
||||||
|
|
||||||
|
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
|
||||||
|
sgd_optimizer.minimize(loss)
|
||||||
|
|
||||||
|
place = fluid.CPUPlace()
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(fluid.default_startup_program())
|
||||||
|
ret = exe.run(feed={'x': self.x_data,
|
||||||
|
'y_': self.y_data},
|
||||||
|
fetch_list=['emb_weight'],
|
||||||
|
return_numpy=False)
|
||||||
|
return np.array(ret[0])
|
||||||
|
|
||||||
|
def test_w_grad(self):
|
||||||
|
self.w_data = np.random.random(size=(10, 16)).astype("float32")
|
||||||
|
w_grad = self.get_w_grad(False)
|
||||||
|
w_grad_with_sparse = self.get_w_grad(True)
|
||||||
|
self.check_grad(w_grad, w_grad_with_sparse)
|
||||||
|
|
||||||
|
def check_grad(self, w_grad1, w_grad2, tolerance=1e-6):
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
w_grad1, w_grad2, rtol=tolerance, atol=tolerance)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupTableApi(unittest.TestCase):
|
||||||
|
def test_api(self):
|
||||||
|
x = fluid.layers.data(name='x', shape=[20], dtype='int64')
|
||||||
|
emb = fluid.embedding(input=x, size=[128, 64])
|
||||||
|
|
||||||
|
place = fluid.CPUPlace()
|
||||||
|
x_data = np.random.randint(0, 127, [2, 20]).astype("int64")
|
||||||
|
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(fluid.default_startup_program())
|
||||||
|
ret = exe.run(feed={'x': x_data, },
|
||||||
|
fetch_list=[emb],
|
||||||
|
return_numpy=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue