CompareOp's kernel device type is decided by input tensor place

CompareOp can run on CPU even other operators are running on GPU, since
opeatations like comparing control flags should be performed only on CPU
mobile_baidu
Yang Yu 7 years ago
parent c365c61ac6
commit 3187451ae7
No known key found for this signature in database
GPG Key ID: 0AC769FE2C5F8F25

@ -14,6 +14,7 @@
#include "paddle/operators/compare_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename OpComment>
@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
}
};
class CompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OP_WITH_KERNEL( \
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
#define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");

@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> {
template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const DeviceContext& context, InputIter first, InputIter last,
OutputIter result, UnaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first, last, result, op);
}
@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> {
void operator()(const DeviceContext& context, InputIter1 first1,
InputIter1 last1, InputIter2 first2, OutputIter result,
BinaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first1, last1, first2, result, op);
}
};

Loading…
Cancel
Save