Enable device switching automatically for serveral operators (#8684)

optimizer
qingqing01 7 years ago committed by GitHub
parent ae2026e134
commit 9e1ec8c919
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,6 +41,14 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ColToRowMatchIndices", dims);
ctx->SetOutputDim("ColToRowMatchDist", dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("DistMat")->type()),
platform::CPUPlace());
}
};
template <typename T>

@ -62,7 +62,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("Scores")->type()),
ctx.device_context());
platform::CPUPlace());
}
};

@ -67,6 +67,14 @@ class PriorBoxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
platform::CPUPlace());
}
};
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {

Loading…
Cancel
Save