add python inference api (#15248)

add python inference api
inference-pre-release-gpu
flame 6 years ago committed by GitHub
parent 59ab98c9a8
commit d60751fb71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -45,6 +45,7 @@ paddle.fluid.AsyncExecutor.save_model ArgSpec(args=['self', 'save_path'], vararg
paddle.fluid.AsyncExecutor.stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.CompiledProgram.__init__ ArgSpec(args=['self', 'program'], varargs=None, keywords=None, defaults=None)
paddle.fluid.CompiledProgram.with_data_parallel ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.CompiledProgram.with_inference_optimize ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=None)
paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.ExecutionStrategy) -> None
paddle.fluid.BuildStrategy.GradientScaleStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy.GradientScaleStrategy, arg0: int) -> None
paddle.fluid.BuildStrategy.ReduceStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy.ReduceStrategy, arg0: int) -> None

@ -45,6 +45,7 @@ using contrib::AnalysisConfig;
class AnalysisPredictor : public PaddlePredictor {
public:
explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {}
~AnalysisPredictor();
bool Init(const std::shared_ptr<framework::Scope> &parent_scope,
const std::shared_ptr<framework::ProgramDesc> &program = nullptr);
@ -95,7 +96,6 @@ class AnalysisPredictor : public PaddlePredictor {
template <typename T>
void GetFetchOne(const framework::LoDTensor &fetchs,
PaddleTensor *output_data);
~AnalysisPredictor();
// Some more detailed tests, they are made the friends of the predictor, so that
// the all the details can be tested.

@ -1,10 +1,11 @@
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune
feed_fetch_method pass_builder parallel_executor profiler layer scope_pool
tracer)
tracer analysis_predictor)
if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op)
endif()
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc ir.cc)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc)
if(WITH_PYTHON)
if(WITH_AMD_GPU)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,23 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pybind11/pybind11.h>
namespace paddle {
namespace pybind {
void BindInferenceApi(pybind11::module *m);
} // namespace pybind
} // namespace paddle

@ -49,6 +49,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/protobuf.h"
#include "paddle/fluid/pybind/pybind.h" // NOLINT
@ -1083,9 +1084,9 @@ All parameter, weight, gradient are variables in Paddle.
BindRecordIOWriter(&m);
BindAsyncExecutor(&m);
BindGraph(&m);
BindNode(&m);
BindInferenceApi(&m);
}
} // namespace pybind
} // namespace paddle

@ -24,6 +24,8 @@ __all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy']
ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy
InferNativeConfig = core.NativeConfig
InferAnalysisConfig = core.AnalysisConfig
def _place_obj(place):
@ -70,6 +72,7 @@ class CompiledProgram(object):
self._executor = None
self._compiled = False
self._is_data_parallel = False
self._is_inference = False
def with_data_parallel(self,
loss_name=None,
@ -109,10 +112,24 @@ class CompiledProgram(object):
self._build_strategy = BuildStrategy()
return self
def _with_distributed(self):
raise NotImplementedError()
def with_inference_optimize(self, config):
""" Add inference optimize
Args:
config: instance of `NativeConfig` or `AnalysisConfig` to create predictor
Returns:
self
"""
assert any([
isinstance(config, InferNativeConfig),
isinstance(config, InferAnalysisConfig)
])
self._is_data_parallel = False
self._is_inference = True
self._infer_config = config
return self
def _with_inference_optimize(self):
def _with_distributed(self):
raise NotImplementedError()
def _compile_data_parallel(self):
@ -177,6 +194,10 @@ class CompiledProgram(object):
if self._loss_name else six.u(''), self._scope, self._local_scopes,
self._exec_strategy, self._build_strategy)
def _compile_inference(self):
assert self._is_data_parallel is False
return core.create_paddle_predictor(self._infer_config)
def _compile(self, scope, place):
"""Compile the program based on the configs.
@ -200,6 +221,8 @@ class CompiledProgram(object):
self._place = place
if self._is_data_parallel:
self._executor = self._compile_data_parallel()
elif self._is_inference:
self._executor = self._compile_inference()
else:
p = _place_obj(self._place)
self._executor = core.Executor(p)

@ -27,6 +27,8 @@ from .. import compat as cpt
__all__ = ['Executor', 'global_scope', 'scope_guard']
g_scope = core.Scope()
InferNativeConfig = core.NativeConfig
InferAnalysisConfig = core.AnalysisConfig
def global_scope():
@ -533,6 +535,8 @@ class Executor(object):
fetch_list=fetch_list,
fetch_var_name=fetch_var_name,
return_numpy=return_numpy)
elif program._is_inference:
return self._run_inference(program, feed)
else:
# TODO(panyx0718): Can compile program to optimize executor
# performance.
@ -590,3 +594,6 @@ class Executor(object):
if return_numpy:
outs = as_numpy(outs)
return outs
def _run_inference(self, program, feed):
return self.executor.run(feed)

@ -195,9 +195,34 @@ def infer(use_cuda, save_dirname=None):
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].recursive_sequence_lengths())
def to_infer_tensor(lod_tensor):
infer_tensor = fluid.core.PaddleTensor()
infer_tensor.lod = lod_tensor.lod()
infer_tensor.data = fluid.core.PaddleBuf(np.array(lod_tensor))
infer_tensor.shape = lod_tensor.shape()
infer_tensor.dtype = fluid.core.PaddleDType.INT64
return infer_tensor
infer_inputs = [first_word, second_word, third_word, fourth_word]
infer_inputs = [to_infer_tensor(t) for t in infer_inputs]
infer_config = fluid.core.NativeConfig()
infer_config.model_dir = 'word2vec.inference.model'
infer_config.use_gpu = use_cuda
if use_cuda:
infer_config.device = 0
infer_config.fraction_of_gpu_memory = 0.15
compiled_program = fluid.compiler.CompiledProgram(inference_program)
compiled_program.with_inference_optimize(infer_config)
assert compiled_program._is_inference is True
infer_outputs = exe.run(compiled_program, feed=infer_inputs)
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
infer_out = infer_outputs[0].data.float_data()
for a, b in zip(np_data[0], infer_out):
g_a = float("{:.6g}".format(a))
g_b = float("{:.6g}".format(b))
assert g_a == g_b
def main(use_cuda, is_sparse, is_parallel):

Loading…
Cancel
Save