Merge pull request #3620 from qingqing01/lookup_table
Add a lookup table op and a CUDA helper.revert-3824-remove_grad_op_type
commit
3663bd881d
@ -0,0 +1,72 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/lookup_table_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LookupTableOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &context) const override {
|
||||
auto table_t = context.Input<Tensor>("W");
|
||||
auto ids_t = context.Input<Tensor>("Ids");
|
||||
auto output_t = context.Output<Tensor>("Out");
|
||||
|
||||
output_t->Resize({ids_t->dims()[0], table_t->dims()[1]});
|
||||
}
|
||||
};
|
||||
|
||||
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
LookupTableOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("W",
|
||||
"An input represents embedding tensors,"
|
||||
" which is a learnable parameter.");
|
||||
AddInput("Ids",
|
||||
"An input with type int32 or int64"
|
||||
"contains the ids to be looked up in W.");
|
||||
AddOutput("Out", "The lookup results, which have the same type with W.");
|
||||
AddComment(
|
||||
"This operator is used to perform lookups on the parameter W,"
|
||||
"then concatenated into a dense tensor.");
|
||||
}
|
||||
};
|
||||
|
||||
class LookupTableOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &context) const override {
|
||||
auto table = context.Input<Tensor>("W");
|
||||
auto d_table = context.Output<Tensor>(framework::GradVarName("W"));
|
||||
d_table->Resize(table->dims());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
|
||||
lookup_table_grad, ops::LookupTableOpGrad);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>);
|
||||
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>);
|
@ -0,0 +1,116 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/platform/assert.h"
|
||||
#include "paddle/platform/cuda_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
|
||||
__global__ void LookupTable(T* output, const T* table, const int32_t* ids,
|
||||
const int N, const int K, const int D) {
|
||||
int idx = threadIdx.x;
|
||||
int idy = blockIdx.x + threadIdx.y * GridDimX;
|
||||
|
||||
while (idy < K) {
|
||||
int id = ids[idy];
|
||||
PADDLE_ASSERT(id >= 0);
|
||||
PADDLE_ASSERT(id < N);
|
||||
T* out = output + idy * D;
|
||||
const T* tab = table + id * D;
|
||||
for (int i = idx; i < D; i += BlockDimX) {
|
||||
out[i] = tab[i];
|
||||
}
|
||||
idy += BlockDimY * GridDimX;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
|
||||
__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
|
||||
const int N, const int K, const int D) {
|
||||
int idx = threadIdx.x;
|
||||
int idy = blockIdx.x + threadIdx.y * GridDimX;
|
||||
|
||||
while (idy < K) {
|
||||
int id = ids[idy];
|
||||
PADDLE_ASSERT(id >= 0);
|
||||
PADDLE_ASSERT(id < N);
|
||||
const T* out = output + idy * D;
|
||||
T* tab = table + id * D;
|
||||
for (int i = idx; i < D; i += BlockDimX) {
|
||||
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
|
||||
}
|
||||
idy += BlockDimY * GridDimX;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class LookupTableCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto table_t = context.Input<Tensor>("W");
|
||||
auto ids_t = context.Input<Tensor>("Ids");
|
||||
auto output_t = context.Output<Tensor>("Out");
|
||||
|
||||
size_t N = table_t->dims()[0];
|
||||
size_t D = table_t->dims()[1];
|
||||
size_t K = product(ids_t->dims());
|
||||
auto ids = ids_t->data<int32_t>();
|
||||
auto table = table_t->data<T>();
|
||||
auto output = output_t->mutable_data<T>(context.GetPlace());
|
||||
|
||||
dim3 threads(128, 8);
|
||||
dim3 grids(8, 1);
|
||||
LookupTable<T, 128, 8, 8><<<grids, threads>>>(output, table, ids, N, K, D);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LookupTableGradCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto ids_t = context.Input<Tensor>("Ids");
|
||||
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
|
||||
|
||||
int N = d_table_t->dims()[0];
|
||||
int D = d_table_t->dims()[1];
|
||||
int K = product(ids_t->dims());
|
||||
const int32_t* ids = ids_t->data<int32_t>();
|
||||
const T* d_output = d_output_t->data<T>();
|
||||
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
|
||||
t.device(context.GetEigenDevice<platform::GPUPlace>()) =
|
||||
t.constant(static_cast<T>(0));
|
||||
|
||||
dim3 threads(128, 8);
|
||||
dim3 grids(8, 1);
|
||||
LookupTableGrad<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output, ids, N,
|
||||
K, D);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
|
||||
REGISTER_OP_GPU_KERNEL(lookup_table_grad,
|
||||
ops::LookupTableGradCUDAKernel<float>);
|
@ -0,0 +1,75 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
class LookupTableKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto table_t = context.Input<Tensor>("W"); // float tensor
|
||||
auto ids_t = context.Input<Tensor>("Ids"); // int tensor
|
||||
auto output_t = context.Output<Tensor>("Out"); // float tensor
|
||||
|
||||
size_t N = table_t->dims()[0];
|
||||
size_t D = table_t->dims()[1];
|
||||
auto ids = ids_t->data<int32_t>();
|
||||
auto table = table_t->data<T>();
|
||||
auto output = output_t->mutable_data<T>(context.GetPlace());
|
||||
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
|
||||
PADDLE_ENFORCE_LT(ids[i], N);
|
||||
PADDLE_ENFORCE_GE(ids[i], 0);
|
||||
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LookupTableGradKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto ids_t = context.Input<Tensor>("Ids");
|
||||
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
|
||||
|
||||
size_t N = d_table_t->dims()[0];
|
||||
size_t D = d_table_t->dims()[1];
|
||||
auto ids = ids_t->data<int32_t>();
|
||||
const T* d_output = d_output_t->data<T>();
|
||||
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
|
||||
t.device(context.GetEigenDevice<platform::CPUPlace>()) =
|
||||
t.constant(static_cast<T>(0));
|
||||
|
||||
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
|
||||
PADDLE_ENFORCE_LT(ids[i], N);
|
||||
PADDLE_ENFORCE_GE(ids[i], 0);
|
||||
for (size_t j = 0; j < D; ++j) {
|
||||
d_table[ids[i] * D + j] += d_output[i * D + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <cuda.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
#define CUDA_ATOMIC_WRAPPER(op, T) \
|
||||
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
|
||||
|
||||
#define USE_CUDA_ATOMIC(op, T) \
|
||||
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
|
||||
|
||||
// For atomicAdd.
|
||||
USE_CUDA_ATOMIC(Add, float);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
USE_CUDA_ATOMIC(Add, double);
|
||||
#else
|
||||
CUDA_ATOMIC_WRAPPER(Add, double) {
|
||||
unsigned long long int* address_as_ull =
|
||||
reinterpret_cast<unsigned long long int*>(address);
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(val + __longlong_as_double(assumed)));
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN
|
||||
} while (assumed != old);
|
||||
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,31 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test_util import OpTestMeta
|
||||
from gradient_checker import GradientChecker, create_op
|
||||
|
||||
|
||||
class TestSigmoidOp(unittest.TestCase):
|
||||
__metaclass__ = OpTestMeta
|
||||
|
||||
def setUp(self):
|
||||
self.type = 'lookup_table'
|
||||
table = np.random.random((17, 31)).astype('float32')
|
||||
ids = np.random.randint(0, 17, 4).astype('int32')
|
||||
self.inputs = {'W': table, 'Ids': ids}
|
||||
self.outputs = {'Out': table[ids]}
|
||||
|
||||
|
||||
class TestSigmoidGradOp(GradientChecker):
|
||||
def test_grad(self):
|
||||
op = create_op('lookup_table')
|
||||
table = np.random.random((17, 31)).astype('float32')
|
||||
ids = np.random.randint(0, 17, 4).astype('int32')
|
||||
inputs = {'W': table, 'Ids': ids}
|
||||
# comapre gradients
|
||||
self.compare_grad(op, inputs, set(['Ids']))
|
||||
# check gradients
|
||||
self.check_grad(op, inputs, set('W'), 'Out')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue