|
|
|
@ -18,7 +18,7 @@ import math
|
|
|
|
|
|
|
|
|
|
import distributed_splitter as splitter
|
|
|
|
|
import framework
|
|
|
|
|
from framework import Program, default_main_program, Variable
|
|
|
|
|
from framework import Program, default_main_program, Variable, Parameter
|
|
|
|
|
from . import core
|
|
|
|
|
|
|
|
|
|
LOOKUP_TABLE_TYPE = "lookup_table"
|
|
|
|
@ -222,8 +222,14 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
# step1: For large parameters and gradients, split them into smaller
|
|
|
|
|
# blocks.
|
|
|
|
|
param_list = [pg[0] for pg in params_grads]
|
|
|
|
|
grad_list = [pg[1] for pg in params_grads]
|
|
|
|
|
param_list = []
|
|
|
|
|
grad_list = []
|
|
|
|
|
for p, g in params_grads:
|
|
|
|
|
# skip parameter marked not trainable
|
|
|
|
|
if type(p) == Parameter and p.trainable == False:
|
|
|
|
|
continue
|
|
|
|
|
param_list.append(p)
|
|
|
|
|
grad_list.append(g)
|
|
|
|
|
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
param_list = [
|
|
|
|
|