diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 38f811c0e9..ee9f6a4805 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -19,8 +19,6 @@ limitations under the License. */ #include #include #include -#include "gflags/gflags.h" -#include "glog/logging.h" #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index a730b84a91..5db4221199 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -1,5 +1,5 @@ if(WITH_PYTHON) -cc_library(layer SRCS layer.cc DEPS proto_desc operator) -cc_library(tracer SRCS tracer.cc DEPS proto_desc) +cc_library(layer SRCS layer.cc DEPS proto_desc operator device_context blas) +cc_library(tracer SRCS tracer.cc DEPS proto_desc device_context) cc_library(engine SRCS engine.cc) endif() diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index b7df4b8886..8029129b9a 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/imperative/layer.h" + #include #include #include @@ -22,6 +23,9 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/string/printf.h" namespace paddle { @@ -34,22 +38,66 @@ std::map py_funcs_; using framework::Variable; -void AddTo(Variable* src, Variable* dst) { - framework::LoDTensor* dst_tensor = dst->GetMutable(); - framework::LoDTensor* src_tensor = src->GetMutable(); +namespace detail { + +template +class TensorAddToFunctor : public boost::static_visitor<> { + public: + TensorAddToFunctor(int64_t numel, const T* x, T* y) + : numel_(numel), x_(x), y_(y) {} + + void operator()(const platform::CPUPlace& place) { + platform::CPUDeviceContext* ctx = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto blas = operators::math::GetBlas(*ctx); + blas.AXPY(numel_, 1., x_, y_); + } + +#ifdef PADDLE_WITH_CUDA + void operator()(const platform::CUDAPlace& place) { + platform::CUDADeviceContext* ctx = + dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto blas = operators::math::GetBlas(*ctx); + blas.AXPY(numel_, 1., x_, y_); + } +#else + void operator()(const platform::CUDAPlace& place) { + PADDLE_THROW("Do NOT support gradient merge in place %s", place); + } +#endif + + // there is NO blas in CUDAPinnedPlace + void operator()(const platform::CUDAPinnedPlace& place) { + PADDLE_THROW("Do NOT support gradient merge in place %s", place); + } + + private: + int64_t numel_; + const T* x_; + T* y_; +}; + +} // namespace detail + +void AddTo(Variable* src, Variable* dst, platform::Place place) { + framework::Tensor* dst_tensor = dst->GetMutable(); + framework::Tensor* src_tensor = src->GetMutable(); + // FIXME(minqiyang): loss_grad op will pass a zero grad of label // ugly fix for it if (src_tensor->numel() == 0) { return; } + PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(), src_tensor->numel()); - float* dst_data = dst_tensor->mutable_data(platform::CPUPlace()); - const float* src_data = src_tensor->data(); - for (int64_t i = 0; i < src_tensor->numel(); ++i) { - dst_data[i] += src_data[i]; - } + + detail::TensorAddToFunctor func( + src_tensor->numel(), src_tensor->data(), + dst_tensor->mutable_data(place)); + boost::apply_visitor(func, place); } class Autograd { @@ -120,6 +168,36 @@ class Autograd { } }; +std::unique_ptr VarBase::NewVarBase(const platform::Place& dst_place, + const bool blocking) const { + PADDLE_ENFORCE(var_->IsInitialized(), + "Variable must be initialized when getting numpy tensor"); + + std::unique_ptr new_var(new VarBase()); + framework::LoDTensor* tensor = + new_var->var_->GetMutable(); + tensor->Resize(var_->Get().dims()); + tensor->set_lod(var_->Get().lod()); + + if (blocking) { + platform::DeviceContext* dev_ctx = + platform::DeviceContextPool::Instance().Get(dst_place); + + framework::TensorCopySync(var_->Get(), dst_place, + tensor); + + dev_ctx->Wait(); + } else { + framework::TensorCopy(var_->Get(), dst_place, tensor); + } + + if (platform::is_gpu_place(dst_place)) { + VLOG(3) << "copy tensor " << var_desc_->Name() << " from gpu"; + } + + return new_var; +} + framework::LoDTensor& VarBase::GradValue() { VLOG(3) << "get var grad " << var_desc_->Name(); return *(grads_->var_->GetMutable()); @@ -162,9 +240,8 @@ std::map> OpBase::ApplyGrad() { PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); framework::Scope scope; - platform::CPUPlace place; - PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place); - p.op.RuntimeInferShape(scope, place, ctx); + PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); + p.op.RuntimeInferShape(scope, place_, ctx); p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); } @@ -176,7 +253,7 @@ std::map> OpBase::ApplyGrad() { for (size_t i = 0; i < outputs.size(); ++i) { framework::Variable* grad = outputs[i]; framework::Variable* orig_grad = origin_outputs[i]; - AddTo(grad, orig_grad); + AddTo(grad, orig_grad, place_); delete grad; } } @@ -188,8 +265,10 @@ void VarBase::RunBackward() { VLOG(3) << "start backward"; auto grads_t = grads_->var_->GetMutable(); - float* data = grads_t->mutable_data(platform::CPUPlace()); - std::fill(data, data + grads_t->numel(), 1.0); + operators::math::set_constant( + *(platform::DeviceContextPool::Instance().Get( + var_->GetMutable()->place())), + grads_t, 1.0); PADDLE_ENFORCE( grads_ == diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 0b1077c640..633924aa41 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -21,17 +21,21 @@ #include // NOLINT #include // NOLINT #include // NOLINT +#include // NOLINT #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { namespace imperative { +class VarBase; + namespace py = ::pybind11; class PreparedOp { @@ -81,6 +85,8 @@ class PreparedOp { return PreparedOp(op, ctx, kernel_iter->second, dev_ctx); } + inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx; } + const framework::OperatorBase& op; const framework::RuntimeContext& ctx; framework::OperatorWithKernel::OpKernelFunc func; @@ -148,6 +154,9 @@ class VarBase { framework::LoDTensor& GradValue(); + std::unique_ptr NewVarBase(const platform::Place& dst_place, + const bool blocking) const; + inline std::string GradName() const { PADDLE_ENFORCE( var_desc_, @@ -176,7 +185,8 @@ class OpBase { : op_desc_(nullptr), forward_id_(-1), grad_op_desc_(nullptr), - backward_id_(-1) {} + backward_id_(-1), + place_(platform::CPUPlace()) {} virtual ~OpBase() { if (grad_op_desc_) delete grad_op_desc_; @@ -193,6 +203,8 @@ class OpBase { framework::OpDesc* grad_op_desc_; int backward_id_; + platform::Place place_; + VarBasePtrMap input_vars_; VarBasePtrMap output_vars_; OpBasePtrMap pre_ops_; diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 843fee41f3..5b87839f45 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -14,6 +14,10 @@ #include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + namespace paddle { namespace imperative { @@ -31,16 +35,38 @@ void CreateGradOp(const framework::OpDesc& op_desc, *grad_op_desc = grad_op_descs[0].release(); } -void InitVar(framework::Variable* var, framework::Variable* grad_var) { +void InitVar(framework::Variable* var, framework::Variable* grad_var, + platform::DeviceContext* dev_ctx) { + PADDLE_ENFORCE_NOT_NULL(dev_ctx, + "Could not get valid device from forward op"); auto& var_t = var->Get(); - float* data = - grad_var->GetMutable()->mutable_data( - var_t.dims(), platform::CPUPlace()); - std::fill(data, data + var_t.numel(), 0.0); + grad_var->GetMutable()->mutable_data( + var_t.dims(), dev_ctx->GetPlace()); + operators::math::set_constant( + *dev_ctx, grad_var->GetMutable(), 0.0); +} + +platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) { + platform::Place result = place; + for (auto it : inputs) { + for (VarBase* var : it.second) { + platform::Place tmp_place = + var->var_->Get().place(); + if (!platform::is_same_place(tmp_place, result)) { + PADDLE_THROW( + "Input variable should keep in the same place: %s, but get place: " + "%s of input %s instead", + result, tmp_place, it.first); + } + } + } + + return result; } void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, const VarBasePtrMap& outputs, framework::BlockDesc* block, + const platform::Place expected_place, const bool stop_gradient) { std::map vars; @@ -105,10 +131,11 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); framework::Scope scope; - platform::CPUPlace place; - PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place); - p.op.RuntimeInferShape(scope, place, ctx); - p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); + op->place_ = GetExpectedPlace(expected_place, inputs); + PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_); + prepared_op.op.RuntimeInferShape(scope, op->place_, ctx); + prepared_op.func(framework::ExecutionContext( + prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx)); if (!stop_gradient) { framework::OpDesc* grad_op_desc; @@ -131,7 +158,8 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, } else { VarBase* var = vars[var_it->second]; if (!var->grads_->var_->IsInitialized()) { - InitVar(var->var_, var->grads_->var_); + InitVar(var->var_, var->grads_->var_, + prepared_op.GetDeviceContext()); } // Douts. grad_in_vars.push_back(var->grads_->var_); @@ -144,10 +172,13 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, for (const std::string& grad_outvar : it.second) { block->FindRecursiveOrCreateVar(grad_outvar); auto var_it = grad_to_var->find(grad_outvar); - PADDLE_ENFORCE(var_it != grad_to_var->end()); + PADDLE_ENFORCE(var_it != grad_to_var->end(), + "Could not found the grad op output var, should this " + "operator %s's stop gradient be True", + op_desc->Type()); VarBase* var = vars[var_it->second]; if (!var->grads_->var_->IsInitialized()) { - InitVar(var->var_, var->grads_->var_); + InitVar(var->var_, var->grads_->var_, prepared_op.GetDeviceContext()); } grad_out_vars.push_back(var->grads_->var_); } @@ -189,16 +220,23 @@ std::vector Tracer::PyTrace(OpBase* op, for (VarBase* out : outputs) { grad_input_vars.push_back(out->var_); } + + platform::CPUPlace place; for (VarBase* out : outputs) { grad_input_vars.push_back(out->grads_->var_); if (!grad_input_vars.back()->IsInitialized()) { - InitVar(out->var_, grad_input_vars.back()); + // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now + InitVar(out->var_, grad_input_vars.back(), + platform::DeviceContextPool::Instance().Get(place)); } } + for (const VarBase* inp : inputs) { grad_output_vars.push_back(inp->grads_->var_); if (!grad_output_vars.back()->IsInitialized()) { - InitVar(inp->var_, grad_output_vars.back()); + // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now + InitVar(inp->var_, grad_output_vars.back(), + platform::DeviceContextPool::Instance().Get(place)); } } } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index f225d8abe6..6908382155 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/imperative/engine.h" #include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/platform/place.h" namespace paddle { namespace imperative { @@ -34,21 +35,25 @@ void CreateGradOp(const framework::OpDesc& op_desc, void InitVar(framework::Variable* var, framework::Variable* grad_var); +platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs); + class Tracer { public: explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {} virtual ~Tracer() {} - void Trace(OpBase* op, - const std::map>& inputs, - const std::map>& outputs, - framework::BlockDesc* block, const bool stop_gradient = false); + void Trace(OpBase* op, const VarBasePtrMap& inputs, + const VarBasePtrMap& outputs, framework::BlockDesc* block, + const platform::Place expected_place, + const bool stop_gradient = false); std::vector PyTrace(OpBase* op, const std::vector& inputs, bool stop_gradient = false); private: + platform::Place GetPlace(const VarBasePtrMap& inputs); + framework::BlockDesc* root_block_; }; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 8f80a2d782..2493fb71c0 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -30,8 +30,9 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { PADDLE_THROW( - "'Place' is not supported, Please re-compile with WITH_GPU " - "option"); + "Place %s is not supported, Please re-compile with WITH_GPU " + "option", + place); } return it->second.get().get(); } diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index dbc7843caa..31c3bfa43f 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -15,18 +15,38 @@ limitations under the License. */ #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/imperative/type_defs.h" namespace paddle { namespace pybind { // Bind Methods -void BindTracer(pybind11::module *m) { +void BindTracer(pybind11::module* m) { pybind11::class_(*m, "Tracer", "") .def("__init__", - [](imperative::Tracer &self, framework::BlockDesc *root_block) { + [](imperative::Tracer& self, framework::BlockDesc* root_block) { new (&self) imperative::Tracer(root_block); }) - .def("trace", &imperative::Tracer::Trace) + .def("trace", + [](imperative::Tracer& self, imperative::OpBase* op, + const imperative::VarBasePtrMap& inputs, + const imperative::VarBasePtrMap& outputs, + framework::BlockDesc* block, + const platform::CPUPlace expected_place, + const bool stop_gradient = false) { + self.Trace(op, inputs, outputs, block, expected_place, + stop_gradient); + }) + .def("trace", + [](imperative::Tracer& self, imperative::OpBase* op, + const imperative::VarBasePtrMap& inputs, + const imperative::VarBasePtrMap& outputs, + framework::BlockDesc* block, + const platform::CUDAPlace expected_place, + const bool stop_gradient = false) { + self.Trace(op, inputs, outputs, block, expected_place, + stop_gradient); + }) .def("py_trace", &imperative::Tracer::PyTrace, pybind11::return_value_policy::take_ownership); } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c470483756..2fbd798d57 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -138,6 +138,22 @@ PYBIND11_MODULE(core, m) { .def("_grad_ivar", [](const imperative::VarBase &self) { return self.grads_; }, py::return_value_policy::reference) + .def("_copy_to", + [](const imperative::VarBase &self, const platform::CPUPlace &place, + bool blocking) { + std::unique_ptr new_var = + self.NewVarBase(place, blocking); + return new_var.release(); + }, + py::return_value_policy::take_ownership) + .def("_copy_to", + [](const imperative::VarBase &self, const platform::CUDAPlace &place, + bool blocking) { + std::unique_ptr new_var = + self.NewVarBase(place, blocking); + return new_var.release(); + }, + py::return_value_policy::take_ownership) .def("value", [](const imperative::VarBase &self) { return self.var_; }, py::return_value_policy::reference) .def_property( diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 22f505854e..2bdae60db3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -70,6 +70,7 @@ ZERO_VAR_SUFFIX = core.kZeroVarSuffix() CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() _imperative_tracer_ = None +_imperative_current_expected_place_ = None def _in_imperative_mode(): @@ -80,6 +81,10 @@ def _imperative_tracer(): return _imperative_tracer_ +def _current_expected_place(): + return _imperative_current_expected_place_ + + class NameScope(object): def __init__(self, name="", parent=None): self._children = dict() @@ -383,8 +388,8 @@ class Variable(object): self._ivar.stop_gradient = stop_gradient def _numpy(self): - tensor = self._ivar.value().get_tensor() - return np.array(tensor) + new_ivar = self._ivar._copy_to(core.CPUPlace(), True) + return np.array(new_ivar.value().get_tensor()) def _backward(self): self._ivar._run_backward() @@ -1311,6 +1316,7 @@ class Block(object): def _trace_op(self, op, stop_gradient=False): if _in_imperative_mode(): _imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc, + _imperative_current_expected_place_, stop_gradient) def _insert_op(self, index, *args, **kwargs): @@ -2502,5 +2508,18 @@ def _imperative_guard(tracer): global _imperative_tracer_ tmp_trace = _imperative_tracer_ _imperative_tracer_ = tracer + yield + _imperative_tracer_ = tmp_trace + + +@contextlib.contextmanager +def _imperative_place_guard(place): + global _imperative_current_expected_place_ + tmp_place = _imperative_current_expected_place_ + _imperative_current_expected_place_ = place + + yield + + _imperative_current_expected_place_ = tmp_place diff --git a/python/paddle/fluid/imperative/base.py b/python/paddle/fluid/imperative/base.py index 5d3ebb25a9..ff3984b11f 100644 --- a/python/paddle/fluid/imperative/base.py +++ b/python/paddle/fluid/imperative/base.py @@ -25,18 +25,28 @@ def enabled(): @contextlib.contextmanager -def guard(): +def guard(place=None): train = framework.Program() startup = framework.Program() tracer = core.Tracer(train.current_block().desc) + + if place is None: + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + with framework.program_guard(train, startup): with framework.unique_name.guard(): with framework._imperative_guard(tracer): - yield + with framework._imperative_place_guard(place): + yield def to_variable(value, block=None): if isinstance(value, np.ndarray): + assert enabled(), "to_variable could only be called in imperative mode" + if not block: block = framework.default_main_program().current_block() py_var = framework.Variable( @@ -47,9 +57,7 @@ def to_variable(value, block=None): dtype=value.dtype) var = py_var._ivar.value() tensor = var.get_tensor() - tensor.set(value, core.CPUPlace()) + tensor.set(value, framework._current_expected_place()) return py_var elif isinstance(value, framework.Variable): return value - else: - raise ValueError("Unsupported type %s" % type(value)) diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index 03fbfe76d1..140c0ff037 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -27,6 +27,7 @@ __all__ = [ 'Conv2D', 'Pool2D', 'FC', + 'BatchNorm', ] @@ -55,7 +56,8 @@ class Conv2D(layers.Layer): param_attr=param_attr, bias_attr=bias_attr, dtype=dtype, - name=name) + name=name, + act=act) self._groups = groups self._stride = utils.convert_to_list(stride, 2, 'stride') @@ -141,6 +143,7 @@ class Conv2D(layers.Layer): outputs={'Out': [pre_act]}, attrs={'axis': 1}) + # Currently, we don't support inplace in imperative mode return self._helper.append_activation(pre_act) @@ -216,6 +219,7 @@ class FC(layers.Layer): act=None, name=None): super(FC, self).__init__() + self._size = size self._num_flatten_dims = num_flatten_dims self._dtype = dtype @@ -241,6 +245,16 @@ class FC(layers.Layer): dtype=self._dtype, is_bias=False) + if self._helper.bias_attr: + size = list([self._size]) + self._b = self._helper.create_parameter( + attr=self._helper.bias_attr, + shape=size, + dtype=self._dtype, + is_bias=True) + else: + self._b = None + def forward(self, input): tmp = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( @@ -253,28 +267,155 @@ class FC(layers.Layer): "y_num_col_dims": 1 }) - out = self._helper.create_variable_for_type_inference(self._dtype) + pre_bias = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( type="sum", inputs={"X": [tmp]}, - outputs={"Out": out}, + outputs={"Out": pre_bias}, attrs={"use_mkldnn": False}) - bias_attr = self._helper.bias_attr - if bias_attr: - # add bias - size = list(out.shape[1:]) - if not self._built: - self._b = self._helper.create_parameter( - attr=bias_attr, shape=size, dtype=out.dtype, is_bias=True) - bias_out = self._helper.create_variable_for_type_inference( - dtype=out.dtype) + if self._b: + pre_activation = self._helper.create_variable_for_type_inference( + dtype=self._dtype) self._helper.append_op( type='elementwise_add', - inputs={'X': [out], + inputs={'X': [pre_bias], 'Y': [self._b]}, - outputs={'Out': [bias_out]}, - attrs={'axis': 1}) - out = bias_out - # add activation - return self._helper.append_activation(out) + outputs={'Out': [pre_activation]}, + attrs={'axis': self._num_flatten_dims}) + else: + pre_activation = pre_bias + # Currently, we don't support inplace in imperative mode + return self._helper.append_activation(pre_activation) + + +class BatchNorm(layers.Layer): + def __init__(self, + num_channels, + act=None, + is_test=False, + momentum=0.9, + epsilon=1e-05, + param_attr=None, + bias_attr=None, + dtype=core.VarDesc.VarType.FP32, + data_layout='NCHW', + in_place=False, + name=None, + moving_mean_name=None, + moving_variance_name=None, + do_model_average_for_mean_and_var=False, + fuse_with_relu=False, + use_global_stats=False): + super(BatchNorm, self).__init__() + + assert bias_attr is not False, "bias_attr should not be False in batch_norm." + + from ..layer_helper import LayerHelper + self._helper = LayerHelper( + 'batch_norm', + param_attr=param_attr, + bias_attr=bias_attr, + name=name, + act=act) + + if dtype == core.VarDesc.VarType.FP16: + self._dtype = core.VarDesc.VarType.FP32 + else: + self._dtype = dtype + + param_shape = [num_channels] + + # create parameter + self._scale = self._helper.create_parameter( + attr=self._helper.param_attr, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0)) + + # TODO(minqiyang): change stop_gradient sign to trainable to align with static graph + # # setting stop_gradient=True to reduce computation + # if use_global_stats and self._helper.param_attr.learning_rate == 0.: + # self._scale.stop_gradient = True + + self._bias = self._helper.create_parameter( + attr=self._helper.bias_attr, + shape=param_shape, + dtype=self._dtype, + is_bias=True) + # TODO(minqiyang): change stop_gradient sign to trainable to align with static graph + # # setting stop_gradient=True to reduce computation + # if use_global_stats and self._helper.bias_attr.learning_rate == 0.: + # self._bias.stop_gradient = True + + self._mean = self._helper.create_parameter( + attr=ParamAttr( + name=moving_mean_name, + initializer=Constant(0.0), + trainable=False, + do_model_average=do_model_average_for_mean_and_var), + shape=param_shape, + dtype=self._dtype) + self._mean.stop_gradient = True + + self._variance = self._helper.create_parameter( + attr=ParamAttr( + name=moving_variance_name, + initializer=Constant(1.0), + trainable=False, + do_model_average=do_model_average_for_mean_and_var), + shape=param_shape, + dtype=self._dtype) + self._variance.stop_gradient = True + + self._in_place = in_place + self._momentum = momentum + self._epsilon = epsilon + self._is_test = is_test + self._fuse_with_relu = fuse_with_relu + self._use_global_stats = use_global_stats + + def _build_once(self, input): + pass + + def forward(self, input): + # create output + # mean and mean_out share the same memory + mean_out = self._mean + # variance and variance out share the same memory + variance_out = self._variance + + saved_mean = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + saved_variance = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( + self._dtype) + + self._helper.append_op( + type="batch_norm", + inputs={ + "X": input, + "Scale": self._scale, + "Bias": self._bias, + "Mean": self._mean, + "Variance": self._variance + }, + outputs={ + "Y": batch_norm_out, + "MeanOut": mean_out, + "VarianceOut": variance_out, + "SavedMean": saved_mean, + "SavedVariance": saved_variance + }, + attrs={ + "momentum": self._momentum, + "epsilon": self._epsilon, + "is_test": self._is_test, + "use_mkldnn": False, + "fuse_with_relu": self._fuse_with_relu, + "use_global_stats": self._use_global_stats + }) + + # Currently, we don't support inplace in imperative mode + return self._helper.append_activation(batch_norm_out) diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index ea9953f581..972c51938f 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -435,7 +435,10 @@ class LayerHelper(object): act_type = act.pop('type') tmp = input_var # NOTE(dzhwinter): some activation support inplace compution. - if not core.IsInplace(act_type): + # NOTE(minqiyang): currently, we don't support inplace in imperative mode + if not imperative_base.enabled() and core.IsInplace(act_type): + tmp = input_var + else: tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) self.append_op( type=act_type, diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index dde0518972..617704a531 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -321,7 +321,7 @@ def append_LARS(params_grads, learning_rate, weight_decay): The decayed learning rate Examples: .. code-block:: python - + learning_rate *= local_gw_ratio * sqrt(sumsq(param)) / (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param))) """ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 305c475c70..e2a4c05926 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2874,7 +2874,7 @@ def batch_norm(input, attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) # setting stop_gradient=True to reduce computation if use_global_stats and helper.bias_attr.learning_rate == 0.: - scale.stop_gradient = True + bias.stop_gradient = True mean = helper.create_parameter( attr=ParamAttr( @@ -5856,7 +5856,8 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): type='increment', inputs={'X': [counter]}, outputs={'Out': [counter]}, - attrs={'step': float(step)}) + attrs={'step': float(step)}, + stop_gradient=True) counter.stop_gradient = True return counter @@ -9475,7 +9476,7 @@ def teacher_student_sigmoid_loss(input, by the previous operator. label (Variable|list): the ground truth which is a 2-D tensor with shape [N x 1], where N is the batch size. - soft_max_up_bound (float): if input > soft_max_up_bound, will be bound + soft_max_up_bound (float): if input > soft_max_up_bound, will be bound soft_max_lower_bound (float): if input < soft_max_lower_bound, will be bound Returns: diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index ce9f508c9f..2153ca254f 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -382,7 +382,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): 'dtype': out.dtype, 'value': float(value), 'force_cpu': force_cpu or force_init_on_cpu() - }) + }, + stop_gradient=True) out.stop_gradient = True return out diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index b72b900d3b..14f4276e2f 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -301,10 +301,10 @@ class Optimizer(object): no_grad_set (set|None): set of Variables should be ignored. callbacks (list|None): list of callables to run when appending backward operator for one parameter. - + Return: list: list of (param, grad) pair, grad is the output of backward. - + Examples: See examples in `apply_gradients`. """ @@ -322,10 +322,10 @@ class Optimizer(object): Args: params_grads (list): list of (param, grad) pair to do optimization. - + Returns: list: A list of operators appended to the current program. - + Examples: .. code-block:: python @@ -364,7 +364,7 @@ class Optimizer(object): This method combines interface `backward()` and `apply_gradients()` into one. - + Args: loss (Variable): loss variable to run optimizations. startup_program (Program): startup_program for initializing parameters @@ -381,18 +381,21 @@ class Optimizer(object): optimize_ops = [] if imperative_base.enabled(): if parameter_list is not None: - params_grads = parameter_list + parameters = parameter_list else: parameters = program.global_block().all_parameters() - params_grads = [] - for param in parameters: - # create gradient variable - grad_var = Variable( - block=loss.block, - name=param._ivar._grad_name(), - stop_gradient=True, - ivar=param._ivar._grad_ivar()) - params_grads.append((param, grad_var)) + + params_grads = [] + for param in parameters: + if param.stop_gradient: + continue + # create gradient variable + grad_var = Variable( + block=loss.block, + name=param._ivar._grad_name(), + stop_gradient=True, + ivar=param._ivar._grad_ivar()) + params_grads.append((param, grad_var)) with program_guard(program, startup_program): optimize_ops = self._create_optimization_pass(params_grads) else: diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 808e1e6aa8..c23dfa01e7 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -84,6 +84,7 @@ list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer) list(REMOVE_ITEM TEST_OPS test_image_classification_resnet) list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op) list(REMOVE_ITEM TEST_OPS test_nearest_interp_op) +list(REMOVE_ITEM TEST_OPS test_imperative_resnet) foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) @@ -91,6 +92,8 @@ py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_ py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL) py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL) py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL) +py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS + FLAGS_cudnn_deterministic=1) if(WITH_DISTRIBUTE) py_test_modules(test_dist_train MODULES test_dist_train SERIAL) set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20) diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py index dfe4daca95..7533ab9fdb 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative.py +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -133,7 +133,8 @@ class TestImperative(unittest.TestCase): x = fluid.layers.reduce_sum(fluid.layers.tanh(x1)) param_grads = fluid.backward.append_backward( x, parameter_list=[x1.name])[0] - exe = fluid.Executor(fluid.CPUPlace()) + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) static_out, static_grad = exe.run( feed={inp.name: np_inp}, @@ -160,7 +161,8 @@ class TestImperative(unittest.TestCase): x = l(inp)[0] param_grads = fluid.backward.append_backward( x, parameter_list=[l._x_for_debug.name])[0] - exe = fluid.Executor(fluid.CPUPlace()) + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) static_out, static_grad = exe.run( feed={inp.name: np_inp}, @@ -186,7 +188,8 @@ class TestImperative(unittest.TestCase): out = mlp(inp) param_grads = fluid.backward.append_backward( out, parameter_list=[mlp._fc1._w.name])[0] - exe = fluid.Executor(fluid.CPUPlace()) + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) exe.run(fluid.default_startup_program()) static_out, static_grad = exe.run( diff --git a/python/paddle/fluid/tests/unittests/test_imperative_gan.py b/python/paddle/fluid/tests/unittests/test_imperative_gan.py index 4fe286f85e..681661bfc6 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_gan.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_gan.py @@ -20,6 +20,7 @@ import sys import paddle import paddle.fluid as fluid +import paddle.fluid.core as core from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC from test_imperative_base import new_program_scope @@ -58,7 +59,7 @@ class Generator(fluid.imperative.Layer): class TestImperativeMnist(unittest.TestCase): - def test_mnist_cpu_float32(self): + def test_gan_float32(self): seed = 90 startup = fluid.Program() @@ -115,7 +116,8 @@ class TestImperativeMnist(unittest.TestCase): sgd = SGDOptimizer(learning_rate=1e-3) sgd.minimize(g_loss) - exe = fluid.Executor(fluid.CPUPlace()) + exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0)) static_params = dict() with fluid.scope_guard(scope): img = np.ones([2, 1], np.float32) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index 63eeae4b71..d0a5a88317 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -145,7 +145,8 @@ class TestImperativeMnist(unittest.TestCase): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - exe = fluid.Executor(fluid.CPUPlace()) + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) mnist = MNIST() sgd = SGDOptimizer(learning_rate=1e-3) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py new file mode 100644 index 0000000000..87a72dd04e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -0,0 +1,370 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import unittest +import numpy as np +import six + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.imperative.nn import Conv2D, Pool2D, BatchNorm, FC +from paddle.fluid.imperative.base import to_variable +from test_imperative_base import new_program_scope + +batch_size = 8 +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": batch_size, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + }, + "batch_size": batch_size, + "lr": 0.1, + "total_images": 1281164, +} + + +def optimizer_setting(params): + ls = params["learning_strategy"] + if ls["name"] == "piecewise_decay": + if "total_images" not in params: + total_images = 1281167 + else: + total_images = params["total_images"] + batch_size = ls["batch_size"] + step = int(total_images / batch_size + 1) + + bd = [step * e for e in ls["epochs"]] + base_lr = params["lr"] + lr = [] + lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] + optimizer = fluid.optimizer.SGD(learning_rate=0.01) + # TODO(minqiyang): Add learning rate scheduler support to imperative mode + # optimizer = fluid.optimizer.Momentum( + # learning_rate=params["lr"], + # learning_rate=fluid.layers.piecewise_decay( + # boundaries=bd, values=lr), + # momentum=0.9, + # regularization=fluid.regularizer.L2Decay(1e-4)) + + return optimizer + + +class ConvBNLayer(fluid.imperative.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=None) + + self._batch_norm = BatchNorm(num_filters, act=act) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + + return y + + +class BottleneckBlock(fluid.imperative.Layer): + def __init__(self, num_channels, num_filters, stride, shortcut=True): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + act='relu') + self.conv1 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu') + self.conv2 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters * 4, + filter_size=1, + act=None) + + if not shortcut: + self.short = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters * 4, + filter_size=1, + stride=stride) + + self.shortcut = shortcut + + self._num_channels_out = num_filters * 4 + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + + y = fluid.layers.elementwise_add(x=short, y=conv2) + + layer_helper = LayerHelper('elementwise_add_activation', act='relu') + return layer_helper.append_activation(y) + + +class ResNet(fluid.imperative.Layer): + def __init__(self, layers=50, class_dim=102): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [64, 128, 256, 512] + + self.conv = ConvBNLayer( + num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu') + self.pool2d_max = Pool2D( + pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') + + self.bottleneck_block_list = [] + num_channels = 64 + for block in range(len(depth)): + shortcut = False + for i in range(depth[block]): + bottleneck_block = BottleneckBlock( + num_channels=num_channels, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut) + num_channels = bottleneck_block._num_channels_out + self.bottleneck_block_list.append(bottleneck_block) + shortcut = True + + self.pool2d_avg = Pool2D( + pool_size=7, pool_type='avg', global_pooling=True) + + import math + stdv = 1.0 / math.sqrt(2048 * 1.0) + + self.out = FC(size=class_dim, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv))) + + def forward(self, inputs): + y = self.conv(inputs) + y = self.pool2d_max(y) + for bottleneck_block in self.bottleneck_block_list: + y = bottleneck_block(y) + y = self.pool2d_avg(y) + y = self.out(y) + return y + + +class TestImperativeResnet(unittest.TestCase): + def test_resnet_float32(self): + seed = 90 + + batch_size = train_parameters["batch_size"] + batch_num = 1 + with fluid.imperative.guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + resnet = ResNet() + optimizer = optimizer_setting(train_parameters) + np.random.seed(seed) + import random + random.seed = seed + train_reader = paddle.batch( + paddle.dataset.flowers.train(use_xmap=False), + batch_size=batch_size) + + dy_param_init_value = {} + for param in fluid.default_main_program().global_block( + ).all_parameters(): + dy_param_init_value[param.name] = param._numpy() + + for batch_id, data in enumerate(train_reader()): + if batch_id >= batch_num: + break + + dy_x_data = np.array( + [x[0].reshape(3, 224, 224) for x in data]).astype('float32') + y_data = np.array([x[1] for x in data]).astype('int64').reshape( + batch_size, 1) + + img = to_variable(dy_x_data) + label = to_variable(y_data) + label._stop_gradient = True + + out = resnet(img) + loss = fluid.layers.cross_entropy(input=out, label=label) + avg_loss = fluid.layers.mean(x=loss) + + dy_out = avg_loss._numpy() + + if batch_id == 0: + for param in fluid.default_main_program().global_block( + ).all_parameters(): + if param.name not in dy_param_init_value: + dy_param_init_value[param.name] = param._numpy() + + avg_loss._backward() + + dy_grad_value = {} + for param in fluid.default_main_program().global_block( + ).all_parameters(): + if not param.stop_gradient: + np_array = np.array(param._ivar._grad_ivar().value() + .get_tensor()) + dy_grad_value[param.name + core.grad_var_suffix( + )] = np_array + + optimizer.minimize(avg_loss) + + dy_param_value = {} + for param in fluid.default_main_program().global_block( + ).all_parameters(): + dy_param_value[param.name] = param._numpy() + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + exe = fluid.Executor(fluid.CPUPlace( + ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) + + resnet = ResNet() + optimizer = optimizer_setting(train_parameters) + + np.random.seed(seed) + import random + random.seed = seed + train_reader = paddle.batch( + paddle.dataset.flowers.train(use_xmap=False), + batch_size=batch_size) + + img = fluid.layers.data( + name='pixel', shape=[3, 224, 224], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + out = resnet(img) + loss = fluid.layers.cross_entropy(input=out, label=label) + avg_loss = fluid.layers.mean(x=loss) + optimizer.minimize(avg_loss) + + # initialize params and fetch them + static_param_init_value = {} + static_param_name_list = [] + static_grad_name_list = [] + for param in fluid.default_startup_program().global_block( + ).all_parameters(): + static_param_name_list.append(param.name) + for param in fluid.default_main_program().global_block( + ).all_parameters(): + if not param.stop_gradient: + static_grad_name_list.append(param.name + + core.grad_var_suffix()) + + out = exe.run(fluid.default_startup_program(), + fetch_list=static_param_name_list) + + for i in range(len(static_param_name_list)): + static_param_init_value[static_param_name_list[i]] = out[i] + + for batch_id, data in enumerate(train_reader()): + if batch_id >= batch_num: + break + + static_x_data = np.array( + [x[0].reshape(3, 224, 224) for x in data]).astype('float32') + y_data = np.array([x[1] for x in data]).astype('int64').reshape( + [batch_size, 1]) + + fetch_list = [avg_loss.name] + fetch_list.extend(static_param_name_list) + fetch_list.extend(static_grad_name_list) + out = exe.run(fluid.default_main_program(), + feed={"pixel": static_x_data, + "label": y_data}, + fetch_list=fetch_list) + + static_param_value = {} + static_grad_value = {} + static_out = out[0] + param_start_pos = 1 + grad_start_pos = len(static_param_name_list) + param_start_pos + for i in range(param_start_pos, + len(static_param_name_list) + param_start_pos): + static_param_value[static_param_name_list[ + i - param_start_pos]] = out[i] + for i in range(grad_start_pos, + len(static_grad_name_list) + grad_start_pos): + static_grad_value[static_grad_name_list[ + i - grad_start_pos]] = out[i] + + self.assertTrue(np.allclose(static_out, dy_out)) + + self.assertEqual(len(dy_param_init_value), len(static_param_init_value)) + for key, value in six.iteritems(static_param_init_value): + self.assertTrue(np.allclose(value, dy_param_init_value[key])) + self.assertTrue(np.isfinite(value.all())) + self.assertFalse(np.isnan(value.any())) + + self.assertEqual(len(dy_grad_value), len(static_grad_value)) + for key, value in six.iteritems(static_grad_value): + self.assertTrue(np.allclose(value, dy_grad_value[key])) + self.assertTrue(np.isfinite(value.all())) + self.assertFalse(np.isnan(value.any())) + + self.assertEqual(len(dy_param_value), len(static_param_value)) + for key, value in six.iteritems(static_param_value): + self.assertTrue(np.allclose(value, dy_param_value[key])) + self.assertTrue(np.isfinite(value.all())) + self.assertFalse(np.isnan(value.any())) + + +if __name__ == '__main__': + unittest.main()