|
|
|
@ -125,30 +125,6 @@ struct OpHandle {
|
|
|
|
|
virtual void RunImpl() = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ComputationOpHandle : public OpHandle {
|
|
|
|
|
std::unique_ptr<OperatorBase> op_;
|
|
|
|
|
Scope *scope_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
|
|
|
|
|
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
|
|
|
|
|
platform::Place place)
|
|
|
|
|
: op_(framework::OpRegistry::CreateOp(op_desc)),
|
|
|
|
|
scope_(scope),
|
|
|
|
|
place_(place) {}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void RunImpl() override {
|
|
|
|
|
auto *cur_ctx = dev_ctx_[place_];
|
|
|
|
|
for (auto *in : inputs_) {
|
|
|
|
|
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
|
|
|
|
|
in->generated_op_->Wait(cur_ctx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_->Run(*scope_, place_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ScaleLossGradOpHandle : public OpHandle {
|
|
|
|
|
float coeff_;
|
|
|
|
|
Scope *scope_;
|
|
|
|
@ -396,6 +372,36 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ComputationOpHandle : public OpHandle {
|
|
|
|
|
std::unique_ptr<OperatorBase> op_;
|
|
|
|
|
Scope *scope_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
|
|
|
|
|
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
|
|
|
|
|
platform::Place place)
|
|
|
|
|
: op_(framework::OpRegistry::CreateOp(op_desc)),
|
|
|
|
|
scope_(scope),
|
|
|
|
|
place_(place) {}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void RunImpl() override {
|
|
|
|
|
auto *cur_ctx = dev_ctx_[place_];
|
|
|
|
|
for (auto *in : inputs_) {
|
|
|
|
|
bool need_wait =
|
|
|
|
|
in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx;
|
|
|
|
|
if (dynamic_cast<NCCLAllReduceOpHandle *>(in->generated_op_)) {
|
|
|
|
|
VLOG(3) << "Input is nccl all reduce, need to wait" << need_wait;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (need_wait) {
|
|
|
|
|
in->generated_op_->Wait(cur_ctx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_->Run(*scope_, place_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
ParallelExecutor::ParallelExecutor(
|
|
|
|
|
size_t num_threads, const std::vector<platform::Place> &places,
|
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
|