Merge pull request #10049 from typhoonzero/fix_not_trainable_transpiler

Skip updating not trainable parameters in distribute transpiler
wangkuiyi-patch-2
Wu Yi 7 years ago committed by GitHub
commit 879b7c5601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,7 +18,7 @@ import math
import distributed_splitter as splitter
import framework
from framework import Program, default_main_program, Variable
from framework import Program, default_main_program, Variable, Parameter
from . import core
LOOKUP_TABLE_TYPE = "lookup_table"
@ -222,8 +222,14 @@ class DistributeTranspiler:
# step1: For large parameters and gradients, split them into smaller
# blocks.
param_list = [pg[0] for pg in params_grads]
grad_list = [pg[1] for pg in params_grads]
param_list = []
grad_list = []
for p, g in params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
if self.has_distributed_lookup_table:
param_list = [

Loading…
Cancel
Save