|
|
|
@ -4780,7 +4780,7 @@ class RecomputeOptimizer(Optimizer):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def _insert_async_memcpy_op(self, insert_idx, src_varname, dst_varname,
|
|
|
|
|
op_role, kind):
|
|
|
|
|
op_role, dst_place_type):
|
|
|
|
|
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
|
|
|
|
|
self.block._insert_op_without_sync(
|
|
|
|
|
insert_idx,
|
|
|
|
@ -4789,8 +4789,10 @@ class RecomputeOptimizer(Optimizer):
|
|
|
|
|
outputs={
|
|
|
|
|
'Out': [self._main_program.global_block().var(dst_varname)]
|
|
|
|
|
},
|
|
|
|
|
attrs={"dst_place_type": int(kind),
|
|
|
|
|
OP_ROLE_KEY: op_role})
|
|
|
|
|
attrs={
|
|
|
|
|
"dst_place_type": int(dst_place_type),
|
|
|
|
|
OP_ROLE_KEY: op_role
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
def _insert_fetch_op(self, idx, varname):
|
|
|
|
|
assert varname in self.checkpoint_name2pinned_name, "Try to fetch {} from Pinned Memory, but it is NOT a checkpoint".format(
|
|
|
|
@ -4798,13 +4800,13 @@ class RecomputeOptimizer(Optimizer):
|
|
|
|
|
|
|
|
|
|
pinned_varname = self.checkpoint_name2pinned_name[varname]
|
|
|
|
|
fetch_varname = self.checkpoint_name2fetch_name[varname]
|
|
|
|
|
self._insert_async_memcpy_op(idx, pinned_varname, fetch_varname, 1, 2)
|
|
|
|
|
self._insert_async_memcpy_op(idx, pinned_varname, fetch_varname, 1, 1)
|
|
|
|
|
|
|
|
|
|
def _insert_offload_op(self, idx, varname):
|
|
|
|
|
assert varname in self.checkpoint_name2pinned_name, "Try to offload {} to Pinned Memory, but it is NOT a checkpoint".format(
|
|
|
|
|
varname)
|
|
|
|
|
pinned_varname = self.checkpoint_name2pinned_name[varname]
|
|
|
|
|
self._insert_async_memcpy_op(idx, varname, pinned_varname, 0, 3)
|
|
|
|
|
self._insert_async_memcpy_op(idx, varname, pinned_varname, 0, 2)
|
|
|
|
|
|
|
|
|
|
def _insert_sync_op(self, op_idx, checkpoint_name):
|
|
|
|
|
# single stream offload no need sync
|
|
|
|
|