Polish code

test=develop
revert-16045-imperative_remove_desc
minqiyang 6 years ago committed by ceci3
parent afc3fcd509
commit 3723dcc301

@ -156,6 +156,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
}
void BlockDesc::RemoveOpInternal(const OpDesc *op_desc) {
// TODO(minqiyang): make this faster
for (auto it = ops_.begin(); it != ops_.end(); ++it) {
if (it->get() == op_desc) {
ops_.erase(it);

@ -235,6 +235,8 @@ class PYBIND11_HIDDEN OpBase {
backward_hooks_() {}
virtual ~OpBase() {
// TODO(minqiyang): remove op_desc from block_desc in tracer
//
// reset all output vars' pre op
for (auto iter : output_vars_) {
for (VarBase* var : iter.second) {
@ -242,13 +244,6 @@ class PYBIND11_HIDDEN OpBase {
}
}
// remove op desc from block desc
if (op_desc_) {
if (block_) {
block_->RemoveOpInternal(op_desc_);
}
}
// release resource
for (framework::OpDesc* desc : grad_op_descs_) {
delete desc;

@ -19,7 +19,7 @@ import numpy as np
from .wrapped_decorator import signature_safe_contextmanager
from .core import VarDesc
from . import unique_name
from .imperative import base
from .imperative import base as imperative_base
__all__ = [
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
@ -166,7 +166,7 @@ class ConstantInitializer(Initializer):
'force_cpu': self._force_cpu or force_init_on_cpu()
},
stop_gradient=True)
if not base.enabled():
if not imperative_base.enabled():
var.op = op
return op
@ -246,7 +246,7 @@ class UniformInitializer(Initializer):
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
if not base.enabled():
if not imperative_base.enabled():
var.op = op
return op
@ -325,7 +325,7 @@ class NormalInitializer(Initializer):
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
if not base.enabled():
if not imperative_base.enabled():
var.op = op
return op
@ -404,7 +404,7 @@ class TruncatedNormalInitializer(Initializer):
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
if not base.enabled():
if not imperative_base.enabled():
var.op = op
return op
@ -510,7 +510,7 @@ class XavierInitializer(Initializer):
"seed": self._seed
},
stop_gradient=True)
if not base.enabled():
if not imperative_base.enabled():
var.op = op
return op
@ -611,7 +611,7 @@ class MSRAInitializer(Initializer):
"seed": self._seed
},
stop_gradient=True)
if not base.enabled():
if not imperative_base.enabled():
var.op = op
return op
@ -710,7 +710,7 @@ class BilinearInitializer(Initializer):
'shape': list(shape),
value_name: values
})
if not base.enabled():
if not imperative_base.enabled():
var.op = op
return op
@ -769,7 +769,7 @@ class NumpyArrayInitializer(Initializer):
value_name: values
},
stop_gradient=True)
if not base.enabled():
if not imperative_base.enabled():
var.op = op
return op

Loading…
Cancel
Save