Merge pull request #14836 from sneaxiy/feature/py_func

Featue/py_func op
revert-15207-remove_op_handle_lock_and_fix_var
Zeng Jinle 6 years ago committed by GitHub
commit 95cbe07c40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -208,6 +208,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))

@ -123,6 +123,8 @@ class OpDesc {
BlockDesc *Block() { return this->block_; }
const BlockDesc *Block() const { return this->block_; }
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {

@ -25,6 +25,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
class OperatorBase;
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
class InferShapeContext {

@ -42,8 +42,7 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
# warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32)
@ -92,4 +91,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
endif()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")

File diff suppressed because it is too large Load Diff

@ -0,0 +1,25 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
namespace paddle {
namespace operators {
size_t AppendPythonCallableObjectAndReturnId(const ::pybind11::object &py_obj);
} // namespace operators
} // namespace paddle

@ -1,5 +1,8 @@
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer)
if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op)
endif()
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc)
if(WITH_PYTHON)

@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) {
.def("infer_var_type", &pd::OpDesc::InferVarType)
.def("set_is_target", &pd::OpDesc::SetIsTarget)
.def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", &pd::OpDesc::Block,
.def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference);
}

@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
@ -110,6 +111,12 @@ PYBIND11_MODULE(core, m) {
BindException(&m);
m.def(
"_append_python_callable_object_and_return_id",
[](py::object py_obj) -> size_t {
return paddle::operators::AppendPythonCallableObjectAndReturnId(py_obj);
});
py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC")
.def(py::init<>())
.def("_run_backward",

@ -18,7 +18,9 @@ All layers just related to the neural network.
from __future__ import print_function
import numpy as np
import six
import os
import inspect
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable, OpProtoHolder
@ -176,6 +178,7 @@ __all__ = [
'merge_selected_rows',
'get_tensor_from_selected_rows',
'lstm',
'py_func',
'psroi_pool',
'huber_loss',
]
@ -9327,6 +9330,224 @@ def get_tensor_from_selected_rows(x, name=None):
return out
class PyFuncRegistry(object):
_register_funcs = []
def __init__(self, func):
if func is None or not callable(func):
raise TypeError('func must be a Python function')
self._func = func
# find named args using reflection
args = inspect.getargspec(self._func)
if len(args[0]) == 0 and args[1] is None and args[2] is None:
# Function with no inputs
self._named_args = None
else:
self._named_args = args[0]
self._id = core._append_python_callable_object_and_return_id(self)
'''
Why record self here?
1. For debug usage. Users can call
:code:`py_func.registered_func(idx)` method
to find the registered function corresponding
to :code:`idx`.
2. For increasing reference count of self.
It seems that to release Python object
whose reference count is 1 would cause
segmentation fault error in C++ side.
May be lack of Python GC in C++ side?
'''
PyFuncRegistry._register_funcs.append(self)
@classmethod
def registered_func(cls, idx):
return cls._register_funcs[idx]._func
@classmethod
def registered_func_num(cls):
return len(cls._register_funcs)
@property
def id(self):
return self._id
def __call__(self, *args):
if self._named_args is None:
func_ret = self._func()
else:
kwargs = dict()
idx = 0
for arg in self._named_args:
kwargs[arg] = args[idx]
idx += 1
func_ret = self._func(*args[idx:], **kwargs)
if not isinstance(func_ret, (list, tuple)):
func_ret = (func_ret, )
ret = []
for each_ret in func_ret:
if each_ret is None or isinstance(each_ret, core.LoDTensor):
ret.append(each_ret)
continue
if not isinstance(each_ret, np.ndarray):
each_ret = np.array(each_ret)
tensor = core.LoDTensor()
tensor.set(each_ret, core.CPUPlace())
ret.append(tensor)
return tuple(ret)
@templatedoc()
def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
"""
PyFunc Operator.
User can use :code:`py_func` to register operators in Python side.
The inputs of :code:`func` is :code:`LoDTensor` and outputs can be
numpy array or :code:`LoDTensor`. Paddle would call the registered
:code:`func` in forward part, and call :code:`backward_func` in
backward part (if :code:`backward_func` is not None).
User should set the right data type and shape of :code:`out` before
calling this function. However, data types and shapes of gradients of
:code:`out` and :code:`x` would be inferred automatically.
Input orders of :code:`backward_func` would be: forward inputs
:code:`x`, forward outputs :code:`out` and backward input gradients of
:code:`out`. If some variables of :code:`out` have no gradient, the input
tensor would be None in Python side. If some variables of :code:`in` have
no gradient, users should return None.
This function can also be used to debug the running network. User can
add a :code:`py_func` operator without output, and print input
:code:`x` inside :code:`func`.
Args:
func (callable): forward Python function.
x (Variable|list(Variable)|tuple(Variable)): inputs of :code:`func`.
out (Variable|list(Variable)|tuple(Variable)): outputs of :code:`func`.
Paddle cannot infer shapes and data types of :code:`out`. Users
should create :code:`out` beforehand.
backward_func (callable|None): backward Python function.
None means no backward. Default None.
skip_vars_in_backward_input (Variable|list(Variable)|tuple(Variable)):
Variables that are not needed in :code:`backward_func` inputs.
These variables must be any of :code:`x` and :code:`out`.
If set, these vars would not be inputs of :code:`backward_func`,
Only useful when :code:`backward_func` is not None. Default None.
Returns:
out (Variable|list(Variable)|tuple(Variable)): input :code:`out`
Examples:
>>> import paddle.fluid as fluid
>>> import six
>>>
>>> def create_tmp_var(name, dtype, shape):
>>> return fluid.default_main_program().current_block().create_var(
>>> name=name, dtype=dtype, shape=shape)
>>>
>>> # tanh activation has been provided by Paddle C++ op
>>> # Here, we only use tanh to be an example to show the usage
>>> # of py_func
>>> def tanh(x):
>>> return np.tanh(x)
>>>
>>> # forward input x is skipped
>>> def tanh_grad(y, dy):
>>> return np.array(dy) * (1 - np.square(np.array(y)))
>>>
>>> def debug_func(x):
>>> print(x)
>>>
>>> def simple_net(img, label):
>>> hidden = img
>>> for idx in six.moves.range(4):
>>> hidden = fluid.layers.fc(hidden, size=200)
>>> new_hidden = create_tmp_var(name='hidden_{}'.format(idx),
>>> dtype=hidden.dtype, shape=hidden.shape)
>>>
>>> # user-defined layers with forward and backward
>>> hidden = fluid.layers.py_func(func=tanh, x=hidden,
>>> out=new_hidden, backward_func=tanh_grad,
>>> skip_vars_in_backward_input=hidden)
>>>
>>> # user-defined debug layers to print variables
>>> fluid.layers.py_func(func=debug_func, x=hidden, out=None)
>>>
>>> prediction = fluid.layers.fc(hidden, size=10, act='softmax')
>>> loss = fluid.layers.cross_entropy(input=prediction, label=label)
>>> return fluid.layers.mean(loss)
"""
helper = LayerHelper('py_func', **locals())
if x is None:
x = []
elif isinstance(x, Variable):
x = [x]
elif not isinstance(x, (list, tuple)):
raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)')
if out is None:
out_list = []
elif isinstance(out, Variable):
out_list = [out]
elif isinstance(out, (list, tuple)):
out_list = out
else:
raise TypeError(
'Output must be Variable/list(Variable)/tuple(Variable)')
fwd_func_id = PyFuncRegistry(func).id
bwd_func_id = PyFuncRegistry(
backward_func).id if backward_func is not None else -1
for each_out in out_list:
if len(each_out.shape) == 0:
raise ValueError(
'Output shapes of py_func op should be provided by users manually'
)
backward_skip_vars = set()
if backward_func is not None and skip_vars_in_backward_input is not None:
if isinstance(skip_vars_in_backward_input, Variable):
skip_vars_in_backward_input = [skip_vars_in_backward_input]
fwd_in_out = [v.name for v in x]
fwd_in_out.extend([v.name for v in out_list])
fwd_in_out = set(fwd_in_out)
backward_skip_vars = set()
for v in skip_vars_in_backward_input:
if not v.name in fwd_in_out:
raise ValueError(
'Variable {} is not found in forward inputs and outputs'
.format(v.name))
backward_skip_vars.add(v.name)
helper.append_op(
type='py_func',
inputs={'X': x},
outputs={'Out': out_list},
attrs={
'forward_callable_id': fwd_func_id,
'backward_callable_id': bwd_func_id,
'backward_skip_vars': list(backward_skip_vars)
})
return out
# For debug usage
py_func.registered_func = PyFuncRegistry.registered_func
py_func.registered_func_num = PyFuncRegistry.registered_func_num
@templatedoc()
def psroi_pool(input,
rois,

@ -0,0 +1,183 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
import paddle
import unittest
import six
import numpy as np
dev_cnt = 2
if fluid.core.is_compiled_with_cuda():
dev_cnt = fluid.core.get_cuda_device_count()
os.environ['CPU_NUM'] = str(dev_cnt)
def dummy_func_with_no_input():
return float(1.0)
def dummy_func_with_no_output(x):
pass
def tanh(x):
return np.tanh(x)
def tanh_grad(y, dy):
return np.array(dy) * (1 - np.square(np.array(y)))
def cross_entropy(logits, labels):
logits = np.array(logits)
labels = np.array(labels)
M = logits.shape[0]
N = logits.shape[1]
ret = np.ndarray([M, 1]).astype(logits.dtype)
for idx in six.moves.range(M):
ret[idx][0] = -np.log(logits[idx][labels[idx][0]])
return ret
def cross_entropy_grad(logits, labels, bwd_dout):
logits = np.array(logits)
labels = np.array(labels)
bwd_dout = np.array(bwd_dout)
M = logits.shape[0]
N = logits.shape[1]
dlogits = np.zeros([M, N]).astype(logits.dtype)
for idx in six.moves.range(M):
dlogits[idx][labels[idx][0]] = -bwd_dout[idx] / logits[idx][labels[idx][
0]]
return dlogits, None
def simple_fc_net(img, label, use_py_func_op):
hidden = img
for idx in range(4):
hidden = fluid.layers.fc(
hidden,
size=200,
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
if not use_py_func_op:
hidden = fluid.layers.tanh(hidden)
else:
new_hidden = fluid.default_main_program().current_block(
).create_var(
name='hidden_{}'.format(idx),
dtype='float32',
shape=hidden.shape)
hidden = fluid.layers.py_func(
func=tanh,
x=hidden,
out=new_hidden,
backward_func=tanh_grad,
skip_vars_in_backward_input=hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
if not use_py_func_op:
loss = fluid.layers.cross_entropy(input=prediction, label=label)
else:
loss = fluid.default_main_program().current_block().create_var(
name='loss', dtype='float32', shape=[-1, 1])
loss = fluid.layers.py_func(
func=cross_entropy,
x=[prediction, label],
out=loss,
backward_func=cross_entropy_grad,
skip_vars_in_backward_input=loss)
dummy_var = fluid.default_main_program().current_block().create_var(
name='test_tmp_var', dtype='float32', shape=[1])
fluid.layers.py_func(
func=dummy_func_with_no_input, x=None, out=dummy_var)
fluid.layers.py_func(func=dummy_func_with_no_output, x=loss, out=None)
loss = fluid.layers.mean(loss)
return loss
def reader():
for _ in six.moves.range(dev_cnt * 100):
yield np.random.random([784]), np.random.random_integers(
size=[1], low=0, high=9)
def test_main(use_cuda, use_py_func_op, use_parallel_executor):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return None
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.core.Scope()):
fluid.default_main_program().random_seed = 1
fluid.default_startup_program().random_seed = 1
np.random.seed(1)
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
loss = simple_fc_net(img, label, use_py_func_op)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
r = paddle.batch(reader, batch_size=10)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if use_parallel_executor:
exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=loss.name)
fetch_list = [loss.name]
else:
fetch_list = [loss]
ret = []
for epoch_id in six.moves.range(2):
for d in r():
L, = exe.run(feed=feeder.feed(d), fetch_list=fetch_list)
ret.append(L)
return np.array(ret)
class TestPyFuncOpUseExecutor(unittest.TestCase):
def setUp(self):
self.use_parallel_executor = False
def test_loss_diff(self):
losses = []
for use_cuda in [True, False]:
for use_py_func_op in [True, False]:
L = test_main(use_cuda, use_py_func_op,
self.use_parallel_executor)
if L is not None:
losses.append(L)
for idx in six.moves.range(len(losses) - 1):
max_diff = np.max(np.abs(losses[idx] - losses[0]))
self.assertAlmostEqual(max_diff, 0, delta=1e-3)
class TestPyFuncOpUseParallelExecutor(unittest.TestCase):
def setUp(self):
self.use_parallel_executor = True
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save