[bugfix]SyncDeviceToHost failed when device address size is zero

pull/4987/head
lizhenyu 5 years ago
parent 7098b5c5d5
commit 1becddf3a4

@ -65,6 +65,8 @@ void GPUSession::StartKernelRT() const {
void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>(); auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>()); pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
@ -73,9 +75,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>()); if (context_ptr->execution_mode() != kPynativeMode) {
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>()); pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>()); pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
}
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();

@ -32,6 +32,10 @@ namespace device {
namespace gpu { namespace gpu {
bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, TypeId, void *host_ptr) const { bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, TypeId, void *host_ptr) const {
MS_EXCEPTION_IF_NULL(host_ptr); MS_EXCEPTION_IF_NULL(host_ptr);
bool need_sync = (size != 0) && (size_ != 0);
if (!need_sync) {
return true;
}
auto &stream = GPUDeviceManager::GetInstance().default_stream(); auto &stream = GPUDeviceManager::GetInstance().default_stream();
MS_EXCEPTION_IF_NULL(stream); MS_EXCEPTION_IF_NULL(stream);
auto ret = GPUDeviceManager::GetInstance().SyncStream(stream); auto ret = GPUDeviceManager::GetInstance().SyncStream(stream);
@ -48,6 +52,10 @@ bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, T
bool GPUDeviceAddress::SyncHostToDevice(const std::vector<int> &, size_t size, TypeId, const void *host_ptr) const { bool GPUDeviceAddress::SyncHostToDevice(const std::vector<int> &, size_t size, TypeId, const void *host_ptr) const {
MS_EXCEPTION_IF_NULL(host_ptr); MS_EXCEPTION_IF_NULL(host_ptr);
bool need_sync = (size != 0) && (size_ != 0);
if (!need_sync) {
return true;
}
auto &stream = GPUDeviceManager::GetInstance().default_stream(); auto &stream = GPUDeviceManager::GetInstance().default_stream();
MS_EXCEPTION_IF_NULL(stream); MS_EXCEPTION_IF_NULL(stream);
if (size != size_) { if (size != size_) {

Loading…
Cancel
Save