|
|
|
@ -532,8 +532,7 @@ bool OpSupportGPU(const std::string& op_type) {
|
|
|
|
|
|
|
|
|
|
class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
public:
|
|
|
|
|
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope,
|
|
|
|
|
const RuntimeContext& ctx)
|
|
|
|
|
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx)
|
|
|
|
|
: op_(op), ctx_(ctx) {}
|
|
|
|
|
|
|
|
|
|
bool HasInput(const std::string& name) const override {
|
|
|
|
@ -901,7 +900,7 @@ static void CheckTensorNANOrInf(const std::string& op_type,
|
|
|
|
|
void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
|
|
|
|
|
const platform::Place& place,
|
|
|
|
|
const RuntimeContext& ctx) const {
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, scope, ctx);
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, ctx);
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -966,7 +965,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!all_kernels_must_compute_runtime_shape_) {
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
|
|
|
|
|
RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|