|
|
|
@ -18,23 +18,16 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
static Tensor *GetTensorFromVar(Variable *in_var) {
|
|
|
|
|
if (in_var->IsType<LoDTensor>()) {
|
|
|
|
|
return in_var->GetMutable<LoDTensor>();
|
|
|
|
|
} else if (in_var->IsType<SelectedRows>()) {
|
|
|
|
|
return in_var->GetMutable<SelectedRows>()->mutable_value();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const std::vector<platform::Place> &places)
|
|
|
|
|
: local_scopes_(local_scopes), places_(places) {}
|
|
|
|
|
|
|
|
|
|
void GatherOpHandle::RunImpl() {
|
|
|
|
|
PADDLE_ENFORCE_EQ(this->inputs_.size(), places_.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(this->outputs_.size(), 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
this->inputs_.size(), places_.size(),
|
|
|
|
|
"The number of inputs should be equal to the number of place.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(this->outputs_.size(), 1,
|
|
|
|
|
"The number of output should be one.");
|
|
|
|
|
|
|
|
|
|
// Wait input done, this Wait is asynchronous operation
|
|
|
|
|
for (auto *in : inputs_) {
|
|
|
|
@ -46,6 +39,7 @@ void GatherOpHandle::RunImpl() {
|
|
|
|
|
auto in_0_handle = static_cast<VarHandle *>(inputs_[0]);
|
|
|
|
|
auto pre_in_var =
|
|
|
|
|
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
|
|
|
|
|
auto pre_place = in_0_handle->place_;
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_rows;
|
|
|
|
|
std::vector<Tensor *> in_tensors;
|
|
|
|
@ -58,7 +52,8 @@ void GatherOpHandle::RunImpl() {
|
|
|
|
|
in_places.push_back(in_p);
|
|
|
|
|
PADDLE_ENFORCE_LT(in_handle->scope_idx_, local_scopes_.size(),
|
|
|
|
|
"%s is not the the local_scopes ", in_handle->name_);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
|
|
|
|
|
"The place of input should be the same.");
|
|
|
|
|
auto *s = local_scopes_[in_handle->scope_idx_];
|
|
|
|
|
auto in_var = s->FindVar(in_handle->name_);
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_var->Type(), pre_in_var->Type(),
|
|
|
|
@ -69,13 +64,17 @@ void GatherOpHandle::RunImpl() {
|
|
|
|
|
auto &in_sr = in_var->Get<framework::SelectedRows>();
|
|
|
|
|
auto in_sr_rows = in_sr.rows();
|
|
|
|
|
out_rows.insert(out_rows.begin(), in_sr_rows.begin(), in_sr_rows.end());
|
|
|
|
|
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(), "");
|
|
|
|
|
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), "");
|
|
|
|
|
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
|
|
|
|
|
"The height of inputs is not consistent.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), ,
|
|
|
|
|
"The dims of inputs is not consistent.");
|
|
|
|
|
} else if (in_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
auto &pre_in = pre_in_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto &in_lodtensor = in_var->Get<framework::LoDTensor>();
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_lodtensor.lod(), pre_in.lod());
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_lodtensor.dims(), pre_in.dims());
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_lodtensor.lod(), pre_in.lod(),
|
|
|
|
|
"The lod of inputs is not consistent.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_lodtensor.dims(), pre_in.dims(),
|
|
|
|
|
"The dims of inputs is not consistent.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
|
|
|
|
|
}
|
|
|
|
@ -88,7 +87,8 @@ void GatherOpHandle::RunImpl() {
|
|
|
|
|
auto &out_place = out_handle->place_;
|
|
|
|
|
auto out_scope_idx = out_handle->scope_idx_;
|
|
|
|
|
auto out_var = local_scopes_[out_scope_idx]->FindVar(out_handle->name_);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_place.which(), pre_place.which(),
|
|
|
|
|
"The place of input and output should be the same.");
|
|
|
|
|
if (pre_in_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
|
|
|
|
|
auto out = out_var->GetMutable<framework::SelectedRows>();
|
|
|
|
@ -110,12 +110,13 @@ void GatherOpHandle::RunImpl() {
|
|
|
|
|
s = e;
|
|
|
|
|
}
|
|
|
|
|
} else if (pre_in_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
// gather LoDTensor ???
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GatherOpHandle::Name() const { return "broadcast"; }
|
|
|
|
|
std::string GatherOpHandle::Name() const { return "gather"; }
|
|
|
|
|
} // namespace details
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|