|
|
|
@ -24,8 +24,9 @@ template <typename T>
|
|
|
|
|
class GatherOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"This kernel only runs on GPU device."));
|
|
|
|
|
auto *x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto *index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto *output = ctx.Output<Tensor>("Out");
|
|
|
|
@ -35,12 +36,15 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
const auto &index_type = index->type();
|
|
|
|
|
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
|
|
|
|
index_type == framework::proto::VarType::INT64;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
index_type_match,
|
|
|
|
|
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
|
|
|
|
|
paddle::framework::DataTypeToString(index_type),
|
|
|
|
|
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
|
|
|
|
|
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
|
|
|
|
|
PADDLE_ENFORCE_EQ(index_type_match, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Index holds the wrong type, it holds [%s],"
|
|
|
|
|
"but desires to be [%s] or [%s].",
|
|
|
|
|
paddle::framework::DataTypeToString(index_type),
|
|
|
|
|
paddle::framework::DataTypeToString(
|
|
|
|
|
framework::proto::VarType::INT32),
|
|
|
|
|
paddle::framework::DataTypeToString(
|
|
|
|
|
framework::proto::VarType::INT64)));
|
|
|
|
|
if (index_type == framework::proto::VarType::INT32) {
|
|
|
|
|
GPUGather<T, int>(ctx.device_context(), *x, *index, output);
|
|
|
|
|
} else if (index_type == framework::proto::VarType::INT64) {
|
|
|
|
@ -53,8 +57,9 @@ template <typename T>
|
|
|
|
|
class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"This kernel only runs on GPU device."));
|
|
|
|
|
auto *index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
@ -69,12 +74,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
const auto &index_type = index->type();
|
|
|
|
|
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
|
|
|
|
index_type == framework::proto::VarType::INT64;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
index_type_match,
|
|
|
|
|
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
|
|
|
|
|
paddle::framework::DataTypeToString(index_type),
|
|
|
|
|
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
|
|
|
|
|
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
|
|
|
|
|
PADDLE_ENFORCE_EQ(index_type_match, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Index holds the wrong type, it holds [%s],"
|
|
|
|
|
"but desires to be [%s] or [%s].",
|
|
|
|
|
paddle::framework::DataTypeToString(index_type),
|
|
|
|
|
paddle::framework::DataTypeToString(
|
|
|
|
|
framework::proto::VarType::INT32),
|
|
|
|
|
paddle::framework::DataTypeToString(
|
|
|
|
|
framework::proto::VarType::INT64)));
|
|
|
|
|
if (index_type == framework::proto::VarType::INT32) {
|
|
|
|
|
GPUScatterAssign<T, int>(ctx, *dO, *index, dX,
|
|
|
|
|
ctx.Attr<bool>("overwrite"));
|
|
|
|
|