Merge pull request #7384 from dzhwinter/feature/sync_wait

Feature/sync wait
add_depthwiseConv_op_gpu
ranqiu92 7 years ago committed by GitHub
commit 95c0c12641
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
@ -21,6 +22,10 @@ limitations under the License. */
#include "paddle/framework/shape_inference.h" #include "paddle/framework/shape_inference.h"
#include "paddle/framework/var_type.h" #include "paddle/framework/var_type.h"
DEFINE_bool(op_sync, false,
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -542,8 +547,14 @@ void OperatorWithKernel::Run(const Scope& scope,
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
kernel_iter->second->Compute(ExecutionContext( auto* new_dev_ctx = pool.Get(expected_kernel_key.place_);
*this, new_scope, *pool.Get(expected_kernel_key.place_))); kernel_iter->second->Compute(
ExecutionContext(*this, new_scope, *new_dev_ctx));
/*For profiling/benchmark only*/
if (FLAGS_op_sync) {
new_dev_ctx->Wait();
}
} }
proto::DataType OperatorWithKernel::IndicateDataType( proto::DataType OperatorWithKernel::IndicateDataType(

@ -58,7 +58,7 @@ def __bootstrap__():
read_env_flags = ['use_pinned_memory', 'check_nan_inf'] read_env_flags = ['use_pinned_memory', 'check_nan_inf']
if core.is_compile_gpu(): if core.is_compile_gpu():
read_env_flags.append('fraction_of_gpu_memory_to_use') read_env_flags += ['fraction_of_gpu_memory_to_use', 'op_sync']
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)]) ["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0]) core.init_glog(sys.argv[0])

Loading…
Cancel
Save