From 8760d23c7dbcb4ad5a5b941aca5917514467c86d Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 10 Dec 2018 13:09:28 +0000 Subject: [PATCH 1/8] featue/py_func --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/py_func_op.cc | 131 ++++++++++++++++++++++++++ paddle/fluid/operators/py_func_op.h | 25 +++++ paddle/fluid/pybind/pybind.cc | 21 +++++ python/paddle/fluid/layers/nn.py | 112 +++++++++++++++++++++- 5 files changed, 289 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/py_func_op.cc create mode 100644 paddle/fluid/operators/py_func_op.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 257bfc0a3f..9379122faf 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -82,7 +82,7 @@ endif() # op_library(unstack_op DEPS stack_op) # op_library(tensor_array_to_tensor_op DEPS concat_op) -set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS}) +set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS} python pybind) set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies") cc_test(gather_test SRCS gather_test.cc DEPS tensor) diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc new file mode 100644 index 0000000000..86914f3060 --- /dev/null +++ b/paddle/fluid/operators/py_func_op.cc @@ -0,0 +1,131 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/py_func_op.h" +#include +#include +#include +#include "Python.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +namespace py = pybind11; + +static std::mutex g_py_callables_mtx; +static std::vector g_py_callables; + +size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) { + std::lock_guard guard(g_py_callables_mtx); + g_py_callables.emplace_back(py_obj); + return g_py_callables.size() - 1; +} + +static py::object *GetPythonCallableObject(size_t i) { + std::lock_guard guard(g_py_callables_mtx); + PADDLE_ENFORCE_LT(i, g_py_callables.size()); + return &g_py_callables[i]; +} + +void DoCallPythonFunc(py::object *callable, const std::string &func_token, + const std::vector &ins, + std::vector *out) { + py::gil_scoped_acquire guard{}; + py::tuple in_args(ins.size()); + for (size_t i = 0; i < ins.size(); ++i) { + in_args[i] = py::cast(ins[i]); + } + + auto ret = (*callable)(func_token, *in_args); + auto ret_tuple = py::cast(ret); + PADDLE_ENFORCE_EQ(py::len(ret_tuple), out->size(), "Output number not match"); + for (size_t i = 0; i < out->size(); ++i) { + try { + auto *out_tensor = py::cast(ret_tuple[i]); + PADDLE_ENFORCE_NOT_NULL(out_tensor, + "Output tensor should not be nullptr"); + (*out)[i]->set_lod(out_tensor->lod()); + (*out)[i]->ShareDataWith(*out_tensor); + } catch (py::cast_error &) { + PADDLE_THROW("Output %d is not LoDTensor", i); + } + } +} + +class PyFuncOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInputs("X"), "Input(X) must exist"); + PADDLE_ENFORCE(ctx->HasOutputs("Out"), "Output(Out) must exist"); + } +}; + +class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Inputs of py_func op.").AsDuplicable(); + AddOutput("Out", "Outputs of py_func op").AsDuplicable(); + AddAttr("token", "function token"); + AddAttr("handle_idx", "handle index").SetDefault(0); + AddComment(R"DOC("PyFunc Op")DOC"); + } +}; + +class PyFuncOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + protected: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &in_arg_names = Inputs("X"); + auto &out_arg_names = Outputs("Out"); + + std::vector inputs(in_arg_names.size()); + for (size_t i = 0; i < in_arg_names.size(); ++i) { + auto &in_tensor = + scope.FindVar(in_arg_names[i])->Get(); + if (platform::is_gpu_place(in_tensor.place())) { + framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]); + } else { + inputs[i].ShareDataWith(in_tensor); + } + inputs[i].set_lod(in_tensor.lod()); + } + + std::vector outputs(out_arg_names.size()); + for (size_t i = 0; i < out_arg_names.size(); ++i) { + auto *out_tensor = + scope.FindVar(out_arg_names[i])->GetMutable(); + outputs[i] = out_tensor; + } + + auto &token = Attr("token"); + auto handle_idx = static_cast(Attr("handle_idx")); + auto *py_callable = GetPythonCallableObject(handle_idx); + VLOG(10) << "Call py_func_op with token " << token << ", and handle_idx " + << handle_idx; + DoCallPythonFunc(py_callable, token, inputs, &outputs); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, + ops::PyFuncOpShapeInference, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/py_func_op.h b/paddle/fluid/operators/py_func_op.h new file mode 100644 index 0000000000..e85fa6b5bc --- /dev/null +++ b/paddle/fluid/operators/py_func_op.h @@ -0,0 +1,25 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" + +namespace paddle { +namespace operators { + +size_t AppendPythonCallableObjectAndReturnId(pybind11::object py_obj); + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 58ef3da0b2..58da2cea34 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -36,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/framework/version.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" @@ -100,6 +101,12 @@ PYBIND11_MODULE(core, m) { BindException(&m); + m.def( + "append_python_callable_object_and_return_id", + [](py::object py_obj) -> size_t { + return paddle::operators::AppendPythonCallableObjectAndReturnId(py_obj); + }); + py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) @@ -525,6 +532,20 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Place") .def(py::init<>()) + .def("is_cpu_place", + [](platform::Place &self) { return platform::is_cpu_place(self); }) + .def("is_gpu_place", + [](platform::Place &self) { return platform::is_gpu_place(self); }) + .def("is_cuda_pinned_place", + [](platform::Place &self) { + return platform::is_cuda_pinned_place(self); + }) + .def("gpu_device_id", + [](platform::Place &self) { + PADDLE_ENFORCE(platform::is_gpu_place(self), + "gpu_device_id() only supports in CUDAPlace"); + return boost::get(self).device; + }) .def("set_place", [](platform::Place &self, const platform::CPUPlace &cpu_place) { self = cpu_place; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4833212d31..92cd53a6c3 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -18,10 +18,12 @@ All layers just related to the neural network. from __future__ import print_function import numpy as np +import six import os +import inspect from ..layer_helper import LayerHelper from ..initializer import Normal, Constant -from ..framework import Variable, OpProtoHolder +from ..framework import Variable, OpProtoHolder, Program from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ from .tensor import concat @@ -172,6 +174,7 @@ __all__ = [ 'merge_selected_rows', 'get_tensor_from_selected_rows', 'lstm', + 'py_func', ] kIgnoreIndex = -100 @@ -9082,3 +9085,110 @@ def get_tensor_from_selected_rows(x, name=None): outputs={'Out': out}, attrs={}) return out + + +@templatedoc() +def py_func(func, x, out, backward_func=None): + """ + """ + + class PyFuncRegister(object): + _main_program_to_register = dict() + + @classmethod + def get_instance(cls, prog=None): + if prog is None: + prog = fluid.default_main_program() + + if not isinstance(prog, Program): + raise ValueError("prog must be None or type of Program") + + ret = cls._main_program_to_register.get(prog, None) + if ret is None: + ret = PyFuncRegister() + ret._idx = core.append_python_callable_object_and_return_id(ret) + ret._token_func_dict = dict() + ret._func_token_dict = dict() + cls._main_program_to_register[prog] = ret + + return ret + + @property + def handle_idx(self): + return self._idx + + def unique_token(self, func): + return self._register_func(func) + + def _register_func(self, func): + if func is None: + raise ValueError("func cannot be None") + + token = self._func_token_dict.get(func, None) + if token is not None: + return token + + token = unique_name.generate('py_func_op_token') + self._token_func_dict[token] = func + self._func_token_dict[func] = token + return token + + def __call__(self, token, *args): + func = self._token_func_dict.get(token, None) + if func is None: + raise ValueError("func has not been registered") + + arg_list = inspect.getargspec(func) + kwargs = dict() + idx = 0 + for arg in arg_list[0]: + kwargs[arg] = args[idx] + idx += 1 + + args = args[idx:] + ret0 = func(*args, **kwargs) + if ret0 is None: + return None + + if not isinstance(ret0, (list, tuple)): + ret0 = (ret0, ) + + ret = [] + for i in six.moves.range(len(ret0)): + if isinstance(ret0[i], core.LoDTensor): + ret.append(ret0[i]) + continue + + if isinstance(ret0[i], np.ndarray): + r = ret0[i] + else: + r = np.array(ret0[i]) + + t = core.LoDTensor() + t.set(r, core.CPUPlace()) + ret.append(t) + + return tuple(ret) + + helper = LayerHelper('py_func', **locals()) + if isinstance(x, Variable): + x = [x] + + if isinstance(out, Variable): + out = [out] + + for each_out in out: + if len(each_out.shape) == 0: + raise ValueError( + 'users should infer shapes of outputs of py_func op manually') + + py_func_reg = PyFuncRegister.get_instance(helper.main_program) + token = py_func_reg.unique_token(func) + + helper.append_op( + type='py_func', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'handle_idx': py_func_reg.handle_idx, + 'token': token}) + return out From e240ba291853856d29790ecd3b6c5493c5ab2cd3 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 12 Dec 2018 03:16:34 +0000 Subject: [PATCH 2/8] implement backward test=develop --- paddle/fluid/framework/op_desc.cc | 2 + paddle/fluid/framework/op_desc.h | 2 + paddle/fluid/framework/operator.cc | 5 + paddle/fluid/framework/shape_inference.h | 5 + paddle/fluid/operators/py_func_op.cc | 127 ++++++++++++++++++++--- paddle/fluid/pybind/protobuf.cc | 2 +- python/paddle/fluid/layers/nn.py | 39 ++++--- 7 files changed, 154 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index e8ecd90502..f8a9340df5 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -34,6 +34,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { public: CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block); + InferShapeOpPtr GetOp() const override { return &op_; } + bool HasInput(const std::string &name) const override; bool HasOutput(const std::string &name) const override; diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 30c8a26c3d..3b3f50bfa7 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -121,6 +121,8 @@ class OpDesc { BlockDesc *Block() { return this->block_; } + const BlockDesc *Block() const { return this->block_; } + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index c6f3254e9f..188ab120be 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -481,6 +481,8 @@ class RuntimeInferShapeContext : public InferShapeContext { RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) : op_(op), scope_(scope) {} + InferShapeOpPtr GetOp() const override { return &op_; } + bool HasInput(const std::string& name) const override { // has only one input const auto& ins = op_.Inputs(); @@ -879,6 +881,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( t = &(var->Get().value()); } if (t != nullptr) { + PADDLE_ENFORCE(t->IsInitialized(), + "Input %s(%s) does not exist in Operator %s", + input.first, ipt_name, DebugString()); int tmp = static_cast(ToDataType(t->type())); PADDLE_ENFORCE( tmp == data_type || data_type == -1, diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index d73cca121e..2f95ab353e 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -25,7 +25,10 @@ limitations under the License. */ namespace paddle { namespace framework { +class OperatorBase; + using InferShapeVarPtr = boost::variant; +using InferShapeOpPtr = boost::variant; class InferShapeContext { public: @@ -38,6 +41,8 @@ class InferShapeContext { std::vector GetOutputsVarType( const std::string &name) const; + virtual InferShapeOpPtr GetOp() const = 0; + virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index 86914f3060..46a6125f97 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -24,34 +24,34 @@ namespace operators { namespace py = pybind11; -static std::mutex g_py_callables_mtx; static std::vector g_py_callables; size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) { - std::lock_guard guard(g_py_callables_mtx); g_py_callables.emplace_back(py_obj); return g_py_callables.size() - 1; } static py::object *GetPythonCallableObject(size_t i) { - std::lock_guard guard(g_py_callables_mtx); PADDLE_ENFORCE_LT(i, g_py_callables.size()); return &g_py_callables[i]; } -void DoCallPythonFunc(py::object *callable, const std::string &func_token, - const std::vector &ins, - std::vector *out) { +void CallPythonFunc(py::object *callable, const std::string &func_token, + const std::vector &ins, + std::vector *out) { py::gil_scoped_acquire guard{}; py::tuple in_args(ins.size()); for (size_t i = 0; i < ins.size(); ++i) { - in_args[i] = py::cast(ins[i]); + in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr); } auto ret = (*callable)(func_token, *in_args); auto ret_tuple = py::cast(ret); PADDLE_ENFORCE_EQ(py::len(ret_tuple), out->size(), "Output number not match"); for (size_t i = 0; i < out->size(); ++i) { + if ((*out)[i] == nullptr) { + continue; + } try { auto *out_tensor = py::cast(ret_tuple[i]); PADDLE_ENFORCE_NOT_NULL(out_tensor, @@ -67,8 +67,43 @@ void DoCallPythonFunc(py::object *callable, const std::string &func_token, class PyFuncOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(!ctx->IsRuntime(), + "Infer shape cannot be called in runtime."); PADDLE_ENFORCE(ctx->HasInputs("X"), "Input(X) must exist"); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "Output(Out) must exist"); + + auto *op = boost::get(ctx->GetOp()); + auto *block = op->Block(); + // No need to infer shape in forward part + if (block->ForwardBlockID() < 0) { + return; + } + + PADDLE_ENFORCE(!ctx->Attrs().Get("token").empty(), + "Function token cannot be empty"); + + const std::string kGradVarSuffix = framework::kGradVarSuffix; + auto out_vars = ctx->GetOutputVarPtrs("Out"); + for (auto &out_var : out_vars) { + auto *out_var_desc = boost::get(out_var); + auto out_name = out_var_desc->Name(); + if (out_name == framework::kEmptyVarName || + out_name.size() < kGradVarSuffix.size()) { + continue; + } + + size_t len = out_name.size() - kGradVarSuffix.size(); + if (out_name.substr(len) == kGradVarSuffix) { + auto fwd_var_name = out_name.substr(0, len); + auto *in_var_desc = block->FindVarRecursive(fwd_var_name); + PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found", + fwd_var_name); + out_var_desc->SetShape(in_var_desc->GetShape()); + out_var_desc->SetDataType(in_var_desc->GetDataType()); + out_var_desc->SetLoDLevel(in_var_desc->GetLoDLevel()); + out_var_desc->SetType(in_var_desc->GetType()); + } + } } }; @@ -77,12 +112,68 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Inputs of py_func op.").AsDuplicable(); AddOutput("Out", "Outputs of py_func op").AsDuplicable(); - AddAttr("token", "function token"); - AddAttr("handle_idx", "handle index").SetDefault(0); + AddAttr("handle_idx", "Index of the registered py_func handle") + .SetDefault(0); + AddAttr("token", "Token of function token to be called") + .SetDefault(""); + AddAttr("backward_token", + "Token of backward function to be called") + .SetDefault(""); AddComment(R"DOC("PyFunc Op")DOC"); } }; +class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase { + public: + using framework::GradOpDescMakerBase::GradOpDescMakerBase; + + std::vector> operator()() const override { + auto &fwd_attrs = Attrs(); + if (fwd_attrs.at("backward_token").empty()) { + return {}; + } + + std::unique_ptr grad_op(new framework::OpDesc()); + grad_op->SetType("py_func"); + + framework::AttributeMap bwd_attrs; + bwd_attrs["token"] = fwd_attrs.at("backward_token"); + bwd_attrs["backward_token"] = std::string(""); + grad_op->SetAttrMap(bwd_attrs); + + auto bwd_in = Input("X"); + auto fwd_out = Output("Out"); + auto fwd_out_grad = OutputGrad("Out"); + bwd_in.insert(bwd_in.end(), fwd_out.begin(), fwd_out.end()); + bwd_in.insert(bwd_in.end(), fwd_out_grad.begin(), fwd_out_grad.end()); + + auto bwd_out = InputGrad("X", false); + + if (VLOG_IS_ON(10)) { + std::string in_str = "PyFunc Grad Input: "; + for (auto &in : bwd_in) { + in_str += in; + in_str += " "; + } + VLOG(10) << in_str; + + std::string out_str = "PyFunc Grad Output: "; + for (auto &out : bwd_out) { + out_str += out; + out += " "; + } + VLOG(10) << out_str; + } + + grad_op->SetInput("X", bwd_in); + grad_op->SetOutput("Out", InputGrad("X", false)); + + std::vector> ret(1); + ret[0] = std::move(grad_op); + return ret; + } +}; + class PyFuncOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; @@ -95,8 +186,14 @@ class PyFuncOp : public framework::OperatorBase { std::vector inputs(in_arg_names.size()); for (size_t i = 0; i < in_arg_names.size(); ++i) { - auto &in_tensor = - scope.FindVar(in_arg_names[i])->Get(); + auto in_var = scope.FindVar(in_arg_names[i]); + if (in_var == nullptr) { + continue; + } + auto &in_tensor = in_var->Get(); + if (!in_tensor.IsInitialized()) { + continue; + } if (platform::is_gpu_place(in_tensor.place())) { framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]); } else { @@ -107,8 +204,9 @@ class PyFuncOp : public framework::OperatorBase { std::vector outputs(out_arg_names.size()); for (size_t i = 0; i < out_arg_names.size(); ++i) { + auto *out_var = scope.FindVar(out_arg_names[i]); auto *out_tensor = - scope.FindVar(out_arg_names[i])->GetMutable(); + out_var ? out_var->GetMutable() : nullptr; outputs[i] = out_tensor; } @@ -117,7 +215,7 @@ class PyFuncOp : public framework::OperatorBase { auto *py_callable = GetPythonCallableObject(handle_idx); VLOG(10) << "Call py_func_op with token " << token << ", and handle_idx " << handle_idx; - DoCallPythonFunc(py_callable, token, inputs, &outputs); + CallPythonFunc(py_callable, token, inputs, &outputs); } }; @@ -127,5 +225,4 @@ class PyFuncOp : public framework::OperatorBase { namespace ops = paddle::operators; REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, - ops::PyFuncOpShapeInference, - paddle::framework::EmptyGradOpMaker); + ops::PyFuncOpShapeInference, ops::PyFuncOpGradDescMaker); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index ac406b27b5..4b218fb3a2 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) { .def("infer_var_type", &pd::OpDesc::InferVarType) .def("set_is_target", &pd::OpDesc::SetIsTarget) .def("serialize_to_string", SerializeMessage) - .def("block", &pd::OpDesc::Block, + .def("block", [](pd::OpDesc &self) { return self.Block(); }, pybind11::return_value_policy::reference); } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 92cd53a6c3..66c98c935d 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9096,12 +9096,9 @@ def py_func(func, x, out, backward_func=None): _main_program_to_register = dict() @classmethod - def get_instance(cls, prog=None): - if prog is None: - prog = fluid.default_main_program() - + def get_instance(cls, prog): if not isinstance(prog, Program): - raise ValueError("prog must be None or type of Program") + raise TypeError("prog must be type of Program") ret = cls._main_program_to_register.get(prog, None) if ret is None: @@ -9155,6 +9152,10 @@ def py_func(func, x, out, backward_func=None): ret = [] for i in six.moves.range(len(ret0)): + if ret0[i] is None: + ret.append(None) + continue + if isinstance(ret0[i], core.LoDTensor): ret.append(ret0[i]) continue @@ -9175,20 +9176,34 @@ def py_func(func, x, out, backward_func=None): x = [x] if isinstance(out, Variable): - out = [out] + out_list = [out] + else: + out_list = out + + if func is None or not hasattr(func, '__call__'): + raise TypeError('Input func must be a function') - for each_out in out: + if backward_func is not None and not hasattr(backward_func, '__call__'): + raise TypeError('Input backward_func must be a function') + + for each_out in out_list: if len(each_out.shape) == 0: raise ValueError( - 'users should infer shapes of outputs of py_func op manually') + 'Output shapes of py_func op should be provided by users manually' + ) py_func_reg = PyFuncRegister.get_instance(helper.main_program) - token = py_func_reg.unique_token(func) + forward_token = py_func_reg.unique_token(func) + backward_token = py_func_reg.unique_token( + backward_func) if backward_func is not None else '' helper.append_op( type='py_func', inputs={'X': x}, - outputs={'Out': out}, - attrs={'handle_idx': py_func_reg.handle_idx, - 'token': token}) + outputs={'Out': out_list}, + attrs={ + 'handle_idx': py_func_reg.handle_idx, + 'token': forward_token, + 'backward_token': backward_token + }) return out From 8b9d33fa1e7c3d592d9f3976634eaf87155f1a49 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 12 Dec 2018 12:32:26 +0000 Subject: [PATCH 3/8] add unittest and fix bug add API.spec test=develop --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/py_func_op.cc | 123 ++++++--- python/paddle/fluid/layers/nn.py | 242 +++++++++++------- .../fluid/tests/unittests/test_py_func_op.py | 145 +++++++++++ 4 files changed, 374 insertions(+), 137 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_py_func_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 2722ea078e..b3f7593be3 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -197,6 +197,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) +paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index 46a6125f97..32c44c3bc2 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -26,26 +26,35 @@ namespace py = pybind11; static std::vector g_py_callables; +const char kForwardPythonCallableId[] = "forward_callable_id"; +const char kBackwardPythonCallableId[] = "backward_callable_id"; +const char kPyFuncBackwardSkipVars[] = "backward_skip_vars"; + size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) { g_py_callables.emplace_back(py_obj); return g_py_callables.size() - 1; } static py::object *GetPythonCallableObject(size_t i) { - PADDLE_ENFORCE_LT(i, g_py_callables.size()); + PADDLE_ENFORCE_LT(i, g_py_callables.size(), "Invalid python callable id"); return &g_py_callables[i]; } -void CallPythonFunc(py::object *callable, const std::string &func_token, +std::string PythonObjectToString(const py::object &py_callable) { + py::gil_scoped_acquire guard; + return py::str(*py_callable); +} + +void CallPythonFunc(py::object *callable, const std::vector &ins, std::vector *out) { - py::gil_scoped_acquire guard{}; + py::gil_scoped_acquire guard; py::tuple in_args(ins.size()); for (size_t i = 0; i < ins.size(); ++i) { in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr); } - auto ret = (*callable)(func_token, *in_args); + auto ret = (*callable)(*in_args); auto ret_tuple = py::cast(ret); PADDLE_ENFORCE_EQ(py::len(ret_tuple), out->size(), "Output number not match"); for (size_t i = 0; i < out->size(); ++i) { @@ -55,7 +64,7 @@ void CallPythonFunc(py::object *callable, const std::string &func_token, try { auto *out_tensor = py::cast(ret_tuple[i]); PADDLE_ENFORCE_NOT_NULL(out_tensor, - "Output tensor should not be nullptr"); + "Output tensor %d should not be nullptr", i); (*out)[i]->set_lod(out_tensor->lod()); (*out)[i]->ShareDataWith(*out_tensor); } catch (py::cast_error &) { @@ -69,26 +78,23 @@ class PyFuncOpShapeInference : public framework::InferShapeBase { void operator()(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(!ctx->IsRuntime(), "Infer shape cannot be called in runtime."); - PADDLE_ENFORCE(ctx->HasInputs("X"), "Input(X) must exist"); - PADDLE_ENFORCE(ctx->HasOutputs("Out"), "Output(Out) must exist"); + PADDLE_ENFORCE(ctx->HasInputs("X") || ctx->HasOutputs("Out"), + "Input(X) or Output(Out) must exist"); + PADDLE_ENFORCE_GE(ctx->Attrs().Get(kForwardPythonCallableId), 0, + "Function id cannot be less than 0"); auto *op = boost::get(ctx->GetOp()); auto *block = op->Block(); - // No need to infer shape in forward part - if (block->ForwardBlockID() < 0) { - return; - } - - PADDLE_ENFORCE(!ctx->Attrs().Get("token").empty(), - "Function token cannot be empty"); - const std::string kGradVarSuffix = framework::kGradVarSuffix; auto out_vars = ctx->GetOutputVarPtrs("Out"); for (auto &out_var : out_vars) { auto *out_var_desc = boost::get(out_var); + if (out_var_desc == nullptr) { + continue; + } auto out_name = out_var_desc->Name(); if (out_name == framework::kEmptyVarName || - out_name.size() < kGradVarSuffix.size()) { + out_name.size() <= kGradVarSuffix.size()) { continue; } @@ -98,6 +104,8 @@ class PyFuncOpShapeInference : public framework::InferShapeBase { auto *in_var_desc = block->FindVarRecursive(fwd_var_name); PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found", fwd_var_name); + VLOG(10) << "Infer shape of Out(" << out_name << ") as Input(" + << in_var_desc->Name() << ")"; out_var_desc->SetShape(in_var_desc->GetShape()); out_var_desc->SetDataType(in_var_desc->GetDataType()); out_var_desc->SetLoDLevel(in_var_desc->GetLoDLevel()); @@ -112,13 +120,15 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Inputs of py_func op.").AsDuplicable(); AddOutput("Out", "Outputs of py_func op").AsDuplicable(); - AddAttr("handle_idx", "Index of the registered py_func handle") + AddAttr(kForwardPythonCallableId, + "Index of registered forward Python function.") .SetDefault(0); - AddAttr("token", "Token of function token to be called") - .SetDefault(""); - AddAttr("backward_token", - "Token of backward function to be called") - .SetDefault(""); + AddAttr(kBackwardPythonCallableId, + "Index of registered backward Python function") + .SetDefault(-1); + AddAttr>(kPyFuncBackwardSkipVars, + "Unused forward in/out in backward op") + .SetDefault(std::vector()); AddComment(R"DOC("PyFunc Op")DOC"); } }; @@ -129,7 +139,8 @@ class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase { std::vector> operator()() const override { auto &fwd_attrs = Attrs(); - if (fwd_attrs.at("backward_token").empty()) { + // no backward op when backward_id is less than 0 + if (boost::get(fwd_attrs.at(kBackwardPythonCallableId)) < 0) { return {}; } @@ -137,36 +148,65 @@ class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase { grad_op->SetType("py_func"); framework::AttributeMap bwd_attrs; - bwd_attrs["token"] = fwd_attrs.at("backward_token"); - bwd_attrs["backward_token"] = std::string(""); + bwd_attrs[kForwardPythonCallableId] = + fwd_attrs.at(kBackwardPythonCallableId); + bwd_attrs[kBackwardPythonCallableId] = -1; grad_op->SetAttrMap(bwd_attrs); - auto bwd_in = Input("X"); - auto fwd_out = Output("Out"); - auto fwd_out_grad = OutputGrad("Out"); - bwd_in.insert(bwd_in.end(), fwd_out.begin(), fwd_out.end()); - bwd_in.insert(bwd_in.end(), fwd_out_grad.begin(), fwd_out_grad.end()); + // All forward inputs + auto fwd_ins = Input("X"); + // All forward outputs + auto fwd_outs = Output("Out"); + + // For memory reused, some inputs/output in forward part may be not needed + // in backward part + // Just skip these vars + auto &backward_skip_var_list = boost::get>( + fwd_attrs.at(kPyFuncBackwardSkipVars)); + std::unordered_set backward_skip_var_set( + backward_skip_var_list.begin(), backward_skip_var_list.end()); + std::vector bwd_ins; + bwd_ins.reserve(fwd_ins.size() + fwd_outs.size()); + for (auto &fwd_in : fwd_ins) { + if (backward_skip_var_set.count(fwd_in) == 0) { + bwd_ins.emplace_back(fwd_in); + } + } + + for (auto &fwd_out : fwd_outs) { + if (backward_skip_var_set.count(fwd_out) == 0) { + bwd_ins.emplace_back(fwd_out); + } + } + + // Backward OG cannot be skipped + // But in Python side, if OG is kEmptyVarName, input tensor would be None + auto fwd_out_grads = OutputGrad("Out"); + bwd_ins.reserve(bwd_ins.size() + fwd_out_grads.size()); + bwd_ins.insert(bwd_ins.end(), fwd_out_grads.begin(), fwd_out_grads.end()); - auto bwd_out = InputGrad("X", false); + // Backward IG cannot be skipped + // But in Python side, if IG is not needed, users can just return None + auto bwd_outs = InputGrad("X", false); if (VLOG_IS_ON(10)) { std::string in_str = "PyFunc Grad Input: "; - for (auto &in : bwd_in) { + for (auto &in : bwd_ins) { in_str += in; in_str += " "; } VLOG(10) << in_str; std::string out_str = "PyFunc Grad Output: "; - for (auto &out : bwd_out) { + for (auto &out : bwd_outs) { out_str += out; - out += " "; + out_str += " "; } VLOG(10) << out_str; } - grad_op->SetInput("X", bwd_in); - grad_op->SetOutput("Out", InputGrad("X", false)); + grad_op->SetInput("X", bwd_ins); + grad_op->SetOutput("Out", bwd_outs); std::vector> ret(1); ret[0] = std::move(grad_op); @@ -210,12 +250,11 @@ class PyFuncOp : public framework::OperatorBase { outputs[i] = out_tensor; } - auto &token = Attr("token"); - auto handle_idx = static_cast(Attr("handle_idx")); - auto *py_callable = GetPythonCallableObject(handle_idx); - VLOG(10) << "Call py_func_op with token " << token << ", and handle_idx " - << handle_idx; - CallPythonFunc(py_callable, token, inputs, &outputs); + auto callable_id = static_cast(Attr(kForwardPythonCallableId)); + auto *py_callable = GetPythonCallableObject(callable_id); + VLOG(10) << "Call py_func_op with id " << callable_id << ": " + << PythonObjectToString(*py_callable); + CallPythonFunc(py_callable, inputs, &outputs); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 66c98c935d..95f046c614 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9087,104 +9087,140 @@ def get_tensor_from_selected_rows(x, name=None): return out -@templatedoc() -def py_func(func, x, out, backward_func=None): - """ - """ - - class PyFuncRegister(object): - _main_program_to_register = dict() - - @classmethod - def get_instance(cls, prog): - if not isinstance(prog, Program): - raise TypeError("prog must be type of Program") - - ret = cls._main_program_to_register.get(prog, None) - if ret is None: - ret = PyFuncRegister() - ret._idx = core.append_python_callable_object_and_return_id(ret) - ret._token_func_dict = dict() - ret._func_token_dict = dict() - cls._main_program_to_register[prog] = ret - - return ret - - @property - def handle_idx(self): - return self._idx - - def unique_token(self, func): - return self._register_func(func) - - def _register_func(self, func): - if func is None: - raise ValueError("func cannot be None") - - token = self._func_token_dict.get(func, None) - if token is not None: - return token - - token = unique_name.generate('py_func_op_token') - self._token_func_dict[token] = func - self._func_token_dict[func] = token - return token - - def __call__(self, token, *args): - func = self._token_func_dict.get(token, None) - if func is None: - raise ValueError("func has not been registered") - - arg_list = inspect.getargspec(func) - kwargs = dict() - idx = 0 - for arg in arg_list[0]: - kwargs[arg] = args[idx] - idx += 1 - - args = args[idx:] - ret0 = func(*args, **kwargs) - if ret0 is None: - return None - - if not isinstance(ret0, (list, tuple)): - ret0 = (ret0, ) - - ret = [] - for i in six.moves.range(len(ret0)): - if ret0[i] is None: - ret.append(None) - continue - - if isinstance(ret0[i], core.LoDTensor): - ret.append(ret0[i]) - continue +class PyFuncWrapper(object): + _register_funcs = [] + + def __init__(self, func): + if func is None or not hasattr(func, '__call__'): + raise TypeError('func must be a Python function') + + self._func = func + # find named args using reflection + self._named_args = inspect.getargspec(self._func)[0] + self._id = core.append_python_callable_object_and_return_id(self) + ''' + Why record self here? + + 1. For debug usage. Users can call + :code:`py_func.registered_func(idx)` method + to find the registered function coresponding + to :code:`idx`. + + 2. For increasing reference count of self. + It seems that to release Python object + whose reference count is 1 would cause + segmentation fault error in C++ side. + May be lack of Python GC in C++ side? + ''' + PyFuncWrapper._register_funcs.append(self) + + @classmethod + def registered_func(cls, idx): + return cls._register_funcs[idx]._func + + @classmethod + def registered_func_num(cls): + return len(cls._register_funcs) + + @property + def id(self): + return self._id + + def __call__(self, *args): + kwargs = dict() + idx = 0 + for arg in self._named_args: + kwargs[arg] = args[idx] + idx += 1 + + ret0 = self._func(*args[idx:], **kwargs) + if ret0 is None: + return None + + if not isinstance(ret0, (list, tuple)): + ret0 = (ret0, ) + + ret = [] + for i in six.moves.range(len(ret0)): + if ret0[i] is None: + ret.append(None) + continue + + if isinstance(ret0[i], core.LoDTensor): + ret.append(ret0[i]) + continue + + if isinstance(ret0[i], np.ndarray): + r = ret0[i] + else: + r = np.array(ret0[i]) - if isinstance(ret0[i], np.ndarray): - r = ret0[i] - else: - r = np.array(ret0[i]) + t = core.LoDTensor() + t.set(r, core.CPUPlace()) + ret.append(t) - t = core.LoDTensor() - t.set(r, core.CPUPlace()) - ret.append(t) + return tuple(ret) - return tuple(ret) +@templatedoc() +def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): + """ + PyFunc Operator. + + User can use :code:`py_func` to register operators in Python side. + The inputs of :code:`func` is :code:`LoDTensor` and outputs can be + numpy array or :code:`LoDTensor`. Paddle would call the registered + :code:`func` in forward part, and call :code:`backward_func` in + backward part (if :code:`backward_func` is not None). + + User should set the right data type and shape of :code:`out` before + calling this function. However, data types and shapes of gradients of + :code:`out` and :code:`x` would be infered automatically. + + The orders of inputs of :code:`backward_func` would be: forward input + :code:`x`, forward output :code:`out` and backward input gradient of + :code:`out`. If some variables of :code:`out` have no gradient, the input + tensor would be None in Python side. If some variables of :code:`in` have + no gradient, users should return None. + + Args: + func (callable): forward Python function. + x (Variable|list(Variable)|tuple(Variable)): inputs of :code:`func`. + out (Variable|list(Variable)|tuple(Variable)): outputs of :code:`func`. + Paddle cannot infer shapes and data types of :code:`out`. Users + should create :code:`out` beforehand. + backward_func (callable|None): backward Python function. + None means no backward. Default None. + skip_vars_in_backward_input (Variable|list(Variable)|tuple(Variable)): + Variables that are not needed in :code:`backward_func` inputs. + These variables must be any of :code:`x` and :code:`out`. + If set, these vars would not be inputs of :code:`backward_func`, + Only useful when :code:`backward_func` is not None. Default None. + + Returns: + out (Variable|list(Variable)|tuple(Variable)): input :code:`out` + """ helper = LayerHelper('py_func', **locals()) - if isinstance(x, Variable): + if x is None: + x = [] + elif isinstance(x, Variable): x = [x] + elif not isinstance(x, (list, tuple)): + raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)') - if isinstance(out, Variable): + if out is None: + out_list = [] + elif isinstance(out, Variable): out_list = [out] - else: + elif isinstance(out, (list, tuple)): out_list = out + else: + raise TypeError( + 'Output must be Variable/list(Variable)/tuple(Variable)') - if func is None or not hasattr(func, '__call__'): - raise TypeError('Input func must be a function') - - if backward_func is not None and not hasattr(backward_func, '__call__'): - raise TypeError('Input backward_func must be a function') + fwd_func_id = PyFuncWrapper(func).id + bwd_func_id = PyFuncWrapper( + backward_func).id if backward_func is not None else -1 for each_out in out_list: if len(each_out.shape) == 0: @@ -9192,18 +9228,34 @@ def py_func(func, x, out, backward_func=None): 'Output shapes of py_func op should be provided by users manually' ) - py_func_reg = PyFuncRegister.get_instance(helper.main_program) - forward_token = py_func_reg.unique_token(func) - backward_token = py_func_reg.unique_token( - backward_func) if backward_func is not None else '' + backward_skip_vars = set() + if backward_func is not None and skip_vars_in_backward_input is not None: + if isinstance(skip_vars_in_backward_input, Variable): + skip_vars_in_backward_input = [skip_vars_in_backward_input] + + fwd_in_out = [v.name for v in x] + fwd_in_out.extend([v.name for v in out_list]) + fwd_in_out = set(fwd_in_out) + backward_skip_vars = set() + for v in skip_vars_in_backward_input: + if not v.name in fwd_in_out: + raise ValueError( + 'Variable {} is not found in forward inputs and outputs' + .format(v.name)) + backward_skip_vars.add(v.name) helper.append_op( type='py_func', inputs={'X': x}, outputs={'Out': out_list}, attrs={ - 'handle_idx': py_func_reg.handle_idx, - 'token': forward_token, - 'backward_token': backward_token + 'forward_callable_id': fwd_func_id, + 'backward_callable_id': bwd_func_id, + 'backward_skip_vars': list(backward_skip_vars) }) return out + + +# For debug usage +py_func.registered_func = PyFuncWrapper.registered_func +py_func.registered_func_num = PyFuncWrapper.registered_func_num diff --git a/python/paddle/fluid/tests/unittests/test_py_func_op.py b/python/paddle/fluid/tests/unittests/test_py_func_op.py new file mode 100644 index 0000000000..0f03368b7e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_py_func_op.py @@ -0,0 +1,145 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle +import unittest +import six +import numpy as np + + +def tanh(x): + return np.tanh(x) + + +def tanh_grad(y, dy): + return np.array(dy) * (1 - np.square(np.array(y))) + + +def cross_entropy(logits, labels): + logits = np.array(logits) + labels = np.array(labels) + M = logits.shape[0] + N = logits.shape[1] + ret = np.ndarray([M, 1]).astype(logits.dtype) + for idx in six.moves.range(M): + ret[idx][0] = -np.log(logits[idx][labels[idx][0]]) + return ret + + +def cross_entropy_grad(logits, labels, bwd_dout): + logits = np.array(logits) + labels = np.array(labels) + bwd_dout = np.array(bwd_dout) + M = logits.shape[0] + N = logits.shape[1] + dlogits = np.zeros([M, N]).astype(logits.dtype) + for idx in six.moves.range(M): + dlogits[idx][labels[idx][0]] = -bwd_dout[idx] / logits[idx][labels[idx][ + 0]] + return dlogits, None + + +def simple_fc_net(img, label, use_py_func_op): + hidden = img + for idx in range(4): + hidden = fluid.layers.fc( + hidden, + size=200, + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + if use_py_func_op: + hidden = fluid.layers.tanh(hidden) + else: + new_hidden = fluid.default_main_program().current_block( + ).create_var( + name='hidden_{}'.format(idx), + dtype='float32', + shape=hidden.shape) + hidden = fluid.layers.py_func( + func=tanh, + x=hidden, + out=new_hidden, + backward_func=tanh_grad, + skip_vars_in_backward_input=hidden) + + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + if not use_py_func_op: + loss = fluid.layers.cross_entropy(input=prediction, label=label) + else: + loss = fluid.default_main_program().current_block().create_var( + name='loss', dtype='float32', shape=[-1, 1]) + fluid.layers.py_func( + func=cross_entropy, + x=[prediction, label], + out=loss, + backward_func=cross_entropy_grad, + skip_vars_in_backward_input=loss) + loss = fluid.layers.mean(loss) + return loss + + +def reader(): + for _ in six.moves.range(100): + yield np.random.random([784]), np.random.random_integers( + size=[1], low=0, high=9) + + +def test_main(use_cuda, use_py_func_op): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return None + + with fluid.program_guard(fluid.Program(), fluid.Program()): + with fluid.scope_guard(fluid.core.Scope()): + fluid.default_main_program().random_seed = 1 + fluid.default_startup_program().random_seed = 1 + np.random.seed(1) + + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + loss = simple_fc_net(img, label, use_py_func_op) + optimizer = fluid.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + feeder = fluid.DataFeeder(feed_list=[img, label], place=place) + r = paddle.batch(reader, batch_size=10) + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + ret = [] + for epoch_id in six.moves.range(2): + for d in r(): + L, = exe.run(feed=feeder.feed(d), fetch_list=[loss]) + ret.append(L[0]) + + return np.array(ret) + + +class TestPyFuncOp(unittest.TestCase): + def test_loss_diff(self): + losses = [] + for use_cuda in [True, False]: + for use_py_func_op in [True, False]: + L = test_main(use_cuda, use_py_func_op) + if L is not None: + losses.append(L) + + for idx in six.moves.range(len(losses) - 1): + max_diff = np.max(np.abs(losses[idx] - losses[0])) + self.assertAlmostEqual(max_diff, 0, delta=1e-3) + + +if __name__ == '__main__': + unittest.main() From e7c5c9d2de9f51e5d403c879a31c0297e0f40656 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 12 Dec 2018 12:41:47 +0000 Subject: [PATCH 4/8] remove unnecesary code test=develop --- paddle/fluid/pybind/pybind.cc | 14 -------------- python/paddle/fluid/layers/nn.py | 2 +- .../fluid/tests/unittests/test_py_func_op.py | 2 +- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 38b1308330..348a073915 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -572,20 +572,6 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Place") .def(py::init<>()) - .def("is_cpu_place", - [](platform::Place &self) { return platform::is_cpu_place(self); }) - .def("is_gpu_place", - [](platform::Place &self) { return platform::is_gpu_place(self); }) - .def("is_cuda_pinned_place", - [](platform::Place &self) { - return platform::is_cuda_pinned_place(self); - }) - .def("gpu_device_id", - [](platform::Place &self) { - PADDLE_ENFORCE(platform::is_gpu_place(self), - "gpu_device_id() only supports in CUDAPlace"); - return boost::get(self).device; - }) .def("set_place", [](platform::Place &self, const platform::CPUPlace &cpu_place) { self = cpu_place; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6555025001..d71368644d 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -23,7 +23,7 @@ import os import inspect from ..layer_helper import LayerHelper from ..initializer import Normal, Constant -from ..framework import Variable, OpProtoHolder, Program +from ..framework import Variable, OpProtoHolder from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ from .tensor import concat diff --git a/python/paddle/fluid/tests/unittests/test_py_func_op.py b/python/paddle/fluid/tests/unittests/test_py_func_op.py index 0f03368b7e..c71f2bdea8 100644 --- a/python/paddle/fluid/tests/unittests/test_py_func_op.py +++ b/python/paddle/fluid/tests/unittests/test_py_func_op.py @@ -59,7 +59,7 @@ def simple_fc_net(img, label, use_py_func_op): size=200, bias_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(value=1.0))) - if use_py_func_op: + if not use_py_func_op: hidden = fluid.layers.tanh(hidden) else: new_hidden = fluid.default_main_program().current_block( From deb0d41cea15db2b24aff269e2f84bd68eeaa919 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 12 Dec 2018 13:33:30 +0000 Subject: [PATCH 5/8] fix cmake fix cmake again test=develop --- paddle/fluid/operators/CMakeLists.txt | 9 +++++--- paddle/fluid/operators/py_func_op.cc | 4 ++-- paddle/fluid/operators/py_func_op.h | 2 +- paddle/fluid/pybind/CMakeLists.txt | 3 +++ python/paddle/fluid/layers/nn.py | 31 ++++++++++----------------- 5 files changed, 23 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 9379122faf..23508ebe7c 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -42,8 +42,7 @@ if (WITH_DISTRIBUTE) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) endif() -register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) - +register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) # warpctc_op needs cudnn 7 above if (WITH_GPU AND NOT WIN32) @@ -82,7 +81,7 @@ endif() # op_library(unstack_op DEPS stack_op) # op_library(tensor_array_to_tensor_op DEPS concat_op) -set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS} python pybind) +set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS}) set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies") cc_test(gather_test SRCS gather_test.cc DEPS tensor) @@ -94,4 +93,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) +if (WITH_PYTHON) + cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind) +endif() + set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index 32c44c3bc2..90a6433366 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -22,7 +22,7 @@ namespace paddle { namespace operators { -namespace py = pybind11; +namespace py = ::pybind11; static std::vector g_py_callables; @@ -30,7 +30,7 @@ const char kForwardPythonCallableId[] = "forward_callable_id"; const char kBackwardPythonCallableId[] = "backward_callable_id"; const char kPyFuncBackwardSkipVars[] = "backward_skip_vars"; -size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) { +size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) { g_py_callables.emplace_back(py_obj); return g_py_callables.size() - 1; } diff --git a/paddle/fluid/operators/py_func_op.h b/paddle/fluid/operators/py_func_op.h index e85fa6b5bc..4ba06bf598 100644 --- a/paddle/fluid/operators/py_func_op.h +++ b/paddle/fluid/operators/py_func_op.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { -size_t AppendPythonCallableObjectAndReturnId(pybind11::object py_obj); +size_t AppendPythonCallableObjectAndReturnId(const ::pybind11::object &py_obj); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index b8954cb126..b75790e4fe 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,5 +1,8 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer) +if(WITH_PYTHON) + list(APPEND PYBIND_DEPS py_func_op) +endif() set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc) if(WITH_PYTHON) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d71368644d..debe0ff0c9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9173,31 +9173,22 @@ class PyFuncWrapper(object): kwargs[arg] = args[idx] idx += 1 - ret0 = self._func(*args[idx:], **kwargs) - if ret0 is None: - return None - - if not isinstance(ret0, (list, tuple)): - ret0 = (ret0, ) + func_ret = self._func(*args[idx:], **kwargs) + if not isinstance(func_ret, (list, tuple)): + func_ret = (func_ret, ) ret = [] - for i in six.moves.range(len(ret0)): - if ret0[i] is None: - ret.append(None) - continue - - if isinstance(ret0[i], core.LoDTensor): - ret.append(ret0[i]) + for each_ret in func_ret: + if each_ret is None or isinstance(each_ret, core.LoDTensor): + ret.append(each_ret) continue - if isinstance(ret0[i], np.ndarray): - r = ret0[i] - else: - r = np.array(ret0[i]) + if not isinstance(each_ret, np.ndarray): + each_ret = np.array(each_ret) - t = core.LoDTensor() - t.set(r, core.CPUPlace()) - ret.append(t) + tensor = core.LoDTensor() + tensor.set(each_ret, core.CPUPlace()) + ret.append(tensor) return tuple(ret) From f0df62f136396794556a121344a719e4c6fb62ef Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Thu, 13 Dec 2018 09:44:40 +0000 Subject: [PATCH 6/8] add more unittest case test=develop --- paddle/fluid/operators/py_func_op.cc | 33 +++++++++++------- paddle/fluid/pybind/pybind.cc | 2 +- python/paddle/fluid/layers/nn.py | 34 +++++++++++++------ .../fluid/tests/unittests/test_py_func_op.py | 17 +++++++++- 4 files changed, 60 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index 5d1aa7d7e6..1bee3d9351 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -35,6 +35,9 @@ size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) { return g_py_callables.size() - 1; } +// Return py::object* instead of py::object +// Returning py::object would cause reference count increasing +// but without GIL, reference count in Python may not be safe static py::object *GetPythonCallableObject(size_t i) { PADDLE_ENFORCE_LT(i, g_py_callables.size(), "Invalid python callable id"); return &g_py_callables[i]; @@ -47,7 +50,7 @@ static std::string PythonObjectToString(const py::object &py_callable) { static void CallPythonFunc(py::object *callable, const std::vector &ins, - std::vector *out) { + std::vector *outs) { py::gil_scoped_acquire guard; py::tuple in_args(ins.size()); for (size_t i = 0; i < ins.size(); ++i) { @@ -57,8 +60,8 @@ static void CallPythonFunc(py::object *callable, auto ret = (*callable)(*in_args); auto ret_tuple = py::cast(ret); size_t ret_num = py::len(ret_tuple); - size_t out_num = out->size(); - if (ret_num != out_num) { + size_t out_num = outs->size(); + if (UNLIKELY(ret_num != out_num)) { // Python function has no return values or returns None // In this case, ret_num = 1 && ret[0] == None && out_num should be 0 // Otherwise, ret_num must be equal to out_num @@ -69,17 +72,18 @@ static void CallPythonFunc(py::object *callable, } for (size_t i = 0; i < out_num; ++i) { - if ((*out)[i] == nullptr) { + auto *out = (*outs)[i]; + if (out == nullptr) { continue; } try { - auto *out_tensor = py::cast(ret_tuple[i]); - PADDLE_ENFORCE_NOT_NULL(out_tensor, + auto *py_out_tensor = py::cast(ret_tuple[i]); + PADDLE_ENFORCE_NOT_NULL(py_out_tensor, "Output tensor %d should not be nullptr", i); - (*out)[i]->set_lod(out_tensor->lod()); - (*out)[i]->ShareDataWith(*out_tensor); + out->set_lod(py_out_tensor->lod()); + out->ShareDataWith(*py_out_tensor); } catch (py::cast_error &) { - PADDLE_THROW("Output %d is not LoDTensor", i); + PADDLE_THROW("The %d-th output must be LoDTensor", i); } } } @@ -94,6 +98,10 @@ class PyFuncOpShapeInference : public framework::InferShapeBase { PADDLE_ENFORCE_GE(ctx->Attrs().Get(kForwardPythonCallableId), 0, "Function id cannot be less than 0"); + // Transverse all outputs + // If name of any output ends with @GRAD, + // set its shape, dtype, lod_level, type to be the same as + // the correponding forward variable auto *op = boost::get(ctx->GetOp()); auto *block = op->Block(); const std::string kGradVarSuffix = framework::kGradVarSuffix; @@ -115,7 +123,7 @@ class PyFuncOpShapeInference : public framework::InferShapeBase { auto *in_var_desc = block->FindVarRecursive(fwd_var_name); PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found", fwd_var_name); - VLOG(10) << "Infer shape of Out(" << out_name << ") as Input(" + VLOG(10) << "Infer shape of Output(" << out_name << ") as Input(" << in_var_desc->Name() << ")"; out_var_desc->SetShape(in_var_desc->GetShape()); out_var_desc->SetDataType(in_var_desc->GetDataType()); @@ -135,7 +143,7 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { "Index of registered forward Python function.") .SetDefault(0); AddAttr(kBackwardPythonCallableId, - "Index of registered backward Python function") + "Index of registered backward Python function.") .SetDefault(-1); AddAttr>(kPyFuncBackwardSkipVars, "Unused forward in/out in backward op") @@ -170,8 +178,7 @@ class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase { auto fwd_outs = Output("Out"); // For memory reused, some inputs/output in forward part may be not needed - // in backward part - // Just skip these vars + // in backward part. Skipping these vars helps to save memory auto &backward_skip_var_list = boost::get>( fwd_attrs.at(kPyFuncBackwardSkipVars)); std::unordered_set backward_skip_var_set( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 348a073915..208efbea4a 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -104,7 +104,7 @@ PYBIND11_MODULE(core, m) { BindException(&m); m.def( - "append_python_callable_object_and_return_id", + "_append_python_callable_object_and_return_id", [](py::object py_obj) -> size_t { return paddle::operators::AppendPythonCallableObjectAndReturnId(py_obj); }); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index db7ec9d021..3cd0a2887e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9137,8 +9137,13 @@ class PyFuncRegistry(object): self._func = func # find named args using reflection - self._named_args = inspect.getargspec(self._func)[0] - self._id = core.append_python_callable_object_and_return_id(self) + args = inspect.getargspec(self._func) + if len(args[0]) == 0 and args[1] is None and args[2] is None: + # Function with no inputs + self._named_args = None + else: + self._named_args = args[0] + self._id = core._append_python_callable_object_and_return_id(self) ''' Why record self here? @@ -9168,13 +9173,16 @@ class PyFuncRegistry(object): return self._id def __call__(self, *args): - kwargs = dict() - idx = 0 - for arg in self._named_args: - kwargs[arg] = args[idx] - idx += 1 + if self._named_args is None: + func_ret = self._func() + else: + kwargs = dict() + idx = 0 + for arg in self._named_args: + kwargs[arg] = args[idx] + idx += 1 + func_ret = self._func(*args[idx:], **kwargs) - func_ret = self._func(*args[idx:], **kwargs) if not isinstance(func_ret, (list, tuple)): func_ret = (func_ret, ) @@ -9207,14 +9215,18 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): User should set the right data type and shape of :code:`out` before calling this function. However, data types and shapes of gradients of - :code:`out` and :code:`x` would be infered automatically. + :code:`out` and :code:`x` would be inferred automatically. - The orders of inputs of :code:`backward_func` would be: forward input - :code:`x`, forward output :code:`out` and backward input gradient of + Input orders of :code:`backward_func` would be: forward inputs + :code:`x`, forward outputs :code:`out` and backward input gradients of :code:`out`. If some variables of :code:`out` have no gradient, the input tensor would be None in Python side. If some variables of :code:`in` have no gradient, users should return None. + This function can also be used to debug the running network. User can + add a :code:`py_func` operator without output, and print input + :code:`x` inside :code:`func`. + Args: func (callable): forward Python function. x (Variable|list(Variable)|tuple(Variable)): inputs of :code:`func`. diff --git a/python/paddle/fluid/tests/unittests/test_py_func_op.py b/python/paddle/fluid/tests/unittests/test_py_func_op.py index 491bbc2190..943ad3ed22 100644 --- a/python/paddle/fluid/tests/unittests/test_py_func_op.py +++ b/python/paddle/fluid/tests/unittests/test_py_func_op.py @@ -25,6 +25,14 @@ if fluid.core.is_compiled_with_cuda(): os.environ['CPU_NUM'] = str(dev_cnt) +def dummy_func_with_no_input(): + return float(1.0) + + +def dummy_func_with_no_output(x): + pass + + def tanh(x): return np.tanh(x) @@ -86,13 +94,20 @@ def simple_fc_net(img, label, use_py_func_op): else: loss = fluid.default_main_program().current_block().create_var( name='loss', dtype='float32', shape=[-1, 1]) - fluid.layers.py_func( + loss = fluid.layers.py_func( func=cross_entropy, x=[prediction, label], out=loss, backward_func=cross_entropy_grad, skip_vars_in_backward_input=loss) + dummy_var = fluid.default_main_program().current_block().create_var( + name='test_tmp_var', dtype='float32', shape=[1]) + fluid.layers.py_func( + func=dummy_func_with_no_input, x=None, out=dummy_var) + + fluid.layers.py_func(func=dummy_func_with_no_output, x=loss, out=None) + loss = fluid.layers.mean(loss) return loss From dc8847af876e678e23d0c0125bedd5cfae47ec9b Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 18 Dec 2018 08:36:44 +0000 Subject: [PATCH 7/8] add examples and comments test=develop --- paddle/fluid/operators/py_func_op.cc | 77 ++++++++++++++++++---------- python/paddle/fluid/layers/nn.py | 41 +++++++++++++++ 2 files changed, 92 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index 1bee3d9351..a2895b5404 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -43,9 +43,12 @@ static py::object *GetPythonCallableObject(size_t i) { return &g_py_callables[i]; } -static std::string PythonObjectToString(const py::object &py_callable) { +static std::string PythonFuncDebugString(const py::object &py_callable) { py::gil_scoped_acquire guard; - return py::str(*py_callable); + std::string wrapper_func_str = py::str(py_callable); + auto inner_func = py_callable.attr("_func"); + std::string inner_func_str = py::str(inner_func); + return inner_func_str + " wrapped by " + wrapper_func_str; } static void CallPythonFunc(py::object *callable, @@ -93,15 +96,29 @@ class PyFuncOpShapeInference : public framework::InferShapeBase { void operator()(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(!ctx->IsRuntime(), "Infer shape cannot be called in runtime."); + + /** + * X or Out can be empty, so that py_func can be more flexible + * to support Python functions with no input or no output + */ PADDLE_ENFORCE(ctx->HasInputs("X") || ctx->HasOutputs("Out"), "Input(X) or Output(Out) must exist"); PADDLE_ENFORCE_GE(ctx->Attrs().Get(kForwardPythonCallableId), 0, "Function id cannot be less than 0"); - // Transverse all outputs - // If name of any output ends with @GRAD, - // set its shape, dtype, lod_level, type to be the same as - // the correponding forward variable + /** + * Traverse all outputs, check if name of any output ends with @GRAD. + * If found, set its shape, dtype, lod_level, type to be the same as + * the corresponding forward variable + * + * Why not get input dims from InferShapeContext? + * Because some variables in forward inputs/outputs may not be needed + * in backward. Those variables are not inside InferShapeContext. + * + * InferShape would be only called in compile time. During runtime, + * the shapes of outputs should be guaranteed by user-defined Python + * functions. + */ auto *op = boost::get(ctx->GetOp()); auto *block = op->Block(); const std::string kGradVarSuffix = framework::kGradVarSuffix; @@ -113,7 +130,7 @@ class PyFuncOpShapeInference : public framework::InferShapeBase { } auto out_name = out_var_desc->Name(); if (out_name == framework::kEmptyVarName || - out_name.size() <= kGradVarSuffix.size()) { + out_name.size() < kGradVarSuffix.size()) { continue; } @@ -152,7 +169,28 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { } }; +/** + * There are several benefits when backward op of py_func op is + * still py_func op. + * + * - Less codes are needed, since codes of backward is almost + * the same as forward. + * + * - To support high order derivative, so that py_func is + * infinite-order differentiable + */ class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase { + private: + static std::string DebugString(const std::vector &strs) { + if (strs.empty()) return ""; + std::string ret = strs[0]; + for (size_t i = 1; i < strs.size(); ++i) { + ret += " "; + ret += strs[i]; + } + return ret; + } + public: using framework::GradOpDescMakerBase::GradOpDescMakerBase; @@ -207,21 +245,8 @@ class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase { // But in Python side, if IG is not needed, users can just return None auto bwd_outs = InputGrad("X", false); - if (VLOG_IS_ON(10)) { - std::string in_str = "PyFunc Grad Input: "; - for (auto &in : bwd_ins) { - in_str += in; - in_str += " "; - } - VLOG(10) << in_str; - - std::string out_str = "PyFunc Grad Output: "; - for (auto &out : bwd_outs) { - out_str += out; - out_str += " "; - } - VLOG(10) << out_str; - } + VLOG(10) << "PyFunc Grad Input: " << DebugString(bwd_ins); + VLOG(10) << "PyFunc Grad Output: " << DebugString(bwd_outs); grad_op->SetInput("X", bwd_ins); grad_op->SetOutput("Out", bwd_outs); @@ -245,6 +270,7 @@ class PyFuncOp : public framework::OperatorBase { std::vector inputs(in_arg_names.size()); for (size_t i = 0; i < in_arg_names.size(); ++i) { auto in_var = scope.FindVar(in_arg_names[i]); + // When py_func op is called in backward, in_var may be null if (in_var == nullptr) { continue; } @@ -263,15 +289,14 @@ class PyFuncOp : public framework::OperatorBase { std::vector outputs(out_arg_names.size()); for (size_t i = 0; i < out_arg_names.size(); ++i) { auto *out_var = scope.FindVar(out_arg_names[i]); - auto *out_tensor = + outputs[i] = out_var ? out_var->GetMutable() : nullptr; - outputs[i] = out_tensor; } auto callable_id = static_cast(Attr(kForwardPythonCallableId)); auto *py_callable = GetPythonCallableObject(callable_id); - VLOG(10) << "Call py_func_op with id " << callable_id << ": " - << PythonObjectToString(*py_callable); + VLOG(10) << "Call Python function with id " << callable_id << ": " + << PythonFuncDebugString(*py_callable); CallPythonFunc(py_callable, inputs, &outputs); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3cd0a2887e..ab3fb1e97e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9243,6 +9243,47 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): Returns: out (Variable|list(Variable)|tuple(Variable)): input :code:`out` + + Examples: + + >>> import paddle.fluid as fluid + >>> import six + >>> + >>> def create_tmp_var(name, dtype, shape): + >>> return fluid.default_main_program().current_block().create_var( + >>> name=name, dtype=dtype, shape=shape) + >>> + >>> # tanh activation has been provided by Paddle C++ op + >>> # Here, we only use tanh to be an example to show the usage + >>> # of py_func + >>> def tanh(x): + >>> return np.tanh(x) + >>> + >>> # forward input x is skipped + >>> def tanh_grad(y, dy): + >>> return np.array(dy) * (1 - np.square(np.array(y))) + >>> + >>> def debug_func(x): + >>> print(x) + >>> + >>> def simple_net(img, label): + >>> hidden = img + >>> for idx in six.moves.range(4): + >>> hidden = fluid.layers.fc(hidden, size=200) + >>> new_hidden = create_tmp_var(name='hidden_{}'.format(idx), + >>> dtype=hidden.dtype, shape=hidden.shape) + >>> + >>> # user-defined layers with forward and backward + >>> hidden = fluid.layers.py_func(func=tanh, x=hidden, + >>> out=new_hidden, backward_func=tanh_grad, + >>> skip_vars_in_backward_input=hidden) + >>> + >>> # user-defined debug layers to print variables + >>> fluid.layers.py_func(func=debug_func, x=hidden, out=None) + >>> + >>> prediction = fluid.layers.fc(hidden, size=10, act='softmax') + >>> loss = fluid.layers.cross_entropy(input=prediction, label=label) + >>> return fluid.layers.mean(loss) """ helper = LayerHelper('py_func', **locals()) if x is None: From 490eb9061f7d3bd19240fbff8465a2d5e4f25204 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Thu, 20 Dec 2018 08:45:43 +0000 Subject: [PATCH 8/8] polish infer shape of py_func op test=develop --- paddle/fluid/framework/op_desc.cc | 2 - paddle/fluid/framework/operator.cc | 2 - paddle/fluid/framework/shape_inference.h | 3 - paddle/fluid/operators/py_func_op.cc | 79 ++++++++++++------------ 4 files changed, 41 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 0faf9fe054..dde642764f 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -34,8 +34,6 @@ class CompileTimeInferShapeContext : public InferShapeContext { public: CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block); - InferShapeOpPtr GetOp() const override { return &op_; } - bool HasInput(const std::string &name) const override; bool HasOutput(const std::string &name) const override; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 222b261e2a..66055e6f1d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -481,8 +481,6 @@ class RuntimeInferShapeContext : public InferShapeContext { RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) : op_(op), scope_(scope) {} - InferShapeOpPtr GetOp() const override { return &op_; } - bool HasInput(const std::string& name) const override { // has only one input const auto& ins = op_.Inputs(); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 2f95ab353e..55349376ba 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -28,7 +28,6 @@ namespace framework { class OperatorBase; using InferShapeVarPtr = boost::variant; -using InferShapeOpPtr = boost::variant; class InferShapeContext { public: @@ -41,8 +40,6 @@ class InferShapeContext { std::vector GetOutputsVarType( const std::string &name) const; - virtual InferShapeOpPtr GetOp() const = 0; - virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index a2895b5404..a6b1c738af 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -91,66 +91,68 @@ static void CallPythonFunc(py::object *callable, } } -class PyFuncOpShapeInference : public framework::InferShapeBase { +class PyFuncOpVarTypInference : public framework::VarTypeInference { public: - void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(!ctx->IsRuntime(), - "Infer shape cannot be called in runtime."); + void operator()(const framework::OpDesc &op, + framework::BlockDesc *block) const override { + auto &outs = op.Outputs(); + bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty()); + + auto &ins = op.Inputs(); + bool has_in = (ins.count("X") > 0 && !ins.at("X").empty()); /** * X or Out can be empty, so that py_func can be more flexible * to support Python functions with no input or no output */ - PADDLE_ENFORCE(ctx->HasInputs("X") || ctx->HasOutputs("Out"), - "Input(X) or Output(Out) must exist"); - PADDLE_ENFORCE_GE(ctx->Attrs().Get(kForwardPythonCallableId), 0, + PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist"); + + PADDLE_ENFORCE_GE(boost::get(op.GetAttr(kForwardPythonCallableId)), 0, "Function id cannot be less than 0"); + if (!has_out) return; + /** * Traverse all outputs, check if name of any output ends with @GRAD. * If found, set its shape, dtype, lod_level, type to be the same as * the corresponding forward variable - * - * Why not get input dims from InferShapeContext? - * Because some variables in forward inputs/outputs may not be needed - * in backward. Those variables are not inside InferShapeContext. - * - * InferShape would be only called in compile time. During runtime, - * the shapes of outputs should be guaranteed by user-defined Python - * functions. */ - auto *op = boost::get(ctx->GetOp()); - auto *block = op->Block(); const std::string kGradVarSuffix = framework::kGradVarSuffix; - auto out_vars = ctx->GetOutputVarPtrs("Out"); - for (auto &out_var : out_vars) { - auto *out_var_desc = boost::get(out_var); - if (out_var_desc == nullptr) { - continue; - } - auto out_name = out_var_desc->Name(); - if (out_name == framework::kEmptyVarName || - out_name.size() < kGradVarSuffix.size()) { + auto &out_var_names = outs.at("Out"); + for (auto &out_var_name : out_var_names) { + if (out_var_name == framework::kEmptyVarName || + out_var_name.size() < kGradVarSuffix.size()) { continue; } - size_t len = out_name.size() - kGradVarSuffix.size(); - if (out_name.substr(len) == kGradVarSuffix) { - auto fwd_var_name = out_name.substr(0, len); - auto *in_var_desc = block->FindVarRecursive(fwd_var_name); - PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found", + size_t len = out_var_name.size() - kGradVarSuffix.size(); + if (out_var_name.substr(len) == kGradVarSuffix) { + auto fwd_var_name = out_var_name.substr(0, len); + auto *out_var_desc = block->FindVarRecursive(out_var_name); + auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name); + PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found", + out_var_name); + PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found", fwd_var_name); - VLOG(10) << "Infer shape of Output(" << out_name << ") as Input(" - << in_var_desc->Name() << ")"; - out_var_desc->SetShape(in_var_desc->GetShape()); - out_var_desc->SetDataType(in_var_desc->GetDataType()); - out_var_desc->SetLoDLevel(in_var_desc->GetLoDLevel()); - out_var_desc->SetType(in_var_desc->GetType()); + VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" + << fwd_var_name << ")"; + out_var_desc->SetShape(fwd_var_desc->GetShape()); + out_var_desc->SetDataType(fwd_var_desc->GetDataType()); + out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel()); + out_var_desc->SetType(fwd_var_desc->GetType()); } } } }; +class PyFuncOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(!ctx->IsRuntime(), + "Infer shape cannot be called in runtime."); + } +}; + class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -307,4 +309,5 @@ class PyFuncOp : public framework::OperatorBase { namespace ops = paddle::operators; REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, - ops::PyFuncOpShapeInference, ops::PyFuncOpGradDescMaker); + ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference, + ops::PyFuncOpGradDescMaker);