|
|
@ -33,11 +33,6 @@ void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void FetchOpHandle::WaitAndMergeCPUTensors() const {
|
|
|
|
void FetchOpHandle::WaitAndMergeCPUTensors() const {
|
|
|
|
// Wait fetch stream done.
|
|
|
|
|
|
|
|
for (auto &ctx : dev_ctx_) {
|
|
|
|
|
|
|
|
ctx.second->Wait();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<const LoDTensor *> tensors_ptr;
|
|
|
|
std::vector<const LoDTensor *> tensors_ptr;
|
|
|
|
tensors_ptr.reserve(tensors_.size());
|
|
|
|
tensors_ptr.reserve(tensors_.size());
|
|
|
|
for (auto &t : tensors_) {
|
|
|
|
for (auto &t : tensors_) {
|
|
|
@ -72,6 +67,8 @@ void FetchOpHandle::RunImpl() {
|
|
|
|
tensors_[i].ShareDataWith(t);
|
|
|
|
tensors_[i].ShareDataWith(t);
|
|
|
|
tensors_[i].set_lod(t.lod());
|
|
|
|
tensors_[i].set_lod(t.lod());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this->WaitAndMergeCPUTensors();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|