@ -33,7 +33,22 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
bool overwrite = ctx.Attr<bool>("overwrite");
Out->ShareDataWith(*X);
GPUScatterAssign<T>(ctx, *Updates, *Ids, Out, overwrite);
// use template class to support int32_t and int64_t
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"scatter_op Index holds the wrong type, it holds %s, but desires to be "
"%s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int32_t>(ctx, *Updates, *Ids, Out, overwrite);
} else {
GPUScatterAssign<T, int64_t>(ctx, *Updates, *Ids, Out, overwrite);
}
}
};
@ -54,7 +69,23 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"scatter_op Index holds the wrong type, it holds %s, but desires to "
"be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64));
// Gradient by Gather: dUpdates = dO[Ids]
if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
} else {
GPUGather<T, int64_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
}
}
};