|
|
|
@ -17,6 +17,53 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
struct ReAllocateVisitor {
|
|
|
|
|
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims)
|
|
|
|
|
: tensor_(tensor), dims_(dims) {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void operator()() const {
|
|
|
|
|
framework::Tensor cpu_tensor;
|
|
|
|
|
platform::CPUPlace cpu;
|
|
|
|
|
T* ptr = cpu_tensor.mutable_data<T>(dims_, cpu);
|
|
|
|
|
const T* old_ptr =
|
|
|
|
|
tensor_->memory_size() == 0 ? nullptr : tensor_->data<T>();
|
|
|
|
|
if (old_ptr != nullptr) {
|
|
|
|
|
std::copy(old_ptr, old_ptr + tensor_->numel(), ptr);
|
|
|
|
|
}
|
|
|
|
|
tensor_->ShareDataWith(cpu_tensor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Tensor* tensor_;
|
|
|
|
|
framework::DDim dims_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct TensorSlicedCopyVisitor {
|
|
|
|
|
TensorSlicedCopyVisitor(const platform::Place& place, framework::Tensor* dst,
|
|
|
|
|
int64_t dst_offset, const framework::Tensor src,
|
|
|
|
|
int64_t src_offset, int64_t size)
|
|
|
|
|
: place_(place),
|
|
|
|
|
dst_(dst),
|
|
|
|
|
dst_offset_(dst_offset),
|
|
|
|
|
src_(src),
|
|
|
|
|
src_offset_(src_offset),
|
|
|
|
|
size_(size) {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void operator()() const {
|
|
|
|
|
std::copy(src_.data<T>() + src_offset_,
|
|
|
|
|
src_.data<T>() + src_offset_ + size_,
|
|
|
|
|
dst_->mutable_data<T>(place_) + dst_offset_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
framework::Tensor* dst_;
|
|
|
|
|
int64_t dst_offset_;
|
|
|
|
|
framework::Tensor src_;
|
|
|
|
|
int64_t src_offset_;
|
|
|
|
|
int64_t size_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) {
|
|
|
|
|
{ // the 1st field, uint32_t version
|
|
|
|
@ -69,5 +116,49 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows,
|
|
|
|
|
TensorFromStream(is, selected_rows->mutable_value(), dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SelectedRows::HasKey(int64_t key) const {
|
|
|
|
|
return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false
|
|
|
|
|
: true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor SelectedRows::Get(int64_t key) const {
|
|
|
|
|
int64_t index = Index(key);
|
|
|
|
|
PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table.");
|
|
|
|
|
return value_->Slice(index, index + 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
|
|
|
|
|
PADDLE_ENFORCE(value.IsInitialized(), "The value should be initialized.");
|
|
|
|
|
if (value_->IsInitialized()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
value.type(), value_->type(),
|
|
|
|
|
"The type of the value should be same with the original value");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
|
|
|
|
|
"The first dim of value should be 1.");
|
|
|
|
|
auto index = Index(key);
|
|
|
|
|
platform::Place cpu = platform::CPUPlace();
|
|
|
|
|
bool is_new_key = false;
|
|
|
|
|
if (index == -1) {
|
|
|
|
|
rows_.push_back(key);
|
|
|
|
|
index = rows_.size() - 1;
|
|
|
|
|
is_new_key = true;
|
|
|
|
|
// whether need to resize the value
|
|
|
|
|
if (static_cast<int64_t>(rows_.size()) > value_->dims()[0]) {
|
|
|
|
|
auto dims = value_->dims();
|
|
|
|
|
dims[0] = (dims[0] + 1) << 1;
|
|
|
|
|
framework::VisitDataType(framework::ToDataType(value.type()),
|
|
|
|
|
ReAllocateVisitor(value_.get(), dims));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value.type()),
|
|
|
|
|
TensorSlicedCopyVisitor(cpu, value_.get(),
|
|
|
|
|
index * value_->numel() / value_->dims()[0],
|
|
|
|
|
value, static_cast<int64_t>(0), value.numel()));
|
|
|
|
|
return is_new_key;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|