Add cuda support for unique op (#27646)

* unique op for cuda is added

* add support for cuda

* Add cuda support for unique op.

* Add support for int32_t and int64_t.

* For old version, process by cpu

* Add VisitDataType for thrust
my_2.0rc
AshburnLee 5 years ago committed by GitHub
parent bbc2add703
commit c3a3df6466
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -63,6 +63,11 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64);
// For the use of thrust, as index-type elements can be only integers.
#define _ForEachDataTypeTiny_(callback) \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64);
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
@ -107,6 +112,20 @@ inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) {
#undef VisitDataTypeCallbackSmall
}
template <typename Visitor>
inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) {
#define VisitDataTypeCallbackTiny(cpp_type, proto_type) \
do { \
if (type == proto_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)
_ForEachDataTypeTiny_(VisitDataTypeCallbackTiny);
#undef VisitDataTypeCallbackTiny
}
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out,

@ -87,9 +87,17 @@ class UniqueOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
// Return CPUPlace when Attr("is_sorted") is false. Because it means
// that fluid.layers.unique is called, but there is no cuda kernel.
if (!ctx.Attr<bool>("is_sorted")) {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} else {
// new version paddle.unique is called.
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
}
};

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save