|
|
@ -23,26 +23,13 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
|
|
|
|
: local_scopes_(local_scopes), places_(places) {}
|
|
|
|
: local_scopes_(local_scopes), places_(places) {}
|
|
|
|
|
|
|
|
|
|
|
|
void GatherOpHandle::RunImpl() {
|
|
|
|
void GatherOpHandle::RunImpl() {
|
|
|
|
// the input may have dummy var.
|
|
|
|
// the input and output may have dummy var.
|
|
|
|
std::vector<VarHandle *> in_var_handles;
|
|
|
|
std::vector<VarHandle *> in_var_handles = GetValidVarHandles(inputs_);
|
|
|
|
for (auto *in : inputs_) {
|
|
|
|
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
|
|
|
|
auto *in_handle = dynamic_cast<VarHandle *>(in);
|
|
|
|
|
|
|
|
if (in_handle) {
|
|
|
|
|
|
|
|
in_var_handles.push_back(in_handle);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
in_var_handles.size(), places_.size(),
|
|
|
|
in_var_handles.size(), places_.size(),
|
|
|
|
"The number of output should equal to the number of places.");
|
|
|
|
"The number of output should equal to the number of places.");
|
|
|
|
|
|
|
|
|
|
|
|
// the output may have dummy var.
|
|
|
|
|
|
|
|
std::vector<VarHandle *> out_var_handles;
|
|
|
|
|
|
|
|
for (auto *out : outputs_) {
|
|
|
|
|
|
|
|
auto *out_handle = dynamic_cast<VarHandle *>(out);
|
|
|
|
|
|
|
|
if (out_handle) {
|
|
|
|
|
|
|
|
out_var_handles.push_back(out_handle);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
|
|
|
|
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
|
|
|
|
"The number of output should be one.");
|
|
|
|
"The number of output should be one.");
|
|
|
|
|
|
|
|
|
|
|
@ -58,11 +45,7 @@ void GatherOpHandle::RunImpl() {
|
|
|
|
"The place of input and output should be the same.");
|
|
|
|
"The place of input and output should be the same.");
|
|
|
|
|
|
|
|
|
|
|
|
// Wait input done, this Wait is asynchronous operation
|
|
|
|
// Wait input done, this Wait is asynchronous operation
|
|
|
|
for (auto *in : in_var_handles) {
|
|
|
|
WaitEvents(in_var_handles);
|
|
|
|
if (in->generated_op_) {
|
|
|
|
|
|
|
|
in->generated_op_->Wait(dev_ctxes_[in->place_]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_rows;
|
|
|
|
std::vector<int64_t> out_rows;
|
|
|
|
std::vector<Tensor> in_tensors;
|
|
|
|
std::vector<Tensor> in_tensors;
|
|
|
@ -111,7 +94,7 @@ void GatherOpHandle::RunImpl() {
|
|
|
|
|
|
|
|
|
|
|
|
// copy
|
|
|
|
// copy
|
|
|
|
auto dev_ctx = dev_ctxes_[out_place];
|
|
|
|
auto dev_ctx = dev_ctxes_[out_place];
|
|
|
|
RunAndRecordEvent(out_place, [in_tensors, out_var, dev_ctx, out_place] {
|
|
|
|
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
|
|
|
|
int s = 0, e = 0;
|
|
|
|
int s = 0, e = 0;
|
|
|
|
for (size_t j = 0; j < in_tensors.size(); ++j) {
|
|
|
|
for (size_t j = 0; j < in_tensors.size(); ++j) {
|
|
|
|
e += in_tensors[j].dims()[0];
|
|
|
|
e += in_tensors[j].dims()[0];
|
|
|
@ -123,6 +106,27 @@ void GatherOpHandle::RunImpl() {
|
|
|
|
});
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void GatherOpHandle::WaitEvents(
|
|
|
|
|
|
|
|
const std::vector<VarHandle *> &in_var_handles) {
|
|
|
|
|
|
|
|
for (auto *in : in_var_handles) {
|
|
|
|
|
|
|
|
if (in->generated_op_) {
|
|
|
|
|
|
|
|
in->generated_op_->Wait(dev_ctxes_[in->place_]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<VarHandle *> GatherOpHandle::GetValidVarHandles(
|
|
|
|
|
|
|
|
const std::vector<VarHandleBase *> &inputs) {
|
|
|
|
|
|
|
|
std::vector<VarHandle *> in_var_handles;
|
|
|
|
|
|
|
|
for (auto *in : inputs) {
|
|
|
|
|
|
|
|
auto *in_handle = dynamic_cast<VarHandle *>(in);
|
|
|
|
|
|
|
|
if (in_handle) {
|
|
|
|
|
|
|
|
in_var_handles.push_back(in_handle);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return in_var_handles;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::string GatherOpHandle::Name() const { return "gather"; }
|
|
|
|
std::string GatherOpHandle::Name() const { return "gather"; }
|
|
|
|
} // namespace details
|
|
|
|
} // namespace details
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|