Merge pull request #15322 from velconia/imperative_resnet

Imperative Resnet
inference-pre-release-gpu
Xin Pan 6 years ago committed by GitHub
commit 58cb18d9d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,8 +19,6 @@ limitations under the License. */
#include <sstream>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"

@ -1,5 +1,5 @@
if(WITH_PYTHON)
cc_library(layer SRCS layer.cc DEPS proto_desc operator)
cc_library(tracer SRCS tracer.cc DEPS proto_desc)
cc_library(layer SRCS layer.cc DEPS proto_desc operator device_context blas)
cc_library(tracer SRCS tracer.cc DEPS proto_desc device_context)
cc_library(engine SRCS engine.cc)
endif()

@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/imperative/layer.h"
#include <deque>
#include <limits>
#include <map>
@ -22,6 +23,9 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
@ -34,22 +38,66 @@ std::map<int, py::object> py_funcs_;
using framework::Variable;
void AddTo(Variable* src, Variable* dst) {
framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>();
namespace detail {
template <typename T>
class TensorAddToFunctor : public boost::static_visitor<> {
public:
TensorAddToFunctor(int64_t numel, const T* x, T* y)
: numel_(numel), x_(x), y_(y) {}
void operator()(const platform::CPUPlace& place) {
platform::CPUDeviceContext* ctx = dynamic_cast<platform::CPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto blas = operators::math::GetBlas<platform::CPUDeviceContext, T>(*ctx);
blas.AXPY(numel_, 1., x_, y_);
}
#ifdef PADDLE_WITH_CUDA
void operator()(const platform::CUDAPlace& place) {
platform::CUDADeviceContext* ctx =
dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto blas = operators::math::GetBlas<platform::CUDADeviceContext, T>(*ctx);
blas.AXPY(numel_, 1., x_, y_);
}
#else
void operator()(const platform::CUDAPlace& place) {
PADDLE_THROW("Do NOT support gradient merge in place %s", place);
}
#endif
// there is NO blas in CUDAPinnedPlace
void operator()(const platform::CUDAPinnedPlace& place) {
PADDLE_THROW("Do NOT support gradient merge in place %s", place);
}
private:
int64_t numel_;
const T* x_;
T* y_;
};
} // namespace detail
void AddTo(Variable* src, Variable* dst, platform::Place place) {
framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
// ugly fix for it
if (src_tensor->numel() == 0) {
return;
}
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
"dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
src_tensor->numel());
float* dst_data = dst_tensor->mutable_data<float>(platform::CPUPlace());
const float* src_data = src_tensor->data<float>();
for (int64_t i = 0; i < src_tensor->numel(); ++i) {
dst_data[i] += src_data[i];
}
detail::TensorAddToFunctor<float> func(
src_tensor->numel(), src_tensor->data<float>(),
dst_tensor->mutable_data<float>(place));
boost::apply_visitor(func, place);
}
class Autograd {
@ -120,6 +168,36 @@ class Autograd {
}
};
std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
const bool blocking) const {
PADDLE_ENFORCE(var_->IsInitialized(),
"Variable must be initialized when getting numpy tensor");
std::unique_ptr<VarBase> new_var(new VarBase());
framework::LoDTensor* tensor =
new_var->var_->GetMutable<framework::LoDTensor>();
tensor->Resize(var_->Get<framework::LoDTensor>().dims());
tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
if (blocking) {
platform::DeviceContext* dev_ctx =
platform::DeviceContextPool::Instance().Get(dst_place);
framework::TensorCopySync(var_->Get<framework::LoDTensor>(), dst_place,
tensor);
dev_ctx->Wait();
} else {
framework::TensorCopy(var_->Get<framework::LoDTensor>(), dst_place, tensor);
}
if (platform::is_gpu_place(dst_place)) {
VLOG(3) << "copy tensor " << var_desc_->Name() << " from gpu";
}
return new_var;
}
framework::LoDTensor& VarBase::GradValue() {
VLOG(3) << "get var grad " << var_desc_->Name();
return *(grads_->var_->GetMutable<framework::LoDTensor>());
@ -162,9 +240,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
framework::Scope scope;
platform::CPUPlace place;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place);
p.op.RuntimeInferShape(scope, place, ctx);
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
}
@ -176,7 +253,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i];
framework::Variable* orig_grad = origin_outputs[i];
AddTo(grad, orig_grad);
AddTo(grad, orig_grad, place_);
delete grad;
}
}
@ -188,8 +265,10 @@ void VarBase::RunBackward() {
VLOG(3) << "start backward";
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
float* data = grads_t->mutable_data<float>(platform::CPUPlace());
std::fill(data, data + grads_t->numel(), 1.0);
operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get(
var_->GetMutable<framework::LoDTensor>()->place())),
grads_t, 1.0);
PADDLE_ENFORCE(
grads_ ==

@ -21,17 +21,21 @@
#include <map> // NOLINT
#include <string> // NOLINT
#include <vector> // NOLINT
#include <memory> // NOLINT
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle {
namespace imperative {
class VarBase;
namespace py = ::pybind11;
class PreparedOp {
@ -81,6 +85,8 @@ class PreparedOp {
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
}
inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx; }
const framework::OperatorBase& op;
const framework::RuntimeContext& ctx;
framework::OperatorWithKernel::OpKernelFunc func;
@ -148,6 +154,9 @@ class VarBase {
framework::LoDTensor& GradValue();
std::unique_ptr<VarBase> NewVarBase(const platform::Place& dst_place,
const bool blocking) const;
inline std::string GradName() const {
PADDLE_ENFORCE(
var_desc_,
@ -176,7 +185,8 @@ class OpBase {
: op_desc_(nullptr),
forward_id_(-1),
grad_op_desc_(nullptr),
backward_id_(-1) {}
backward_id_(-1),
place_(platform::CPUPlace()) {}
virtual ~OpBase() {
if (grad_op_desc_) delete grad_op_desc_;
@ -193,6 +203,8 @@ class OpBase {
framework::OpDesc* grad_op_desc_;
int backward_id_;
platform::Place place_;
VarBasePtrMap input_vars_;
VarBasePtrMap output_vars_;
OpBasePtrMap pre_ops_;

@ -14,6 +14,10 @@
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace imperative {
@ -31,16 +35,38 @@ void CreateGradOp(const framework::OpDesc& op_desc,
*grad_op_desc = grad_op_descs[0].release();
}
void InitVar(framework::Variable* var, framework::Variable* grad_var) {
void InitVar(framework::Variable* var, framework::Variable* grad_var,
platform::DeviceContext* dev_ctx) {
PADDLE_ENFORCE_NOT_NULL(dev_ctx,
"Could not get valid device from forward op");
auto& var_t = var->Get<framework::LoDTensor>();
float* data =
grad_var->GetMutable<framework::LoDTensor>()->mutable_data<float>(
var_t.dims(), platform::CPUPlace());
std::fill(data, data + var_t.numel(), 0.0);
var_t.dims(), dev_ctx->GetPlace());
operators::math::set_constant(
*dev_ctx, grad_var->GetMutable<framework::LoDTensor>(), 0.0);
}
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
platform::Place result = place;
for (auto it : inputs) {
for (VarBase* var : it.second) {
platform::Place tmp_place =
var->var_->Get<framework::LoDTensor>().place();
if (!platform::is_same_place(tmp_place, result)) {
PADDLE_THROW(
"Input variable should keep in the same place: %s, but get place: "
"%s of input %s instead",
result, tmp_place, it.first);
}
}
}
return result;
}
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, framework::BlockDesc* block,
const platform::Place expected_place,
const bool stop_gradient) {
std::map<std::string, VarBase*> vars;
@ -105,10 +131,11 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
framework::Scope scope;
platform::CPUPlace place;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place);
p.op.RuntimeInferShape(scope, place, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
op->place_ = GetExpectedPlace(expected_place, inputs);
PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_);
prepared_op.op.RuntimeInferShape(scope, op->place_, ctx);
prepared_op.func(framework::ExecutionContext(
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
if (!stop_gradient) {
framework::OpDesc* grad_op_desc;
@ -131,7 +158,8 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
} else {
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_);
InitVar(var->var_, var->grads_->var_,
prepared_op.GetDeviceContext());
}
// Douts.
grad_in_vars.push_back(var->grads_->var_);
@ -144,10 +172,13 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
for (const std::string& grad_outvar : it.second) {
block->FindRecursiveOrCreateVar(grad_outvar);
auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end());
PADDLE_ENFORCE(var_it != grad_to_var->end(),
"Could not found the grad op output var, should this "
"operator %s's stop gradient be True",
op_desc->Type());
VarBase* var = vars[var_it->second];
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_);
InitVar(var->var_, var->grads_->var_, prepared_op.GetDeviceContext());
}
grad_out_vars.push_back(var->grads_->var_);
}
@ -189,16 +220,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
for (VarBase* out : outputs) {
grad_input_vars.push_back(out->var_);
}
platform::CPUPlace place;
for (VarBase* out : outputs) {
grad_input_vars.push_back(out->grads_->var_);
if (!grad_input_vars.back()->IsInitialized()) {
InitVar(out->var_, grad_input_vars.back());
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
InitVar(out->var_, grad_input_vars.back(),
platform::DeviceContextPool::Instance().Get(place));
}
}
for (const VarBase* inp : inputs) {
grad_output_vars.push_back(inp->grads_->var_);
if (!grad_output_vars.back()->IsInitialized()) {
InitVar(inp->var_, grad_output_vars.back());
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
InitVar(inp->var_, grad_output_vars.back(),
platform::DeviceContextPool::Instance().Get(place));
}
}
}

@ -22,6 +22,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace imperative {
@ -34,21 +35,25 @@ void CreateGradOp(const framework::OpDesc& op_desc,
void InitVar(framework::Variable* var, framework::Variable* grad_var);
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs);
class Tracer {
public:
explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {}
virtual ~Tracer() {}
void Trace(OpBase* op,
const std::map<std::string, std::vector<VarBase*>>& inputs,
const std::map<std::string, std::vector<VarBase*>>& outputs,
framework::BlockDesc* block, const bool stop_gradient = false);
void Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, framework::BlockDesc* block,
const platform::Place expected_place,
const bool stop_gradient = false);
std::vector<VarBase*> PyTrace(OpBase* op, const std::vector<VarBase*>& inputs,
bool stop_gradient = false);
private:
platform::Place GetPlace(const VarBasePtrMap& inputs);
framework::BlockDesc* root_block_;
};

@ -30,8 +30,9 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
"Place %s is not supported, Please re-compile with WITH_GPU "
"option",
place);
}
return it->second.get().get();
}

@ -15,18 +15,38 @@ limitations under the License. */
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle {
namespace pybind {
// Bind Methods
void BindTracer(pybind11::module *m) {
void BindTracer(pybind11::module* m) {
pybind11::class_<imperative::Tracer>(*m, "Tracer", "")
.def("__init__",
[](imperative::Tracer &self, framework::BlockDesc *root_block) {
[](imperative::Tracer& self, framework::BlockDesc* root_block) {
new (&self) imperative::Tracer(root_block);
})
.def("trace", &imperative::Tracer::Trace)
.def("trace",
[](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs,
framework::BlockDesc* block,
const platform::CPUPlace expected_place,
const bool stop_gradient = false) {
self.Trace(op, inputs, outputs, block, expected_place,
stop_gradient);
})
.def("trace",
[](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs,
framework::BlockDesc* block,
const platform::CUDAPlace expected_place,
const bool stop_gradient = false) {
self.Trace(op, inputs, outputs, block, expected_place,
stop_gradient);
})
.def("py_trace", &imperative::Tracer::PyTrace,
pybind11::return_value_policy::take_ownership);
}

@ -138,6 +138,22 @@ PYBIND11_MODULE(core, m) {
.def("_grad_ivar",
[](const imperative::VarBase &self) { return self.grads_; },
py::return_value_policy::reference)
.def("_copy_to",
[](const imperative::VarBase &self, const platform::CPUPlace &place,
bool blocking) {
std::unique_ptr<imperative::VarBase> new_var =
self.NewVarBase(place, blocking);
return new_var.release();
},
py::return_value_policy::take_ownership)
.def("_copy_to",
[](const imperative::VarBase &self, const platform::CUDAPlace &place,
bool blocking) {
std::unique_ptr<imperative::VarBase> new_var =
self.NewVarBase(place, blocking);
return new_var.release();
},
py::return_value_policy::take_ownership)
.def("value", [](const imperative::VarBase &self) { return self.var_; },
py::return_value_policy::reference)
.def_property(

@ -70,6 +70,7 @@ ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
_imperative_tracer_ = None
_imperative_current_expected_place_ = None
def _in_imperative_mode():
@ -80,6 +81,10 @@ def _imperative_tracer():
return _imperative_tracer_
def _current_expected_place():
return _imperative_current_expected_place_
class NameScope(object):
def __init__(self, name="", parent=None):
self._children = dict()
@ -383,8 +388,8 @@ class Variable(object):
self._ivar.stop_gradient = stop_gradient
def _numpy(self):
tensor = self._ivar.value().get_tensor()
return np.array(tensor)
new_ivar = self._ivar._copy_to(core.CPUPlace(), True)
return np.array(new_ivar.value().get_tensor())
def _backward(self):
self._ivar._run_backward()
@ -1311,6 +1316,7 @@ class Block(object):
def _trace_op(self, op, stop_gradient=False):
if _in_imperative_mode():
_imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc,
_imperative_current_expected_place_,
stop_gradient)
def _insert_op(self, index, *args, **kwargs):
@ -2502,5 +2508,18 @@ def _imperative_guard(tracer):
global _imperative_tracer_
tmp_trace = _imperative_tracer_
_imperative_tracer_ = tracer
yield
_imperative_tracer_ = tmp_trace
@contextlib.contextmanager
def _imperative_place_guard(place):
global _imperative_current_expected_place_
tmp_place = _imperative_current_expected_place_
_imperative_current_expected_place_ = place
yield
_imperative_current_expected_place_ = tmp_place

@ -25,18 +25,28 @@ def enabled():
@contextlib.contextmanager
def guard():
def guard(place=None):
train = framework.Program()
startup = framework.Program()
tracer = core.Tracer(train.current_block().desc)
if place is None:
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
with framework.program_guard(train, startup):
with framework.unique_name.guard():
with framework._imperative_guard(tracer):
with framework._imperative_place_guard(place):
yield
def to_variable(value, block=None):
if isinstance(value, np.ndarray):
assert enabled(), "to_variable could only be called in imperative mode"
if not block:
block = framework.default_main_program().current_block()
py_var = framework.Variable(
@ -47,9 +57,7 @@ def to_variable(value, block=None):
dtype=value.dtype)
var = py_var._ivar.value()
tensor = var.get_tensor()
tensor.set(value, core.CPUPlace())
tensor.set(value, framework._current_expected_place())
return py_var
elif isinstance(value, framework.Variable):
return value
else:
raise ValueError("Unsupported type %s" % type(value))

@ -27,6 +27,7 @@ __all__ = [
'Conv2D',
'Pool2D',
'FC',
'BatchNorm',
]
@ -55,7 +56,8 @@ class Conv2D(layers.Layer):
param_attr=param_attr,
bias_attr=bias_attr,
dtype=dtype,
name=name)
name=name,
act=act)
self._groups = groups
self._stride = utils.convert_to_list(stride, 2, 'stride')
@ -141,6 +143,7 @@ class Conv2D(layers.Layer):
outputs={'Out': [pre_act]},
attrs={'axis': 1})
# Currently, we don't support inplace in imperative mode
return self._helper.append_activation(pre_act)
@ -216,6 +219,7 @@ class FC(layers.Layer):
act=None,
name=None):
super(FC, self).__init__()
self._size = size
self._num_flatten_dims = num_flatten_dims
self._dtype = dtype
@ -241,6 +245,16 @@ class FC(layers.Layer):
dtype=self._dtype,
is_bias=False)
if self._helper.bias_attr:
size = list([self._size])
self._b = self._helper.create_parameter(
attr=self._helper.bias_attr,
shape=size,
dtype=self._dtype,
is_bias=True)
else:
self._b = None
def forward(self, input):
tmp = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
@ -253,28 +267,155 @@ class FC(layers.Layer):
"y_num_col_dims": 1
})
out = self._helper.create_variable_for_type_inference(self._dtype)
pre_bias = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="sum",
inputs={"X": [tmp]},
outputs={"Out": out},
outputs={"Out": pre_bias},
attrs={"use_mkldnn": False})
bias_attr = self._helper.bias_attr
if bias_attr:
# add bias
size = list(out.shape[1:])
if not self._built:
self._b = self._helper.create_parameter(
attr=bias_attr, shape=size, dtype=out.dtype, is_bias=True)
bias_out = self._helper.create_variable_for_type_inference(
dtype=out.dtype)
if self._b:
pre_activation = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [out],
inputs={'X': [pre_bias],
'Y': [self._b]},
outputs={'Out': [bias_out]},
attrs={'axis': 1})
out = bias_out
# add activation
return self._helper.append_activation(out)
outputs={'Out': [pre_activation]},
attrs={'axis': self._num_flatten_dims})
else:
pre_activation = pre_bias
# Currently, we don't support inplace in imperative mode
return self._helper.append_activation(pre_activation)
class BatchNorm(layers.Layer):
def __init__(self,
num_channels,
act=None,
is_test=False,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
dtype=core.VarDesc.VarType.FP32,
data_layout='NCHW',
in_place=False,
name=None,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=False,
fuse_with_relu=False,
use_global_stats=False):
super(BatchNorm, self).__init__()
assert bias_attr is not False, "bias_attr should not be False in batch_norm."
from ..layer_helper import LayerHelper
self._helper = LayerHelper(
'batch_norm',
param_attr=param_attr,
bias_attr=bias_attr,
name=name,
act=act)
if dtype == core.VarDesc.VarType.FP16:
self._dtype = core.VarDesc.VarType.FP32
else:
self._dtype = dtype
param_shape = [num_channels]
# create parameter
self._scale = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
# TODO(minqiyang): change stop_gradient sign to trainable to align with static graph
# # setting stop_gradient=True to reduce computation
# if use_global_stats and self._helper.param_attr.learning_rate == 0.:
# self._scale.stop_gradient = True
self._bias = self._helper.create_parameter(
attr=self._helper.bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
# TODO(minqiyang): change stop_gradient sign to trainable to align with static graph
# # setting stop_gradient=True to reduce computation
# if use_global_stats and self._helper.bias_attr.learning_rate == 0.:
# self._bias.stop_gradient = True
self._mean = self._helper.create_parameter(
attr=ParamAttr(
name=moving_mean_name,
initializer=Constant(0.0),
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=self._dtype)
self._mean.stop_gradient = True
self._variance = self._helper.create_parameter(
attr=ParamAttr(
name=moving_variance_name,
initializer=Constant(1.0),
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=self._dtype)
self._variance.stop_gradient = True
self._in_place = in_place
self._momentum = momentum
self._epsilon = epsilon
self._is_test = is_test
self._fuse_with_relu = fuse_with_relu
self._use_global_stats = use_global_stats
def _build_once(self, input):
pass
def forward(self, input):
# create output
# mean and mean_out share the same memory
mean_out = self._mean
# variance and variance out share the same memory
variance_out = self._variance
saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="batch_norm",
inputs={
"X": input,
"Scale": self._scale,
"Bias": self._bias,
"Mean": self._mean,
"Variance": self._variance
},
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={
"momentum": self._momentum,
"epsilon": self._epsilon,
"is_test": self._is_test,
"use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats
})
# Currently, we don't support inplace in imperative mode
return self._helper.append_activation(batch_norm_out)

@ -435,7 +435,10 @@ class LayerHelper(object):
act_type = act.pop('type')
tmp = input_var
# NOTE(dzhwinter): some activation support inplace compution.
if not core.IsInplace(act_type):
# NOTE(minqiyang): currently, we don't support inplace in imperative mode
if not imperative_base.enabled() and core.IsInplace(act_type):
tmp = input_var
else:
tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
self.append_op(
type=act_type,

@ -2874,7 +2874,7 @@ def batch_norm(input,
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
# setting stop_gradient=True to reduce computation
if use_global_stats and helper.bias_attr.learning_rate == 0.:
scale.stop_gradient = True
bias.stop_gradient = True
mean = helper.create_parameter(
attr=ParamAttr(
@ -5856,7 +5856,8 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
type='increment',
inputs={'X': [counter]},
outputs={'Out': [counter]},
attrs={'step': float(step)})
attrs={'step': float(step)},
stop_gradient=True)
counter.stop_gradient = True
return counter

@ -382,7 +382,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu or force_init_on_cpu()
})
},
stop_gradient=True)
out.stop_gradient = True
return out

@ -381,11 +381,14 @@ class Optimizer(object):
optimize_ops = []
if imperative_base.enabled():
if parameter_list is not None:
params_grads = parameter_list
parameters = parameter_list
else:
parameters = program.global_block().all_parameters()
params_grads = []
for param in parameters:
if param.stop_gradient:
continue
# create gradient variable
grad_var = Variable(
block=loss.block,

@ -84,6 +84,7 @@ list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer)
list(REMOVE_ITEM TEST_OPS test_image_classification_resnet)
list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op)
list(REMOVE_ITEM TEST_OPS test_nearest_interp_op)
list(REMOVE_ITEM TEST_OPS test_imperative_resnet)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
@ -91,6 +92,8 @@ py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL)
py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL)
py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS
FLAGS_cudnn_deterministic=1)
if(WITH_DISTRIBUTE)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)

@ -133,7 +133,8 @@ class TestImperative(unittest.TestCase):
x = fluid.layers.reduce_sum(fluid.layers.tanh(x1))
param_grads = fluid.backward.append_backward(
x, parameter_list=[x1.name])[0]
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
static_out, static_grad = exe.run(
feed={inp.name: np_inp},
@ -160,7 +161,8 @@ class TestImperative(unittest.TestCase):
x = l(inp)[0]
param_grads = fluid.backward.append_backward(
x, parameter_list=[l._x_for_debug.name])[0]
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
static_out, static_grad = exe.run(
feed={inp.name: np_inp},
@ -186,7 +188,8 @@ class TestImperative(unittest.TestCase):
out = mlp(inp)
param_grads = fluid.backward.append_backward(
out, parameter_list=[mlp._fc1._w.name])[0]
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
static_out, static_grad = exe.run(

@ -20,6 +20,7 @@ import sys
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC
from test_imperative_base import new_program_scope
@ -58,7 +59,7 @@ class Generator(fluid.imperative.Layer):
class TestImperativeMnist(unittest.TestCase):
def test_mnist_cpu_float32(self):
def test_gan_float32(self):
seed = 90
startup = fluid.Program()
@ -115,7 +116,8 @@ class TestImperativeMnist(unittest.TestCase):
sgd = SGDOptimizer(learning_rate=1e-3)
sgd.minimize(g_loss)
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0))
static_params = dict()
with fluid.scope_guard(scope):
img = np.ones([2, 1], np.float32)

@ -145,7 +145,8 @@ class TestImperativeMnist(unittest.TestCase):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
mnist = MNIST()
sgd = SGDOptimizer(learning_rate=1e-3)

Loading…
Cancel
Save