remove scale loss and coll grads, test=document_fix (#27874)

my_2.0rc
Chen Weihang 5 years ago committed by GitHub
parent 6898746f1d
commit ed31dac6eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -630,9 +630,7 @@ class Fleet(object):
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
@ -842,9 +840,7 @@ class Fleet(object):
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
@ -903,9 +899,7 @@ class Fleet(object):
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()

@ -92,9 +92,7 @@ def init_parallel_env():
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()

@ -314,9 +314,7 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
if print_result is True:
print("loss:", loss.numpy())
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()

@ -397,9 +397,7 @@ class DataParallel(layers.Layer):
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()

Loading…
Cancel
Save