Enable device switching automatically for serveral operators (#8684)

optimizer
qingqing01 7 years ago committed by GitHub
parent ae2026e134
commit 9e1ec8c919
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,6 +41,14 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ColToRowMatchIndices", dims); ctx->SetOutputDim("ColToRowMatchIndices", dims);
ctx->SetOutputDim("ColToRowMatchDist", dims); ctx->SetOutputDim("ColToRowMatchDist", dims);
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("DistMat")->type()),
platform::CPUPlace());
}
}; };
template <typename T> template <typename T>

@ -62,7 +62,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::ToDataType(
ctx.Input<framework::LoDTensor>("Scores")->type()), ctx.Input<framework::LoDTensor>("Scores")->type()),
ctx.device_context()); platform::CPUPlace());
} }
}; };

@ -67,6 +67,14 @@ class PriorBoxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
platform::CPUPlace());
}
}; };
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {

Loading…
Cancel
Save