|
|
|
@ -29,13 +29,8 @@ Tensor *GetTensorFromVar(Variable *in_var) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
const platform::ContextMap &ctxs)
|
|
|
|
|
: local_scopes_(local_scopes), places_(places), ctxs_(ctxs) {
|
|
|
|
|
for (auto &p : places_) {
|
|
|
|
|
this->dev_ctxes_[p] = ctxs_.DevCtx(p);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
const std::vector<platform::Place> &places)
|
|
|
|
|
: local_scopes_(local_scopes), places_(places) {}
|
|
|
|
|
|
|
|
|
|
void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
PADDLE_ENFORCE_EQ(this->inputs_.size(), 1);
|
|
|
|
@ -47,26 +42,18 @@ void BroadcastOpHandle::RunImpl() {
|
|
|
|
|
if (inputs_[0]->generated_op_)
|
|
|
|
|
inputs_[0]->generated_op_->Wait(dev_ctxes_[in_place]);
|
|
|
|
|
|
|
|
|
|
auto iter = std::find(places_.begin(), places_.end(), in_place);
|
|
|
|
|
if (iter == places_.end()) {
|
|
|
|
|
PADDLE_THROW("The input of BCast is not in the places_.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int offset = iter - places_.begin();
|
|
|
|
|
auto in_var = local_scopes_[offset]->FindVar(in_var_handle->name_);
|
|
|
|
|
auto in_scope_idx = in_var_handle->scope_idx_;
|
|
|
|
|
PADDLE_ENFORCE_LT(in_scope_idx, local_scopes_.size(), "");
|
|
|
|
|
auto in_var = local_scopes_[in_scope_idx]->FindVar(in_var_handle->name_);
|
|
|
|
|
|
|
|
|
|
Tensor *in_tensor = GetTensorFromVar(in_var);
|
|
|
|
|
for (auto *out : outputs_) {
|
|
|
|
|
auto out_handle = static_cast<VarHandle *>(out);
|
|
|
|
|
auto &out_p = out_handle->place_;
|
|
|
|
|
|
|
|
|
|
auto iter = std::find(places_.begin(), places_.end(), out_p);
|
|
|
|
|
if (iter == places_.end()) {
|
|
|
|
|
PADDLE_THROW("The output of BCast is not in the places_.");
|
|
|
|
|
}
|
|
|
|
|
int offset = iter - places_.begin();
|
|
|
|
|
|
|
|
|
|
auto *s = local_scopes_[offset];
|
|
|
|
|
auto out_scope_idx = out_handle->scope_idx_;
|
|
|
|
|
PADDLE_ENFORCE_LT(out_scope_idx, local_scopes_.size(), "");
|
|
|
|
|
auto *s = local_scopes_[out_scope_idx];
|
|
|
|
|
auto out_var = s->FindVar(out_handle->name_);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_var->Type(), in_var->Type(), "");
|
|
|
|
|