You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/doc/survey/dynamic_graph.md

17 KiB

Automatic Differentiation with the Tape

Automatic Differentiation

A key challenge in the field of deep learning is to automatically derive the backward pass from the forward pass described algorithmically by researchers. Such a derivation, or a transformation of the forward pass program, has been long studied before the recent prosperity of deep learning in the field known as automatic differentiation.

The Tape

Given the forward pass program (usually in Python in practices), there are two strategies to derive the backward pass:

  1. from the forward pass program itself, or
  2. from the execution trace of the forward pass program, which is often known as the tape.

This article surveys systems that follow the latter strategy.

Dynamic Network

When we train a deep learning model, the tape changes every iteration as the input data change, so we have to re-derive the backward pass every iteration. This is known as dynamic network.

Deep learning systems that utilize the idea of dynamic network gained their popularities in recent years. This article surveys two representative systems: PyTorch and DyNet.

An Overview

Both frameworks record a tape of the computation and interpreting (or run-time compiling) a transformation of the tape played back in reverse. This tape is a different kind of entity than the original program.[link]

Consider the following code feedforward model.

x = Variable(randn(20, 1)))
label = Variable(randint(1))
W_1, W_2 = Variable(randn(20, 20)), Variable(randn(10, 20))
h = matmul(W_1, x)
pred = matmul(W_2, x)
loss = softmax(pred, label)
loss.backward()

1) Dynet uses List to encode the Tape

During the forward execution, a list of operators, in this case matmul, matmul and softmax, are recorded in the tape, along with the necessary information needed to do the backward such as pointers to the inputs and outputs. Then the tape is played in reverse order at loss.backward().

digraph g { graph [ rankdir = "LR" ]; node [ fontsize = "16" shape = "ellipse" ]; edge []; "node0" [ label = " type: matmul | input: W_1, x | output: h" shape = "record" ]; "node1" [ label = " type: matmul | input: W_2, h | output: pred" shape = "record" ]; "node2" [ label = " type: softmax | input: pred, label | output: loss" shape = "record" ]; "node0":f0 -> "node1":f0 []; "node1":f0 -> "node2":f0 []; }

Alt text

2) Pytorch uses Node Graph to encode the Tape

The graph is composed of Variables and Functions. During the forward execution, a Variable records its creator function, e.g. h.creator = matmul. And a Function records its inputs' previous/dependent functions prev_func through creator, e.g. matmul.prev_func = matmul1. At loss.backward(), a topological sort is performed on all prev_funcs. Then the grad op is performed by the sorted order.

digraph g { graph [ rankdir = "LR" ];
subgraph function {
    node [
        fontsize = "16"
        style = filled
        shape = "record"
    ];
    "matmul0" [ label = "<f0> type: matmul | prev_func: None" ];
    "matmul1" [ label = "<f0> type: matmul | prev_func: matmul" ];
    "softmax" [ label = "<f0> type: softmax | prev_func: matmul" ];
}

subgraph variable {
    node [
        fontsize = "16"
        shape = "Mrecord"
        style = filled
        fillcolor = white
    ];
    "x" [ label = "<f0> x | <f1> creator: None" ];
    "label" [ label = "<f0> label | <f1> creator: None" ];
    "W_1" [ label = "<f0> W_1 | <f1> creator: None" ];
    "W_2" [ label = "<f0> W_2 | <f1> creator: None" ];
    "h" [ label = "<f0> h | <f1> creator: None" ];
    "pred" [ label = "<f0> pred | <f1> creator: matmul" ];
    "loss" [ label = "<f0> loss | <f1> creator: softmax" ];
}

subgraph data_flow {
    "x":f0 -> "matmul0":f0;
    "W_1":f0 -> "matmul0":f0;
    "matmul0":f0 -> "h":f0;

    "h":f0 -> "matmul1":f0;
    "W_2":f0 -> "matmul1":f0;
    "matmul1":f0 -> "pred":f0;

    "pred":f0 -> "softmax":f0;
    "label":f0 -> "softmax":f0;
    "softmax":f0 -> "loss":f0;
}

subgraph prev_func {
    edge [color="red", arrowsize="0.6", penwidth="1", constraint=false];
    "matmul1":f1 -> "matmul0":f0;
    "softmax":f1 -> "matmul1":f0;
    label = "prev_func";
}

}

Alt text

Chainer and Autograd uses the similar techniques to record the forward pass. For details please refer to the appendix.

Design choices

1) Dynet's List vs Pytorch's Node Graph

What's good about List:

  1. It avoids a topological sort. One only needs to traverse the list of operators in reverse and calling the corresponding backward operator.
  2. It promises effient data parallelism implementations. One could count the time of usage of a certain variable during the construction list. Then in the play back, one knows the calculation of a variable has completed. This enables communication and computation overlapping.

What's good about Node Graph:

  1. More flexibility. PyTorch users can mix and match independent graphs however they like, in whatever threads they like (without explicit synchronization). An added benefit of structuring graphs this way is that when a portion of the graph becomes dead, it is automatically freed. [2] Consider the following example, Pytorch only does backward on SmallNet while Dynet does both BigNet and SmallNet.
result = BigNet(data)
loss = SmallNet(data)
loss.backward()

2) Dynet's Lazy evaluation vs Pytorch's Immediate evaluation

Dynet builds the list in a symbolic matter. Consider the following example

for epoch in range(num_epochs):
    for in_words, out_label in training_data:
        dy.renew_cg()
        W = dy.parameter(W_p)
        b = dy.parameter(b_p)
        score_sym = dy.softmax(W*dy.concatenate([E[in_words[0]],E[in_words[1]]])+b)
        loss_sym = dy.pickneglogsoftmax(score_sym, out_label)
        loss_val = loss_sym.value()
        loss_sym.backward()

The computation of lookup, concat, matmul and softmax didn't happen until the call of loss_sym.value(). This defered execution is useful because it allows some graph-like optimization possible, e.g. kernel fusion.

Pytorch chooses immediate evaluation. It avoids ever materializing a "forward graph"/"tape" (no need to explicitly call dy.renew_cg() to reset the list), recording only what is necessary to differentiate the computation, i.e. creator and prev_func.

What can fluid learn from them?

Please refer to paddle/contrib/dynamic/.

Appendix

Overview

Framework Has Tape Core in C++ First Release Date
Autograd No No Mar 5, 2015
Chainer No No Jun 5, 2015
Pytorch No Yes Aug 31, 2016
Dynet Yes Yes Oct 12, 2016

Source Code

Autograd

Backward code. In the forward pass, a graph of VJPNode is constructed.

# User API
def make_grad(fun, x):
    start_node = VJPNode.new_root()
    end_value, end_node =  trace(start_node, fun, x)
    return backward_pass(g, end_node), end_value

# trace the forward pass by creating VJPNodes
def trace(start_node, fun, x):
    with trace_stack.new_trace() as t:
        start_box = new_box(x, t, start_node)
        end_box = fun(start_box)
        return end_box._value, end_box._node

def backward_pass(g, end_node):
    outgrads = {end_node : (g, False)}
    for node in toposort(end_node):
        outgrad = outgrads.pop(node)
        ingrads = node.vjp(outgrad[0])
        for parent, ingrad in zip(node.parents, ingrads):
            outgrads[parent] = add_outgrads(outgrads.get(parent), ingrad)
    return outgrad[0]

# Every VJPNode corresponds to a op_grad
class VJPNode(Node):
    __slots__ = ['parents', 'vjp']
    def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
        self.parents = parents
        vjpmaker = primitive_vjps[fun]
        self.vjp = vjpmaker(parent_argnums, value, args, kwargs)

Chainer

Example Code

# (1) Function Set definition, creates FunctionNode
model = FunctionSet(
    l1=F.Linear(784, 100),
    l2=F.Linear(100, 100),
    l3=F.Linear(100, 10)).to_gpu()

# (2) Optimizer Setup
opt = optimizers.SGD()
opt.setup(model)

# (3) Forward computation
def forward(x, t):
    h1 = F.relu(model.l1(x))
    h2 = F.relu(model.l2(h1))
    y = model.l3(h2)
    return F.softmax_cross_entropy(y, t)

# (4) Training loop
for epoch in xrange(n_epoch):
    for i in xrange(0, N, b_size):
        x = Variable(to_gpu(...))
        t = Variable(to_gpu(...))
        opt.zero_grads()
        loss = forward(x, t)
        loss.backward()
        opt.update()

In forward(x, t), a graph of VariableNode and FunctionNode is constructed. Every output's VariableNode.creator is pointed to the FunctionNode.

class FunctionNode(object):
    ...
    def apply(self, inputs):
        outputs = self.forward(inputs)
        ret = tuple([variable.Variable(y, requires_grad=requires_grad)
                     for y in outputs])
        # Topological ordering
        self.rank = max([x.rank for x in inputs]) if input_vars else 0
        # Add backward edges
        for y in ret:
            y.creator_node = self
        self.inputs = tuple([x.node for x in input_vars])
        self.outputs = tuple([y.node for y in ret])

        return ret

loss.backward() will calculate the accumulated gradient of all variables. All the backward of FunctionNodes will be called based on the topological order.

class VariableNode(object):
    ...
    def backward(self, retain_grad, loss_scale):
        if self.creator_node is None:
            return

        cand_funcs = []
        seen_set = set()
        grads = {}

        # Initialize error by 1, if this is a loss variable
        if self.data.size == 1 and self._grad_var is None:
            self.grad = numpy.ones_like(self.data)
        grads[self._node] = self._grad_var

        def add_cand(cand):
            if cand not in seen_set:
                # Negate since heapq is min-heap. This is a global variable
                heapq.heappush(cand_funcs, (-cand.rank, len(seen_set), cand))
                seen_set.add(cand)

        add_cand(self.creator_node)

        while cand_funcs:
            _, _, func = heapq.heappop(cand_funcs)
            gxs = func.backward_accumulate(func.inputs, func.outputs, func.outputs.grad)

            for x, gx in enumerate(gxs):
                if x in grads:
                    grads[x] += gx
                else:
                    grads[x] = gx

                if x.creator_node is not None:
                    add_cand(x.creator_node)

PyTorch

Example Code

x = Variable(torch.ones(5, 5))
y = Variable(torch.ones(5, 5) * 4)
z = x ** 2 + x * 2 + x * y + y
z.backward(torch.ones(5, 5))

The trace is done by Variable.creator and Function.previous_functions.

class Variable(object):
    def __init__(self, tensor, creator=None, requires_grad=True):
        if creator is None:
            creator = Leaf(self, requires_grad)
        self.data = tensor
        self.creator = creator
        self._grad = None

    def backward(self, gradient=None):
        if gradient is None:
            if self.data.numel() != 1:
                raise RuntimeError('backward should be called only on a scalar (i.e. 1-element tensor) or with gradient w.r.t. the variable')
            gradient = self.data.new(1).fill_(1)
        self._execution_engine.run_backward(self, gradient)

class Function(obejct):
    # ...
    def _do_forward(self, *input):
        unpacked_input = tuple(arg.data for arg in input)
        raw_output = self.forward(*unpacked_input)

        # mark output.creator = self for backward trace
        output = tuple(Variable(tensor, self) for tensor in raw_output)

        self.previous_functions = [(arg.creator, id(arg)) for arg in input]
        self.output_ids = {id(var): i for i, var in enumerate(output)}
        return output

    def _do_backward(self, grad_output):
        return self.backwaerd(grad_output)

The backward is similar to Autograd.

DyNet

Example code

model = dy.model()
W_p = model.add_parameters((20, 100))
b_p = model.add_parameters(20)
E = model.add_lookup_parameters((20000, 50))
for epoch in range(num_epochs):
    for in_words, out_label in training_data:
        dy.renew_cg() # init tape
        W = dy.parameter(W_p)
        b = dy.parameter(b_p)
        score_sym = dy.softmax(W*dy.concatenate([E[in_words[0]],E[in_words[1]]])+b)
        loss_sym = dy.pickneglogsoftmax(score_sym, out_label)
        loss_val = loss_sym.value()
        loss_sym.backward()

forward, backward. The trace is done by creating a tape of expressions in every iteration. Backward is done by traverse the tape in the reverse order.

void SimpleExecutionEngine::backward(VariableIndex from_where, bool full) {
  ...  
  for (int i = num_nodes - 1; i >= 0; --i) {
    // each node corresponds to an op
    node->backward(xs, node_fx, node_dEdfx, ai, node_dEdxai);
  }
  ...
}