fix params with only 1 dim (#15828)

* fix params with only 1 dim
* test=develop
revert-15774-anakin_subgraph_engine
tangwei12 6 years ago committed by GitHub
parent fbb5404652
commit 971f3bc9b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -766,7 +766,10 @@ def _load_distributed_persistables(executor, dirname, main_program=None):
dtype=slice_var.dtype,
persistable=True)
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
dim1_flatten = 1
if len(slice.shape) >= 2:
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
start = int(offset / dim1_flatten)
end = int(offset / dim1_flatten + slice.shape[0])

@ -1020,7 +1020,11 @@ class DistributeTranspiler(object):
skip_dim0 = 0
slice_vars = self.param_var_mapping[orig_var_name]
orig_dim1_flatten = reduce(lambda x, y: x * y, slice_vars[0].shape[1:])
orig_dim1_flatten = 1
if len(slice_vars[0].shape) >= 2:
orig_dim1_flatten = reduce(lambda x, y: x * y,
slice_vars[0].shape[1:])
for slice_var in slice_vars[:block_idx]:
skip_dim0 += slice_var.shape[0]

Loading…
Cancel
Save