|
|
|
@ -213,7 +213,7 @@ def _debug_string_(proto, throw_on_error=True):
|
|
|
|
|
return proto.__str__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Variable(core.VarBase):
|
|
|
|
|
class Variable(object):
|
|
|
|
|
"""
|
|
|
|
|
In Fluid, every input and output of an operator is a variable. In most
|
|
|
|
|
cases, variables are used for holding different kinds of data or training
|
|
|
|
@ -277,7 +277,6 @@ class Variable(core.VarBase):
|
|
|
|
|
stop_gradient=False,
|
|
|
|
|
is_data=False,
|
|
|
|
|
**kwargs):
|
|
|
|
|
core.VarBase.__init__(self)
|
|
|
|
|
self.block = block
|
|
|
|
|
self.error_clip = error_clip
|
|
|
|
|
|
|
|
|
@ -357,6 +356,9 @@ class Variable(core.VarBase):
|
|
|
|
|
self.op = None
|
|
|
|
|
self.stop_gradient = stop_gradient
|
|
|
|
|
self.is_data = is_data
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
self._ivar = core.VarBase()
|
|
|
|
|
self._ivar.desc = self.desc
|
|
|
|
|
|
|
|
|
|
def _numpy(self):
|
|
|
|
|
scope = _imperative_tracer().get_scope(self.block.desc)
|
|
|
|
@ -365,10 +367,10 @@ class Variable(core.VarBase):
|
|
|
|
|
|
|
|
|
|
def _backward(self):
|
|
|
|
|
scope = _imperative_tracer().get_scope(self.block.desc)
|
|
|
|
|
self._run_backward(scope)
|
|
|
|
|
self._ivar._run_backward(scope)
|
|
|
|
|
|
|
|
|
|
def _gradient(self):
|
|
|
|
|
return np.array(self._grad())
|
|
|
|
|
return np.array(self._ivar._grad())
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
return self.to_string(True)
|
|
|
|
@ -516,7 +518,7 @@ class OpProtoHolder(object):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Operator(core.OpBase):
|
|
|
|
|
class Operator(object):
|
|
|
|
|
"""
|
|
|
|
|
In Fluid, all the operation are represented by Operator, and Operator
|
|
|
|
|
is regarded as a build in an instruction of a Block. Users can use the
|
|
|
|
@ -572,7 +574,6 @@ class Operator(core.OpBase):
|
|
|
|
|
inputs=None,
|
|
|
|
|
outputs=None,
|
|
|
|
|
attrs=None):
|
|
|
|
|
core.OpBase.__init__(self)
|
|
|
|
|
self.block = block
|
|
|
|
|
self.desc = desc
|
|
|
|
|
# note: not add self.attrs here:
|
|
|
|
@ -612,7 +613,6 @@ class Operator(core.OpBase):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
self.inputs = []
|
|
|
|
|
if inputs is not None:
|
|
|
|
|
for in_proto in proto.inputs:
|
|
|
|
|
found = find_name(inputs, in_proto.name)
|
|
|
|
@ -639,13 +639,6 @@ class Operator(core.OpBase):
|
|
|
|
|
else:
|
|
|
|
|
self.desc.set_input(in_proto.name, [])
|
|
|
|
|
|
|
|
|
|
for inp in inputs.values():
|
|
|
|
|
if isinstance(inp, Variable):
|
|
|
|
|
self.inputs.append(inp)
|
|
|
|
|
elif isinstance(inp, list) or isinstance(inp, tuple):
|
|
|
|
|
self.inputs.extend(inp[:])
|
|
|
|
|
|
|
|
|
|
self.outputs = []
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
given = set()
|
|
|
|
|
need = set()
|
|
|
|
@ -674,12 +667,6 @@ class Operator(core.OpBase):
|
|
|
|
|
arg.op = self
|
|
|
|
|
self.desc.set_output(out_proto.name, out_arg_names)
|
|
|
|
|
|
|
|
|
|
for out in outputs.values():
|
|
|
|
|
if isinstance(out, Variable):
|
|
|
|
|
self.outputs.append(out)
|
|
|
|
|
elif isinstance(out, list) or isinstance(out, tuple):
|
|
|
|
|
self.outputs.extend(out[:])
|
|
|
|
|
|
|
|
|
|
if op_attrs is not None:
|
|
|
|
|
if not isinstance(op_attrs, dict):
|
|
|
|
|
raise TypeError("'attrs' should be a dict.")
|
|
|
|
@ -694,6 +681,23 @@ class Operator(core.OpBase):
|
|
|
|
|
if self._has_kernel(type):
|
|
|
|
|
self.desc.infer_var_type(self.block.desc)
|
|
|
|
|
self.desc.infer_shape(self.block.desc)
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
self.iop = core.OpBase()
|
|
|
|
|
self.iop.desc = self.desc
|
|
|
|
|
self.inputs = []
|
|
|
|
|
if inputs is not None:
|
|
|
|
|
for inp in inputs.values():
|
|
|
|
|
if isinstance(inp, Variable):
|
|
|
|
|
self.inputs.append(inp)
|
|
|
|
|
elif isinstance(inp, list) or isinstance(inp, tuple):
|
|
|
|
|
self.inputs.extend(inp[:])
|
|
|
|
|
self.outputs = []
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
for out in outputs.values():
|
|
|
|
|
if isinstance(out, Variable):
|
|
|
|
|
self.outputs.append(out)
|
|
|
|
|
elif isinstance(out, list) or isinstance(out, tuple):
|
|
|
|
|
self.outputs.extend(out[:])
|
|
|
|
|
|
|
|
|
|
def _has_kernel(self, op_type):
|
|
|
|
|
return op_type not in self.OP_WITHOUT_KERNEL_SET
|
|
|
|
@ -1246,7 +1250,8 @@ class Block(object):
|
|
|
|
|
op_desc = self.desc.append_op()
|
|
|
|
|
op = Operator(block=self, desc=op_desc, *args, **kwargs)
|
|
|
|
|
if _in_imperative_mode():
|
|
|
|
|
_imperative_tracer().trace(op, op.inputs, op.outputs, self.desc)
|
|
|
|
|
_imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs],
|
|
|
|
|
[v._ivar for v in op.outputs], self.desc)
|
|
|
|
|
self.ops.append(op)
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|