truncated_gaussian_random supported in distributed training, test=develop (#17091)

feature/fluid_trt_int8
tangwei12 6 years ago committed by GitHub
parent 794a195881
commit 7330cd639c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1017,7 +1017,8 @@ class DistributeTranspiler(object):
new_inputs = self._get_input_map_from_op(pserver_vars, op)
if op.type in [
"gaussian_random", "fill_constant", "uniform_random"
"gaussian_random", "fill_constant", "uniform_random",
"truncated_gaussian_random"
]:
op._set_attr("shape", list(new_outputs["Out"].shape))
s_prog.global_block().append_op(

Loading…
Cancel
Save