|
|
|
@ -39,11 +39,10 @@ struct ReAllocateVisitor {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct TensorCopyVisitor {
|
|
|
|
|
TensorCopyVisitor(const platform::Place& place, framework::Tensor* dst,
|
|
|
|
|
int64_t dst_offset, const framework::Tensor src,
|
|
|
|
|
int64_t src_offset, int64_t size)
|
|
|
|
|
: place_(place),
|
|
|
|
|
dst_(dst),
|
|
|
|
|
TensorCopyVisitor(framework::Tensor* dst, int64_t dst_offset,
|
|
|
|
|
const framework::Tensor src, int64_t src_offset,
|
|
|
|
|
int64_t size)
|
|
|
|
|
: dst_(dst),
|
|
|
|
|
dst_offset_(dst_offset),
|
|
|
|
|
src_(src),
|
|
|
|
|
src_offset_(src_offset),
|
|
|
|
@ -51,12 +50,12 @@ struct TensorCopyVisitor {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void operator()() const {
|
|
|
|
|
std::copy(src_.data<T>() + src_offset_,
|
|
|
|
|
src_.data<T>() + src_offset_ + size_,
|
|
|
|
|
dst_->mutable_data<T>(place_) + dst_offset_);
|
|
|
|
|
// TODO(Yancey1989): support other place
|
|
|
|
|
platform::CPUPlace cpu;
|
|
|
|
|
memory::Copy(cpu, dst_->mutable_data<T>(cpu) + dst_offset_, cpu,
|
|
|
|
|
src_.data<T>() + src_offset_, size_ * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
framework::Tensor* dst_;
|
|
|
|
|
int64_t dst_offset_;
|
|
|
|
|
framework::Tensor src_;
|
|
|
|
@ -125,16 +124,12 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
|
|
|
|
|
framework::Tensor* value) const {
|
|
|
|
|
PADDLE_ENFORCE(value->IsInitialized(),
|
|
|
|
|
"The value tensor should be initialized.");
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> non_keys;
|
|
|
|
|
int64_t value_width = value_->numel() / value_->dims()[0];
|
|
|
|
|
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
|
|
|
|
|
"output tensor should have the same shape with table "
|
|
|
|
|
"execpt the dims[0].");
|
|
|
|
|
|
|
|
|
|
// TODO(Yancey1989): support other place
|
|
|
|
|
platform::CPUPlace cpu;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < keys.size(); ++i) {
|
|
|
|
|
int64_t index = Index(keys[i]);
|
|
|
|
|
if (index == -1) {
|
|
|
|
@ -142,7 +137,7 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
|
|
|
|
|
} else {
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value_->type()),
|
|
|
|
|
TensorCopyVisitor(cpu, value, i * value_width, *value_.get(),
|
|
|
|
|
TensorCopyVisitor(value, i * value_width, *value_.get(),
|
|
|
|
|
index * value_width, value_width));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -159,7 +154,6 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
|
|
|
|
|
"The first dim of value should be 1.");
|
|
|
|
|
auto index = Index(key);
|
|
|
|
|
platform::Place cpu = platform::CPUPlace();
|
|
|
|
|
bool is_new_key = false;
|
|
|
|
|
if (index == -1) {
|
|
|
|
|
rows_.push_back(key);
|
|
|
|
@ -176,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
|
|
|
|
|
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value.type()),
|
|
|
|
|
TensorCopyVisitor(cpu, value_.get(),
|
|
|
|
|
TensorCopyVisitor(value_.get(),
|
|
|
|
|
index * value_->numel() / value_->dims()[0], value,
|
|
|
|
|
static_cast<int64_t>(0), value.numel()));
|
|
|
|
|
return is_new_key;
|
|
|
|
|