|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import os
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
import paddle.fluid.io as io
|
|
|
|
|
from paddle.fluid.communicator import Communicator
|
|
|
|
@ -53,7 +54,11 @@ class DistributedTranspiler(Fleet):
|
|
|
|
|
"""
|
|
|
|
|
if not self._transpile_config.sync_mode:
|
|
|
|
|
self._communicator = Communicator(self.main_program)
|
|
|
|
|
self._communicator.start()
|
|
|
|
|
|
|
|
|
|
if not self._communicator.is_running():
|
|
|
|
|
self._communicator.start()
|
|
|
|
|
else:
|
|
|
|
|
warnings.warn("communicator has been initialized, skip")
|
|
|
|
|
|
|
|
|
|
def init_server(self, model_dir=None):
|
|
|
|
|
"""
|
|
|
|
@ -104,7 +109,8 @@ class DistributedTranspiler(Fleet):
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
if not self._transpile_config.sync_mode:
|
|
|
|
|
if not self._transpile_config.sync_mode and self._communicator.is_running(
|
|
|
|
|
):
|
|
|
|
|
self._communicator.stop()
|
|
|
|
|
self._executor.close()
|
|
|
|
|
|
|
|
|
|