|
|
|
@ -121,27 +121,32 @@ bool SelectedRows::HasKey(int64_t key) const {
|
|
|
|
|
: true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SelectedRows::Get(int64_t key, framework::Tensor* value,
|
|
|
|
|
int64_t offset) const {
|
|
|
|
|
int64_t index = Index(key);
|
|
|
|
|
PADDLE_ENFORCE_GE(index, 0, "The key should be exists in the Table.");
|
|
|
|
|
std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
|
|
|
|
|
framework::Tensor* value) const {
|
|
|
|
|
PADDLE_ENFORCE(value->IsInitialized(),
|
|
|
|
|
"The value tensor should be initialized.");
|
|
|
|
|
|
|
|
|
|
int64_t value_width = value->numel() / value->dims()[0];
|
|
|
|
|
PADDLE_ENFORCE_EQ(value_width, value_->numel() / value_->dims()[0],
|
|
|
|
|
std::vector<int64_t> non_keys;
|
|
|
|
|
int64_t value_width = value_->numel() / value_->dims()[0];
|
|
|
|
|
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
|
|
|
|
|
"output tensor should have the same shape with table "
|
|
|
|
|
"execpt the dims[0].");
|
|
|
|
|
|
|
|
|
|
// TODO(Yancey1989): support other place
|
|
|
|
|
platform::CPUPlace cpu;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < keys.size(); ++i) {
|
|
|
|
|
int64_t index = Index(keys[i]);
|
|
|
|
|
if (index == -1) {
|
|
|
|
|
non_keys.push_back(keys[i]);
|
|
|
|
|
} else {
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
framework::ToDataType(value_->type()),
|
|
|
|
|
TensorCopyVisitor(cpu, value, offset * value_width, *value_.get(),
|
|
|
|
|
TensorCopyVisitor(cpu, value, i * value_width, *value_.get(),
|
|
|
|
|
index * value_width, value_width));
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return non_keys;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
|
|
|
|
|