Merge pull request #11012 from jacquesqiao/add-auto_grown_mutex

add auto_grown_mutex for selected rows
release/0.13.0
Qiao Longfei 7 years ago committed by GitHub
commit 654f5d3c91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,8 +18,8 @@ namespace paddle {
namespace framework { namespace framework {
struct ReAllocateVisitor { struct ReAllocateVisitor {
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims) ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor)
: tensor_(tensor), dims_(dims) {} : dims_(dims), tensor_(tensor) {}
template <typename T> template <typename T>
void operator()() const { void operator()() const {
@ -34,8 +34,8 @@ struct ReAllocateVisitor {
tensor_->ShareDataWith(cpu_tensor); tensor_->ShareDataWith(cpu_tensor);
} }
framework::Tensor* tensor_;
framework::DDim dims_; framework::DDim dims_;
framework::Tensor* tensor_;
}; };
struct TensorCopyVisitor { struct TensorCopyVisitor {
@ -158,6 +158,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
} }
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1), PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
"The first dim of value should be 1."); "The first dim of value should be 1.");
std::lock_guard<std::mutex> lock(*auto_grown_mutex_.get());
auto index = Index(key); auto index = Index(key);
bool is_new_key = false; bool is_new_key = false;
if (index == -1) { if (index == -1) {
@ -169,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
auto dims = value_->dims(); auto dims = value_->dims();
dims[0] = (dims[0] + 1) << 1; dims[0] = (dims[0] + 1) << 1;
framework::VisitDataType(framework::ToDataType(value.type()), framework::VisitDataType(framework::ToDataType(value.type()),
ReAllocateVisitor(value_.get(), dims)); ReAllocateVisitor(dims, value_.get()));
} }
} }

@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <memory>
#include <mutex> // NOLINT
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -46,11 +48,13 @@ class SelectedRows {
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height) SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) { : rows_(rows), height_(height) {
value_.reset(new Tensor()); value_.reset(new Tensor());
auto_grown_mutex_.reset(new std::mutex);
} }
SelectedRows() { SelectedRows() {
height_ = 0; height_ = 0;
value_.reset(new Tensor()); value_.reset(new Tensor());
auto_grown_mutex_.reset(new std::mutex);
} }
platform::Place place() const { return value_->place(); } platform::Place place() const { return value_->place(); }
@ -125,6 +129,7 @@ class SelectedRows {
Vector<int64_t> rows_; Vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr}; std::unique_ptr<Tensor> value_{nullptr};
int64_t height_; int64_t height_;
std::unique_ptr<std::mutex> auto_grown_mutex_{nullptr};
}; };
/* /*

Loading…
Cancel
Save