|
|
|
@ -18,8 +18,8 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
struct ReAllocateVisitor {
|
|
|
|
|
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims)
|
|
|
|
|
: tensor_(tensor), dims_(dims) {}
|
|
|
|
|
ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor)
|
|
|
|
|
: dims_(dims), tensor_(tensor) {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void operator()() const {
|
|
|
|
@ -34,8 +34,8 @@ struct ReAllocateVisitor {
|
|
|
|
|
tensor_->ShareDataWith(cpu_tensor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Tensor* tensor_;
|
|
|
|
|
framework::DDim dims_;
|
|
|
|
|
framework::Tensor* tensor_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct TensorCopyVisitor {
|
|
|
|
@ -158,6 +158,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
|
|
|
|
|
"The first dim of value should be 1.");
|
|
|
|
|
std::lock_guard<std::mutex> lock(*auto_grown_mutex_.get());
|
|
|
|
|
auto index = Index(key);
|
|
|
|
|
bool is_new_key = false;
|
|
|
|
|
if (index == -1) {
|
|
|
|
@ -169,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
|
|
|
|
|
auto dims = value_->dims();
|
|
|
|
|
dims[0] = (dims[0] + 1) << 1;
|
|
|
|
|
framework::VisitDataType(framework::ToDataType(value.type()),
|
|
|
|
|
ReAllocateVisitor(value_.get(), dims));
|
|
|
|
|
ReAllocateVisitor(dims, value_.get()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|