|
|
|
@ -38,10 +38,10 @@ struct ReAllocateVisitor {
|
|
|
|
|
framework::DDim dims_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct TensorSlicedCopyVisitor {
|
|
|
|
|
TensorSlicedCopyVisitor(const platform::Place& place, framework::Tensor* dst,
|
|
|
|
|
int64_t dst_offset, const framework::Tensor src,
|
|
|
|
|
int64_t src_offset, int64_t size)
|
|
|
|
|
struct TensorCopyVisitor {
|
|
|
|
|
TensorCopyVisitor(const platform::Place& place, framework::Tensor* dst,
|
|
|
|
|
int64_t dst_offset, const framework::Tensor src,
|
|
|
|
|
int64_t src_offset, int64_t size)
|
|
|
|
|
: place_(place),
|
|
|
|
|
dst_(dst),
|
|
|
|
|
dst_offset_(dst_offset),
|
|
|
|
@ -121,10 +121,27 @@ bool SelectedRows::HasKey(int64_t key) const {
|
|
|
|
|
: true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor SelectedRows::Get(int64_t key) const {
|
|
|
|
|
bool SelectedRows::Get(int64_t key, framework::Tensor* value,
|
|
|
|
|
int64_t row) const {
|
|
|
|
|
int64_t index = Index(key);
|
|
|
|
|
PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table.");
|
|
|
|
|
return value_->Slice(index, index + 1);
|
|
|
|
|
PADDLE_ENFORCE(value->IsInitialized(),
|
|
|
|
|
"The value tensor should be initialized.");
|
|
|
|
|
|
|
|
|
|
int64_t value_width = value->numel() / value->dims()[0];
|
|
|
|
|
PADDLE_ENFORCE_EQ(value_width, value_->numel() / value_->dims()[0],
|
|
|
|
|
"output tensor should have the same shape with table "
|
|
|
|
|
"execpt the dims[0].");
|
|
|
|
|
|
|
|
|
|
// TODO(Yancey1989): support other place
|
|
|
|
|
platform::CPUPlace cpu;
|
|
|
|
|
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value_->type()),
|
|
|
|
|
TensorCopyVisitor(cpu, value, row * value_width, *value_.get(),
|
|
|
|
|
index * value_width, value_width));
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
|
|
|
|
@ -143,7 +160,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
|
|
|
|
|
rows_.push_back(key);
|
|
|
|
|
index = rows_.size() - 1;
|
|
|
|
|
is_new_key = true;
|
|
|
|
|
// whether need to resize the value
|
|
|
|
|
// whether need to resize the table
|
|
|
|
|
if (static_cast<int64_t>(rows_.size()) > value_->dims()[0]) {
|
|
|
|
|
auto dims = value_->dims();
|
|
|
|
|
dims[0] = (dims[0] + 1) << 1;
|
|
|
|
@ -154,9 +171,9 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
|
|
|
|
|
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value.type()),
|
|
|
|
|
TensorSlicedCopyVisitor(cpu, value_.get(),
|
|
|
|
|
index * value_->numel() / value_->dims()[0],
|
|
|
|
|
value, static_cast<int64_t>(0), value.numel()));
|
|
|
|
|
TensorCopyVisitor(cpu, value_.get(),
|
|
|
|
|
index * value_->numel() / value_->dims()[0], value,
|
|
|
|
|
static_cast<int64_t>(0), value.numel()));
|
|
|
|
|
return is_new_key;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|