fix_not_trainable_transpiler

wangkuiyi-patch-2
typhoonzero 7 years ago
parent 82b192a3fd
commit e6745be9ea

@ -18,7 +18,7 @@ import math
import distributed_splitter as splitter import distributed_splitter as splitter
import framework import framework
from framework import Program, default_main_program, Variable from framework import Program, default_main_program, Variable, Parameter
from . import core from . import core
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
@ -222,8 +222,14 @@ class DistributeTranspiler:
# step1: For large parameters and gradients, split them into smaller # step1: For large parameters and gradients, split them into smaller
# blocks. # blocks.
param_list = [pg[0] for pg in params_grads] param_list = []
grad_list = [pg[1] for pg in params_grads] grad_list = []
for p, g in params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
param_list = [ param_list = [

Loading…
Cancel
Save