|
|
|
@ -32,7 +32,20 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
if (x->numel() == 0) return;
|
|
|
|
|
GPUGather<T>(ctx.device_context(), *x, *index, output);
|
|
|
|
|
const auto &index_type = index->type();
|
|
|
|
|
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
|
|
|
|
index_type == framework::proto::VarType::INT64;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
index_type_match,
|
|
|
|
|
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
|
|
|
|
|
paddle::framework::DataTypeToString(index_type),
|
|
|
|
|
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
|
|
|
|
|
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
|
|
|
|
|
if (index_type == framework::proto::VarType::INT32) {
|
|
|
|
|
GPUGather<T, int>(ctx.device_context(), *x, *index, output);
|
|
|
|
|
} else if (index_type == framework::proto::VarType::INT64) {
|
|
|
|
|
GPUGather<T, int64_t>(ctx.device_context(), *x, *index, output);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -42,7 +55,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
|
auto *Index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto *index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
@ -52,7 +65,21 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
.eigen_device();
|
|
|
|
|
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
|
|
|
|
if (dO->numel() == 0) return;
|
|
|
|
|
GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
|
|
|
|
|
|
|
|
|
|
const auto &index_type = index->type();
|
|
|
|
|
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
|
|
|
|
|
index_type == framework::proto::VarType::INT64;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
index_type_match,
|
|
|
|
|
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
|
|
|
|
|
paddle::framework::DataTypeToString(index_type),
|
|
|
|
|
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
|
|
|
|
|
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
|
|
|
|
|
if (index_type == framework::proto::VarType::INT32) {
|
|
|
|
|
GPUScatterAssign<T, int>(ctx.device_context(), *dO, *index, dX);
|
|
|
|
|
} else if (index_type == framework::proto::VarType::INT64) {
|
|
|
|
|
GPUScatterAssign<T, int64_t>(ctx.device_context(), *dO, *index, dX);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|