|
|
|
@ -13,7 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import os
|
|
|
|
|
from ..layers import collective
|
|
|
|
|
|
|
|
|
|
from ..framework import Parameter
|
|
|
|
|
__parallel_ctx__clz__ = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -39,5 +39,5 @@ def _init_parallel_ctx():
|
|
|
|
|
|
|
|
|
|
def _broadcast_parameters(parameters):
|
|
|
|
|
for param in parameters:
|
|
|
|
|
if param.trainable:
|
|
|
|
|
if isinstance(param, Parameter) and param.trainable:
|
|
|
|
|
collective._broadcast(param, 0, sync_mode=True)
|
|
|
|
|