|
|
|
@ -703,8 +703,6 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
|
|
|
|
|
|
|
|
|
|
void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const {
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
|
|
|
|
@ -758,6 +756,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
dev_ctx = pool.Get(expected_kernel_key.place_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope);
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
|
|
|
|
|
|
|
|
|
|
if (!transfered_inplace_vars.empty()) {
|
|
|
|
|