Feature/expand params in auto-generated pybind functions for dygraph operators (#23181)

* expand parameters, test=develop

* support resnet, test=develop

* fix resnet, test=develop

* support duplicable out, test=develop

* support ptb

* fix bugs, test=develop

* support null input, test=develop

* fix bugs, test=develop

* fix batchNorm is_test, test=develop

* refine code, test=develop

* follow comments, test=develop

* follow comments, test=develop

* follow comments, test=develop

* follow comments, test=develop
revert-23830-2.0-beta
Leo Chen 5 years ago committed by GitHub
parent 9474d140de
commit 488b2387e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,13 +18,45 @@
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/pybind/imperative.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
static inline void ConstructAttrMapFromPyArgs(framework::AttributeMap* attrs,
const py::args& args) {
PADDLE_ENFORCE_EQ(
args.size() % 2, 0,
platform::errors::InvalidArgument(
"The number of arguments for arributes should be even."));
for (size_t i = 0; i < args.size(); i += 2) {
auto name = args[i].cast<std::string>();
auto value = args[i + 1].cast<framework::Attribute>();
(*attrs)[name] = value;
}
}
static inline std::vector<std::shared_ptr<imperative::VarBase>>
ConstructDuplicableOutput(const size_t num) {
auto tracer = imperative::GetCurrentTracer();
std::vector<std::shared_ptr<imperative::VarBase>> res;
res.reserve(num);
for (size_t i = 0; i < num; i++) {
auto var_base_name = tracer->GenerateUniqueName();
res.emplace_back(new imperative::VarBase(var_base_name));
}
return res;
}
} // namespace pybind
} // namespace paddle
// This include must be the last line
#include "paddle/fluid/pybind/op_function_impl.h"

File diff suppressed because it is too large Load Diff

@ -15,7 +15,7 @@
from __future__ import print_function
from .. import core
from ..framework import Variable, convert_np_dtype_to_dtype_
from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator
from ..layers.layer_function_generator import OpProtoHolder
from . import to_variable, no_grad
@ -42,17 +42,11 @@ def monkey_patch_math_varbase():
@no_grad
def create_tensor(value, dtype, shape):
value = float(value)
inputs = {}
attrs = {
'dtype': dtype,
'shape': shape,
'value': value,
'force_cpu': False
}
outs = core.ops.fill_constant(inputs, attrs)
outs['Out'][0].stop_gradient = True
return outs['Out'][0]
out = _varbase_creator(dtype=dtype)
out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape,
'value', value, 'force_cpu', False)
out.stop_gradient = True
return out
def create_scalar(value, dtype):
return create_tensor(value, dtype, shape=[1])
@ -102,19 +96,11 @@ def monkey_patch_math_varbase():
print("new var's dtype is: {}, numpy dtype is {}".format(new_variable.dtype, new_variable.numpy().dtype))
"""
inputs = {'X': [self]}
attrs = {
"in_dtype": self.dtype,
"out_dtype": convert_np_dtype_to_dtype_(dtype)
}
outs = core.ops.cast(inputs, attrs)
return outs['Out'][0]
return core.ops.cast(self, 'in_dtype', self.dtype, 'out_dtype',
convert_np_dtype_to_dtype_(dtype))
def _scalar_elementwise_op_(var, scale, bias):
inputs = {'X': [var]}
attrs = {"scale": scale, "bias": bias}
outs = core.ops.scale(inputs, attrs)
return outs['Out'][0]
return core.ops.scale(var, 'scale', scale, 'bias', bias)
def _neg_(var):
return _scalar_elementwise_op_(var, -1.0, 0.0)
@ -208,11 +194,8 @@ def monkey_patch_math_varbase():
other_var = tmp
axis = -1
op = getattr(core.ops, op_type)
inputs = {'X': [self], 'Y': [other_var]}
attrs = {'axis': axis}
outs = op(inputs, attrs)
return outs['Out'][0]
math_op = getattr(core.ops, op_type)
return math_op(self, other_var, 'aixs', axis)
comment = OpProtoHolder.instance().get_op_proto(op_type).comment

File diff suppressed because it is too large Load Diff

@ -33,15 +33,15 @@ def _append_activation_in_dygraph(input,
"""
if not act:
return input
attrs = {}
if (use_cudnn is not None) and use_cudnn:
attrs['use_cudnn'] = use_cudnn
if (use_mkldnn is not None) and use_mkldnn:
attrs['use_mkldnn'] = use_mkldnn
inputs = {"X": [input]}
attrs = ()
if use_cudnn:
attrs = ('use_cudnn', use_cudnn)
if use_mkldnn:
attrs += ('use_mkldnn', use_mkldnn)
act_op = getattr(core.ops, act)
res = act_op(inputs, attrs)
return res['Out'][0]
return act_op(input, *attrs)
@dygraph_only
@ -58,7 +58,4 @@ def _append_bias_in_dygraph(input, bias=None, axis=1):
if not bias:
return input
attrs = {'axis': axis}
inputs = {'X': [input], 'Y': [bias]}
outs = core.ops.elementwise_add(inputs, attrs)
return outs['Out'][0]
return core.ops.elementwise_add(input, bias, 'axis', axis)

@ -253,10 +253,8 @@ def generate_activation_fn(op_type):
def func(x, name=None):
if in_dygraph_mode():
inputs = {'X': [x]}
op = getattr(core.ops, op_type)
outs = op(inputs)
return outs['Out'][0]
return op(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
op_type)

@ -238,13 +238,13 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
if not soft_label:
return cross_entropy2(input, label, ignore_index)
if in_dygraph_mode():
return core.ops.cross_entropy(input, label, "soft_label", soft_label,
"ignore_index", ignore_index)
inputs = {'X': [input], 'Label': [label]}
attrs = {"soft_label": soft_label, "ignore_index": ignore_index}
if in_dygraph_mode():
outs = core.ops.cross_entropy(inputs, attrs)
return outs['Y'][0]
check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'],
'cross_entropy')
helper = LayerHelper('cross_entropy', **locals())
@ -255,13 +255,13 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
def cross_entropy2(input, label, ignore_index=kIgnoreIndex):
inputs = {'X': [input], 'Label': [label]}
attrs = {'ignore_index': ignore_index}
if in_dygraph_mode():
outs = core.ops.cross_entropy2(inputs, attrs)
return outs['Y'][0]
loss, _, _ = core.ops.cross_entropy2(input, label, 'ignore_index',
ignore_index)
return loss
inputs = {'X': [input], 'Label': [label]}
attrs = {'ignore_index': ignore_index}
check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'],
'cross_entropy2')
helper = LayerHelper('cross_entropy2', **locals())
@ -1233,21 +1233,22 @@ def softmax_with_cross_entropy(logits,
out = fluid.layers.softmax_with_cross_entropy(
logits=fc, label=label)
"""
if in_dygraph_mode():
softmax, loss = core.ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode, 'axis',
axis)
if not return_softmax:
return loss
else:
return loss, softmax
attrs = {
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': numeric_stable_mode,
'axis': axis
}
if in_dygraph_mode():
inputs = {'Logits': [logits], 'Label': [label]}
outs = core.ops.softmax_with_cross_entropy(inputs, attrs)
if not return_softmax:
return outs['Loss'][0]
else:
return outs['Loss'][0], outs['Softmax'][0]
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)

@ -74,24 +74,15 @@ def accuracy(input, label, k=1, correct=None, total=None):
#[array([0.6666667], dtype=float32)]
"""
if in_dygraph_mode():
topk_out, topk_indices = nn.topk(input, k=k)
inputs = {
"Out": [topk_out],
"Indices": [topk_indices],
"Label": [label]
}
acc_out = _varbase_creator(dtype="float32")
if correct is None:
correct = _varbase_creator(dtype="int64")
correct = _varbase_creator(dtype="int32")
if total is None:
total = _varbase_creator(dtype="int64")
outputs = {
"Accuracy": [acc_out],
"Correct": [correct],
"Total": [total]
}
outs = core.ops.accuracy(inputs, {}, outputs)
return outs['Accuracy'][0]
total = _varbase_creator(dtype="int32")
topk_out, topk_indices = nn.topk(input, k=k)
_acc, _, _ = core.ops.accuracy(topk_out, topk_indices, label, correct,
total)
return _acc
helper = LayerHelper("accuracy", **locals())
check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'],
@ -99,9 +90,9 @@ def accuracy(input, label, k=1, correct=None, total=None):
topk_out, topk_indices = nn.topk(input, k=k)
acc_out = helper.create_variable_for_type_inference(dtype="float32")
if correct is None:
correct = helper.create_variable_for_type_inference(dtype="int64")
correct = helper.create_variable_for_type_inference(dtype="int32")
if total is None:
total = helper.create_variable_for_type_inference(dtype="int64")
total = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op(
type="accuracy",
inputs={

File diff suppressed because it is too large Load Diff

@ -255,15 +255,12 @@ def concat(input, axis=0, name=None):
"""
if in_dygraph_mode():
inputs = {'X': input}
if isinstance(axis, Variable):
axis = axis.numpy()
assert axis.shape == (
1, ), "axis of type Variable should have shape [1]"
axis = axis[0]
attrs = {'axis': axis}
outs = core.ops.concat(inputs, attrs)
return outs['Out'][0]
return core.ops.concat(input, 'axis', axis)
if not isinstance(input, list):
warnings.warn(
@ -586,12 +583,13 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
shape))
else:
shape = list(shape.numpy().astype(int))
attrs['shape'] = shape
dtype = convert_np_dtype_to_dtype_(dtype)
if out is None:
out = _varbase_creator(dtype=dtype)
attrs['dtype'] = out.dtype
outputs = {'Out': [out]}
outs = core.ops.fill_constant({}, attrs, outputs)
core.ops.fill_constant(out, 'value',
float(value), 'force_cpu', force_cpu, 'dtype',
dtype, 'str_value', attrs['str_value'], 'shape',
shape)
out.stop_gradient = True
return out

@ -889,16 +889,11 @@ class SGDOptimizer(Optimizer):
@no_grad
def _append_optimize_op(self, block, param_and_grad):
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
inputs = {
"Param": [param_and_grad[0]],
"Grad": [param_and_grad[1]],
"LearningRate": [self._create_param_lr(param_and_grad)]
}
attrs = {}
outputs = {'ParamOut': [param_and_grad[0]]}
outs = core.ops.sgd(inputs, attrs, outputs)
return outs['ParamOut'][0]
core.ops.sgd(param_and_grad[0], lr, param_and_grad[1],
param_and_grad[0])
return None
assert isinstance(block, framework.Block)
# create the optimize op
@ -907,7 +902,7 @@ class SGDOptimizer(Optimizer):
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"LearningRate": self._create_param_lr(param_and_grad)
"LearningRate": lr
},
outputs={"ParamOut": param_and_grad[0]},
stop_gradient=True)
@ -1009,24 +1004,27 @@ class MomentumOptimizer(Optimizer):
velocity_acc = self._get_accumulator(self._velocity_acc_str,
param_and_grad[0])
attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov}
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
_, _ = core.ops.momentum(param_and_grad[0], param_and_grad[1],
velocity_acc, lr, param_and_grad[0],
velocity_acc, 'mu', self._momentum,
'use_nesterov', self._use_nesterov)
return None
attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov}
inputs = {
"Param": [param_and_grad[0]],
"Grad": [param_and_grad[1]],
"Velocity": [velocity_acc],
"LearningRate": [self._create_param_lr(param_and_grad)]
"LearningRate": [lr]
}
outputs = {
"ParamOut": [param_and_grad[0]],
"VelocityOut": [velocity_acc]
}
if framework.in_dygraph_mode():
core.ops.momentum(inputs, attrs, outputs)
return None
# create the momentum optimize op
momentum_op = block.append_op(
type=self.type,
@ -1849,12 +1847,27 @@ class AdamOptimizer(Optimizer):
param_and_grad[0])
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
lr = self._create_param_lr(param_and_grad)
# create the adam optimize op
if framework.in_dygraph_mode():
_beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0)
_beta2 = self._beta2 if not isinstance(
self._beta2, Variable) else self._beta2.numpy().item(0)
_, _, _, _, _ = core.ops.adam(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon,
'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread',
1000, 'beta1', _beta1, 'beta2', _beta2)
return None
inputs = {
"Param": [param_and_grad[0]],
"Grad": [param_and_grad[1]],
"LearningRate": [self._create_param_lr(param_and_grad)],
"LearningRate": [lr],
"Moment1": [moment1],
"Moment2": [moment2],
"Beta1Pow": [beta1_pow_acc],
@ -1882,10 +1895,6 @@ class AdamOptimizer(Optimizer):
else:
attrs['beta2'] = self._beta2
if framework.in_dygraph_mode():
core.ops.adam(inputs, attrs, outputs)
return None
adam_op = block.append_op(
type=self.type,
inputs=inputs,

@ -54,7 +54,7 @@ def _create_regularization_of_grad(param, grad, regularization=None):
inputs = {"X": [grad, regularization_term]}
outputs = {"Out": [new_grad]}
if in_dygraph_mode():
core.ops.sum(inputs, {}, outputs)
new_grad = core.ops.sum([grad, regularization_term])
else:
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
@ -183,8 +183,7 @@ class L2DecayRegularizer(WeightDecayRegularizer):
attrs = {"scale": self._regularization_coeff}
if framework.in_dygraph_mode():
outs = core.ops.scale(inputs, attrs)
return outs['Out'][0]
return core.ops.scale(param, "scale", self._regularization_coeff)
else:
decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)

@ -112,9 +112,9 @@ class InstanceNorm(fluid.dygraph.Layer):
def forward(self, input):
if fluid.in_dygraph_mode():
inputs = {'X': [input], 'Scale': [self.scale], 'Bias': [self.bias]}
attrs = {'epsilon': self.epsilon}
return fluid.core.ops.instance_norm(inputs, attrs)['Y'][0]
out, _, _ = fluid.core.ops.instance_norm(
input, self.scale, self.bias, 'epsilon', self.epsilon)
return out
else:
return fluid.layers.instance_norm(
input,

@ -28,8 +28,7 @@ class TestTracedLayer(fluid.dygraph.Layer):
super(TestTracedLayer, self).__init__(name_scope)
def forward(self, input):
inputs = {'X': [input] if isinstance(input, fluid.Variable) else input}
return core.ops.relu(inputs)['Out'][0]
return core.ops.relu(input)
class TestVariable(unittest.TestCase):
@ -47,9 +46,7 @@ class TestVariable(unittest.TestCase):
x.stop_gradient = False
res1 = layers.elementwise_add(x, y)
inputs = {'X': [x], 'Y': [y]}
res2 = core.ops.elementwise_add(inputs)['Out'][0]
res2 = core.ops.elementwise_add(x, y)
self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
@ -61,9 +58,7 @@ class TestVariable(unittest.TestCase):
y = fluid.dygraph.to_variable(b)
res1 = layers.elementwise_mul(x, y)
inputs = {'X': [x], 'Y': [y]}
res2 = core.ops.elementwise_mul(inputs)['Out'][0]
res2 = core.ops.elementwise_mul(x, y)
self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
@ -73,9 +68,7 @@ class TestVariable(unittest.TestCase):
x = fluid.dygraph.to_variable(a)
res1 = layers.relu(x)
inputs = {'X': [x]}
res2 = core.ops.relu(inputs)['Out'][0]
res2 = core.ops.relu(x)
self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
@ -88,8 +81,7 @@ class TestVariable(unittest.TestCase):
x.stop_gradient = False
y.stop_gradient = False
inputs = {'X': [x], 'Y': [y]}
loss = core.ops.elementwise_mul(inputs)['Out'][0]
loss = core.ops.elementwise_mul(x, y)
loss.backward()
x_grad = x.gradient()
@ -104,7 +96,7 @@ class TestVariable(unittest.TestCase):
a = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
x = fluid.dygraph.to_variable(a)
res_dygraph, static_layer = TracedLayer.trace(
layer, inputs=[x]) # dygraph out
layer, inputs=x) # dygraph out
res_static_graph = static_layer([x])[0]
self.assertTrue(

Loading…
Cancel
Save