From 63320f722cc718e69ddaa4aa5921e7fd047097df Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 01:17:00 -0800 Subject: [PATCH 1/9] "add some interfaces" --- paddle/framework/lod_tensor.h | 22 ++++++- paddle/framework/mixed_vector.h | 102 ++++++++++++++++++++------------ paddle/memory/memory.h | 18 ++++++ 3 files changed, 103 insertions(+), 39 deletions(-) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index d0ab640485..ab28924161 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -48,12 +48,26 @@ namespace framework { */ struct LoD : public std::vector> { using std::vector>::vector; + platform::Place place() const { + if (this->size() == 0) { + // Not Initialze Yet. + return platform::CPUPlace(); + } else { + return this->front().place(); + } + } void CopyFromCUDA() { for (auto it = this->begin(); it != this->end(); ++it) { it->CopyFromCUDA(); } } + + void CopyToPeer(platform::Place place) { + for (auto it = this->begin(); it != this->end(); ++it) { + it->mutable_data(place); + } + } }; std::ostream& operator<<(std::ostream& os, const LoD& lod); @@ -115,7 +129,13 @@ class LoDTensor : public Tensor { explicit LoDTensor(const LoD& lod) : lod_(lod) {} - void set_lod(const LoD& lod) { lod_ = lod; } + void set_lod(const LoD& lod) { + lod_ = lod; + if (holder_ != nullptr && + platform::is_same_place(holder_->place(), lod.place())) { + lod_.CopyToPeer(holder_->place()); + } + } const LoD& lod() const { return lod_; } diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index 85caac8dcd..d86899bc63 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -40,14 +40,15 @@ class Vector : public std::vector { Vector() {} Vector(const std::vector &v) : std::vector(v) {} // NOLINT - virtual ~Vector() { -#ifdef PADDLE_WITH_CUDA - if (cuda_ptr_ != nullptr) { - memory::Free(place_, cuda_ptr_); - } -#endif - } + inline platform::Place place() const { return place_; } + /*! Return a pointer to constant memory block. */ + inline const T *data(platform::Place place) const; + + /*! Return a pointer to mutable memory block. */ + inline T *mutable_data(platform::Place place); + + // TODO(dzhwinter): below interfaces should be removed /* Get device vector */ T *cuda_data() { CopyToCUDA(); @@ -68,25 +69,71 @@ class Vector : public std::vector { void CopyToPeer(platform::Place); private: - void *cuda_ptr_ = nullptr; + std::shared_ptr cuda_ptr_; size_t cuda_size_ = 0; // device vector numel platform::CUDAPlace place_; }; template -void Vector::CopyToCUDA() { +inline const T *Vector::data(platform::Place place) const { + if (platform::is_cpu_place(place)) { + return std::vector::data(); + } else if (platform::is_gpu_place(place)) { + if (cuda_ptr_ == nullptr) { + return nullptr; + } + if (platform::is_same_place(place, place_)) { + return static_cast(cuda_ptr_.get()); + } else { + PADDLE_THROW( + "Unmatched place. Please use `mutable_data` copy lod to the target " + "Place first."); + } + } else { + PADDLE_THROW("Unsupport Place."); + } +} + +template +inline T *Vector::mutable_data(platform::Place place) { + if (platform::is_cpu_place(place)) { + return std::vector::data(); + } else if (platform::is_gpu_place(place)) { + if (!platform::is_same_place(place, place_)) { + place_ = boost::get(place); + } #ifdef PADDLE_WITH_CUDA - if (cuda_size_ < this->size()) { - if (cuda_ptr_ != nullptr) { - memory::Free(place_, cuda_ptr_); + if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { + cuda_ptr_.reset( + memory::Alloc(place_, this->size() * sizeof(T)), + memory::PlainDeleter(place_)); } - cuda_ptr_ = - memory::Alloc(place_, this->size() * sizeof(T)); + cuda_size_ = this->size(); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *ctx = pool.GetByPlace(place_); + memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), + static_cast(this->data()), + this->size() * sizeof(T), ctx->stream()); + ctx->Wait(); + return static_cast(cuda_ptr_.get()); +#endif + } else { + PADDLE_THROW("Unsupport Place."); + } +} + +template +void Vector::CopyToCUDA() { +#ifdef PADDLE_WITH_CUDA + if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { + cuda_ptr_.reset( + memory::Alloc(this->size() * sizeof(T)), + memory::PlainDeleter(place_)); } cuda_size_ = this->size(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto *ctx = pool.GetByPlace(place_); - memory::Copy(place_, cuda_ptr_, platform::CPUPlace(), + memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), static_cast(this->data()), this->size() * sizeof(T), ctx->stream()); ctx->Wait(); @@ -104,32 +151,11 @@ void Vector::CopyFromCUDA() { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto *ctx = pool.GetByPlace(place_); memory::Copy(platform::CPUPlace(), static_cast(this->data()), place_, - static_cast(cuda_ptr_), this->size() * sizeof(T), - ctx->stream()); - ctx->Wait(); -#endif -} - -template -void Vector::CopyToPeer(platform::Place peer_place) { -#ifdef PADDLE_WITH_CUDA - auto *ctx = platform::DeviceContextPool::Instance().GetByPlace(place_); - void *peer_cuda_ptr = memory::Alloc( - boost::get(peer_place), this->size() * sizeof(T)); - memory::Copy(boost::get(peer_place), peer_cuda_ptr, - place_, cuda_ptr_, this->size() * sizeof(T), ctx->stream()); + static_cast(cuda_ptr_.get()), + this->size() * sizeof(T), ctx->stream()); ctx->Wait(); - - memory::Free(place_, cuda_ptr_); - place_ = boost::get(peer_place); - cuda_ptr_ = peer_cuda_ptr; #endif } -template class Vector; -template class Vector; -template class Vector; -template class Vector; - } // namespace framework } // namespace paddle diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 7012b6d331..30ed68c6e0 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -81,5 +81,23 @@ class PODDeleter { Place place_; }; +/** + * \brief Free memory block in one place does not meet POD + * + * \note In some cases, custom deleter is used to + * deallocate the memory automatically for + * std::unique_ptr in tensor.h. + * + */ +template +class PlainDeleter { + public: + explicit PlainDeleter(Place place) : place_(place) {} + void operator()(T* ptr) { Free(place_, reinterpret_cast(ptr)); } + + private: + Place place_; +}; + } // namespace memory } // namespace paddle From a402d2b39257ae58345998ed5edd6b87b09e9a1b Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 01:22:13 -0800 Subject: [PATCH 2/9] "fix condition" --- paddle/framework/lod_tensor.h | 2 +- paddle/framework/selected_rows.h | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index ab28924161..3465e02c82 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -132,7 +132,7 @@ class LoDTensor : public Tensor { void set_lod(const LoD& lod) { lod_ = lod; if (holder_ != nullptr && - platform::is_same_place(holder_->place(), lod.place())) { + !platform::is_same_place(holder_->place(), lod.place())) { lod_.CopyToPeer(holder_->place()); } } diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index 30d3dfc1e8..1132344244 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -42,7 +42,13 @@ class SelectedRows { Vector* mutable_rows() { return &rows_; } - void set_rows(const Vector& rows) { rows_ = rows; } + void set_rows(const Vector& rows) { + rows_ = rows; + if (value_ != nullptr && + !platform::is_same_place(value_->place(), rows.place())) { + rows_.mutable_data(value_->place()); + } + } DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); From 07dd3d25b39878b6ccc4736e189c015cfd2265d2 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 01:53:43 -0800 Subject: [PATCH 3/9] "fix const warning" --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/lod_tensor_test.cu | 22 -------- paddle/framework/mixed_vector_test.cu | 72 +++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 22 deletions(-) create mode 100644 paddle/framework/mixed_vector_test.cu diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8b71f73c36..7c4ba3afb9 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -20,6 +20,7 @@ endif() cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) +nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place paddle_memory device_context init) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init) diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index d4c9f00bd9..adea02e3b3 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -28,28 +28,6 @@ __global__ void test(size_t* a, int size) { } } -TEST(Vector, Normal) { - using namespace paddle::framework; - using namespace paddle::platform; - using namespace paddle::memory; - - paddle::framework::InitDevices(); - - paddle::framework::Vector vec({1, 2, 3}); - size_t* ptr = vec.data(); - for (size_t i = 0; i < vec.size(); ++i) { - EXPECT_EQ(vec[i], *(ptr + i)); - } - - vec.clear(); - vec.CopyFromCUDA(); - - std::vector v = {1, 2, 3}; - for (size_t i = 0; i < v.size(); ++i) { - EXPECT_EQ(v[i], vec[i]); - } -} - TEST(LoD, data) { paddle::framework::InitDevices(); diff --git a/paddle/framework/mixed_vector_test.cu b/paddle/framework/mixed_vector_test.cu new file mode 100644 index 0000000000..7b571788ad --- /dev/null +++ b/paddle/framework/mixed_vector_test.cu @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include +#include +#include "gtest/gtest.h" + +#include "paddle/framework/init.h" +#include "paddle/framework/mixed_vector.h" + +using namespace paddle::framework; +using namespace paddle::platform; +using namespace paddle::memory; + +template +__global__ void test(T* data, int size) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; + i += blockDim.x * gridDim.x) { + data[i] *= 2; + } +} + +TEST(Vector, Normal) { + // fill the device context pool. + InitDevices(); + + Vector vec({1, 2, 3}); + size_t* ptr = vec.data(); + for (size_t i = 0; i < vec.size(); ++i) { + EXPECT_EQ(vec[i], *(ptr + i)); + } + + vec.clear(); + vec.CopyFromCUDA(); + + std::vector v = {1, 2, 3}; + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(v[i], vec[i]); + } +} + +TEST(Vector, MultipleCopy) { + InitDevices(); + Vector vec({1, 2, 3}); + CUDAPlace place(0); + vec.mutable_data(place); + auto vec2 = Vector(vec); + { + const size_t* ptr = vec2.data(CPUPlace()); + for (size_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(*(ptr + i), vec[i]); + } + } + test<<<3, 3>>>(vec2.mutable_data(place), vec2.size()); + vec2.CopyFromCUDA(); + { + const size_t* ptr = vec2.data(CPUPlace()); + for (size_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(*(ptr + i), vec[i] * 2); + } + } +} From 239fafb0d31618a1aee2ac814ed662f18c48cc9c Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 02:37:52 -0800 Subject: [PATCH 4/9] "test on parallel do op" --- paddle/operators/parallel_do_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 67f9854c02..d662878592 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -79,6 +79,7 @@ inline void CopyOrShare(const framework::Variable &src, } else { Copy(src.Get(), dst_place, dst->GetMutable()); } + dst->set_lod(src.lod()); } else if (src.IsType()) { auto &src_sr = src.Get(); auto *dst_sr = dst->GetMutable(); @@ -89,6 +90,7 @@ inline void CopyOrShare(const framework::Variable &src, } else { Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); } + dst_sr->set_rows(src_sr.rows()); } else { PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); } @@ -145,6 +147,7 @@ class ParallelDoOp : public framework::OperatorBase { auto *sub_scope = sub_scopes[i]; auto *dst = sub_scope->Var(param)->GetMutable(); framework::Copy(src, place, dst); + dst->set_lod(src.lod()); } } WaitOnPlaces(places); From f18f3826dc5d59f49908f2c232ff81b15c0abd9a Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 03:04:39 -0800 Subject: [PATCH 5/9] "parallel op set lod after copy " --- paddle/framework/mixed_vector.h | 4 ++-- paddle/operators/parallel_do_op.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index d86899bc63..aade7d8391 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -54,7 +54,7 @@ class Vector : public std::vector { CopyToCUDA(); PADDLE_ENFORCE_NOT_NULL( cuda_ptr_, "No data or Insufficient CUDA memory to allocation"); - return static_cast(cuda_ptr_); + return static_cast(cuda_ptr_.get()); } /* Get host vector */ @@ -127,7 +127,7 @@ void Vector::CopyToCUDA() { #ifdef PADDLE_WITH_CUDA if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { cuda_ptr_.reset( - memory::Alloc(this->size() * sizeof(T)), + memory::Alloc(place_, this->size() * sizeof(T)), memory::PlainDeleter(place_)); } cuda_size_ = this->size(); diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index d662878592..87678decde 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -79,7 +79,7 @@ inline void CopyOrShare(const framework::Variable &src, } else { Copy(src.Get(), dst_place, dst->GetMutable()); } - dst->set_lod(src.lod()); + dst->GetMutable()->set_lod(src.Get().lod()); } else if (src.IsType()) { auto &src_sr = src.Get(); auto *dst_sr = dst->GetMutable(); From 17b1c369b1f2dadff102ec283b847ea064593dec Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 23:08:12 -0800 Subject: [PATCH 6/9] "fix ci" --- paddle/framework/mixed_vector.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index aade7d8391..1fc7622e9b 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -116,6 +116,8 @@ inline T *Vector::mutable_data(platform::Place place) { this->size() * sizeof(T), ctx->stream()); ctx->Wait(); return static_cast(cuda_ptr_.get()); +#else + return nullptr; #endif } else { PADDLE_THROW("Unsupport Place."); From 709c157a2ff4d51846c373b465d021be93033363 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 23:59:41 -0800 Subject: [PATCH 7/9] "fix ci" --- paddle/framework/lod_tensor.h | 8 +------- paddle/framework/selected_rows.h | 8 +------- paddle/operators/parallel_do_op.cc | 11 ++++++++--- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 3465e02c82..a773c1eb32 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -129,13 +129,7 @@ class LoDTensor : public Tensor { explicit LoDTensor(const LoD& lod) : lod_(lod) {} - void set_lod(const LoD& lod) { - lod_ = lod; - if (holder_ != nullptr && - !platform::is_same_place(holder_->place(), lod.place())) { - lod_.CopyToPeer(holder_->place()); - } - } + void set_lod(const LoD& lod) { lod_ = lod; } const LoD& lod() const { return lod_; } diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index 1132344244..30d3dfc1e8 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -42,13 +42,7 @@ class SelectedRows { Vector* mutable_rows() { return &rows_; } - void set_rows(const Vector& rows) { - rows_ = rows; - if (value_ != nullptr && - !platform::is_same_place(value_->place(), rows.place())) { - rows_.mutable_data(value_->place()); - } - } + void set_rows(const Vector& rows) { rows_ = rows; } DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 87678decde..0db2fb6238 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -76,21 +76,26 @@ inline void CopyOrShare(const framework::Variable &src, if (src.IsType()) { if (src.Get().place() == dst_place) { dst->GetMutable()->ShareDataWith(src.Get()); + dst->GetMutable()->set_lod(src.Get().lod()); } else { Copy(src.Get(), dst_place, dst->GetMutable()); + LoD lod(src.Get().lod()); + lod.CopyToPeer(dst_place); + dst->GetMutable()->set_lod(lod); } - dst->GetMutable()->set_lod(src.Get().lod()); } else if (src.IsType()) { auto &src_sr = src.Get(); auto *dst_sr = dst->GetMutable(); - dst_sr->set_rows(src_sr.rows()); dst_sr->set_height(src_sr.height()); if (src_sr.value().place() == dst_place) { dst_sr->mutable_value()->ShareDataWith(src_sr.value()); + dst_sr->set_rows(src_sr.rows()); } else { Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); + LoD lod(src.Get().lod()); + lod.CopyToPeer(dst_place); + dst_sr->set_rows(lod); } - dst_sr->set_rows(src_sr.rows()); } else { PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); } From 179b78934a81c7935b3a3d6fa22f9596170a31dc Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 6 Feb 2018 00:24:13 -0800 Subject: [PATCH 8/9] "fix CopyToPeer" --- paddle/framework/lod_tensor.h | 2 +- paddle/framework/mixed_vector.h | 25 +++++++++++++++++++++++-- paddle/operators/parallel_do_op.cc | 4 ++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index a773c1eb32..be2b301619 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -65,7 +65,7 @@ struct LoD : public std::vector> { void CopyToPeer(platform::Place place) { for (auto it = this->begin(); it != this->end(); ++it) { - it->mutable_data(place); + it->CopyToPeer(place); } } }; diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h index 1fc7622e9b..cdb968e3cb 100644 --- a/paddle/framework/mixed_vector.h +++ b/paddle/framework/mixed_vector.h @@ -82,7 +82,7 @@ inline const T *Vector::data(platform::Place place) const { if (cuda_ptr_ == nullptr) { return nullptr; } - if (platform::is_same_place(place, place_)) { + if (boost::get(place) == place_) { return static_cast(cuda_ptr_.get()); } else { PADDLE_THROW( @@ -99,7 +99,7 @@ inline T *Vector::mutable_data(platform::Place place) { if (platform::is_cpu_place(place)) { return std::vector::data(); } else if (platform::is_gpu_place(place)) { - if (!platform::is_same_place(place, place_)) { + if (boost::get(place) != place_) { place_ = boost::get(place); } #ifdef PADDLE_WITH_CUDA @@ -159,5 +159,26 @@ void Vector::CopyFromCUDA() { #endif } +template +void Vector::CopyToPeer(platform::Place place) { +#ifdef PADDLE_WITH_CUDA + if (boost::get(place) != place_) { + place_ = boost::get(place); + } + if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) { + cuda_ptr_.reset( + memory::Alloc(place_, this->size() * sizeof(T)), + memory::PlainDeleter(place_)); + } + cuda_size_ = this->size(); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *ctx = pool.GetByPlace(place_); + memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(), + static_cast(this->data()), + this->size() * sizeof(T), ctx->stream()); + ctx->Wait(); +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 0db2fb6238..eb6308d306 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -79,7 +79,7 @@ inline void CopyOrShare(const framework::Variable &src, dst->GetMutable()->set_lod(src.Get().lod()); } else { Copy(src.Get(), dst_place, dst->GetMutable()); - LoD lod(src.Get().lod()); + framework::LoD lod(src.Get().lod()); lod.CopyToPeer(dst_place); dst->GetMutable()->set_lod(lod); } @@ -92,7 +92,7 @@ inline void CopyOrShare(const framework::Variable &src, dst_sr->set_rows(src_sr.rows()); } else { Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); - LoD lod(src.Get().lod()); + framework::Vector lod(src_sr.rows()); lod.CopyToPeer(dst_place); dst_sr->set_rows(lod); } From 1eb3d6cdb261bb41eff6b44b301e3da881b2fa26 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 6 Feb 2018 21:24:02 -0800 Subject: [PATCH 9/9] "rerun ci" --- paddle/operators/parallel_do_op.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index eb6308d306..6c85ca6cde 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -152,7 +152,9 @@ class ParallelDoOp : public framework::OperatorBase { auto *sub_scope = sub_scopes[i]; auto *dst = sub_scope->Var(param)->GetMutable(); framework::Copy(src, place, dst); - dst->set_lod(src.lod()); + framework::LoD lod(src.lod()); + lod.CopyToPeer(place); + dst->set_lod(lod); } } WaitOnPlaces(places);