Fix wait stream two times bug

test=develop
revert-15207-remove_op_handle_lock_and_fix_var
minqiyang 7 years ago
parent 6fabbd8fb8
commit 8149a07a41

@ -25,7 +25,7 @@ struct ExecutionStrategy {
size_t num_threads_{0}; size_t num_threads_{0};
bool use_cuda_{true}; bool use_cuda_{true};
bool allow_op_delay_{false}; bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{1}; size_t num_iteration_per_drop_scope_{100};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false}; bool dry_run_{false};
}; };

@ -66,17 +66,15 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr); platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
++drop_scope_counter_; ++drop_scope_counter_;
bool stream_end = false;
if (!fetch_tensors.empty()) { if (!fetch_tensors.empty()) {
// Wait All computational streams WaitComputationalStreams();
for (auto p : places_) { stream_end = true;
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
} }
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
// Wait All computational streams if (!stream_end) {
for (auto p : places_) { WaitComputationalStreams();
platform::DeviceContextPool::Instance().Get(p)->Wait();
} }
for (auto &scope : local_scopes_) { for (auto &scope : local_scopes_) {

@ -47,6 +47,14 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override; FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
inline void WaitComputationalStreams() {
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
private: private:
size_t drop_scope_counter_{0}; size_t drop_scope_counter_{0};

@ -815,7 +815,7 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is INT, num_iteration_per_drop_scope indicates how R"DOC(The type is INT, num_iteration_per_drop_scope indicates how
many iterations to clean up the temp variables which many iterations to clean up the temp variables which
is generated during execution. It may make the execution faster, is generated during execution. It may make the execution faster,
because the temp variable's shape maybe the same between two iterations. Default 1. because the temp variable's shape maybe the same between two iterations. Default 100.
NOTES: NOTES:
1. If you fetch data when calling the 'run', the ParallelExecutor 1. If you fetch data when calling the 'run', the ParallelExecutor

Loading…
Cancel
Save