|
|
|
@ -719,7 +719,7 @@ PartialGradTask::PartialGradTask(
|
|
|
|
|
auto grad_accumulator_iter = grad_accumulators_.find(mapped_out_grad_var);
|
|
|
|
|
if (grad_accumulator_iter == grad_accumulators_.end()) {
|
|
|
|
|
ready_grad_vars_.Set(mapped_out_grad_var,
|
|
|
|
|
std::make_shared<VarBase>(false, out_grad_var));
|
|
|
|
|
std::make_shared<VarBase>(out_grad_var));
|
|
|
|
|
VLOG(10) << "Fill 1.0f or user-provided gradient as ready var "
|
|
|
|
|
<< out_grad_var->Name();
|
|
|
|
|
} else {
|
|
|
|
@ -783,7 +783,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
|
|
|
|
|
if (!input_pair.second.IsGrad()) {
|
|
|
|
|
for (auto &fwd_var : input_pair.second) {
|
|
|
|
|
if (fwd_var) {
|
|
|
|
|
new_inputs.emplace_back(new VarBase(true, fwd_var));
|
|
|
|
|
new_inputs.emplace_back(new VarBase(fwd_var));
|
|
|
|
|
VLOG(10) << "Unpacked forward var " << fwd_var->Name()
|
|
|
|
|
<< ", grad ops: " << GradOpTypes(*new_inputs.back());
|
|
|
|
|
} else {
|
|
|
|
@ -813,7 +813,7 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
|
|
|
|
|
for (auto &fwd_var : output_pair.second) {
|
|
|
|
|
// unpack forward var
|
|
|
|
|
if (fwd_var) {
|
|
|
|
|
new_outputs.emplace_back(new VarBase(true, fwd_var));
|
|
|
|
|
new_outputs.emplace_back(new VarBase(fwd_var));
|
|
|
|
|
VLOG(10) << "Unpacked forward var " << fwd_var->Name();
|
|
|
|
|
} else {
|
|
|
|
|
new_outputs.emplace_back();
|
|
|
|
@ -878,44 +878,43 @@ void PartialGradTask::RunEachOp(const OpBase *op) {
|
|
|
|
|
auto partial_grad_grads = accumulator_info->SumGradient(
|
|
|
|
|
std::move(grad_var), op->id(), &is_finished);
|
|
|
|
|
|
|
|
|
|
if (!partial_grad_grads.empty()) {
|
|
|
|
|
auto sum_grad_var_grad =
|
|
|
|
|
accumulator_info->GradVarBase()->MutableGradVarBase();
|
|
|
|
|
sum_grad_var_grad->SetOverridedStopGradient(false);
|
|
|
|
|
|
|
|
|
|
auto assign_node = std::make_shared<GradOpNode>();
|
|
|
|
|
sum_grad_var_grad->SetGradNode(assign_node);
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "Add " << partial_grad_grads.size() << " assign op for "
|
|
|
|
|
<< sum_grad_var_grad->Name();
|
|
|
|
|
|
|
|
|
|
for (auto &grad_grad : partial_grad_grads) {
|
|
|
|
|
auto *assign_op = &(assign_node->emplace_back());
|
|
|
|
|
assign_op->SetType("assign"); // Can use "scale" as static graph mode
|
|
|
|
|
assign_op->SetInput("X", {sum_grad_var_grad->SharedVar()}, true);
|
|
|
|
|
assign_op->SetOutput("Out", {grad_grad}, true);
|
|
|
|
|
assign_op->CheckAttrs();
|
|
|
|
|
assign_op->SetId(OpBase::GenerateUniqueId());
|
|
|
|
|
assign_op->SetPlace(op->place());
|
|
|
|
|
|
|
|
|
|
if (auto grad_pending_node = grad_grad->GetGradNode()) {
|
|
|
|
|
assign_node->InsertGradPendingNode(std::move(grad_pending_node));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "Pending ops of assign is "
|
|
|
|
|
<< GradPendingOpTypes(*assign_node);
|
|
|
|
|
double_grad_nodes_.emplace_back(assign_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (is_finished) {
|
|
|
|
|
VLOG(10) << "Sum has finished for "
|
|
|
|
|
<< accumulator_info->MappedGradVar()->Name() << " "
|
|
|
|
|
<< accumulator_info->GradVarBase();
|
|
|
|
|
ready_grad_vars_.Set(accumulator_info->MappedGradVar(),
|
|
|
|
|
accumulator_info->GradVarBase());
|
|
|
|
|
grad_accumulators_.erase(accumulator_info->MappedGradVar());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (partial_grad_grads.empty()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sum_grad_var_grad =
|
|
|
|
|
accumulator_info->GradVarBase()->MutableGradVarBase();
|
|
|
|
|
sum_grad_var_grad->SetOverridedStopGradient(false);
|
|
|
|
|
|
|
|
|
|
auto assign_node = std::make_shared<GradOpNode>();
|
|
|
|
|
sum_grad_var_grad->SetGradNode(assign_node);
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "Add " << partial_grad_grads.size() << " assign op for "
|
|
|
|
|
<< sum_grad_var_grad->Name();
|
|
|
|
|
|
|
|
|
|
for (auto &grad_grad : partial_grad_grads) {
|
|
|
|
|
auto *assign_op = &(assign_node->emplace_back());
|
|
|
|
|
assign_op->SetType("assign"); // Can use "scale" as static graph mode
|
|
|
|
|
assign_op->SetInput("X", {sum_grad_var_grad->SharedVar()}, true);
|
|
|
|
|
assign_op->SetOutput("Out", {grad_grad}, true);
|
|
|
|
|
assign_op->CheckAttrs();
|
|
|
|
|
assign_op->SetId(OpBase::GenerateUniqueId());
|
|
|
|
|
assign_op->SetPlace(op->place());
|
|
|
|
|
|
|
|
|
|
if (auto grad_pending_node = grad_grad->GetGradNode()) {
|
|
|
|
|
assign_node->InsertGradPendingNode(std::move(grad_pending_node));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "Pending ops of assign is " << GradPendingOpTypes(*assign_node);
|
|
|
|
|
grad_accumulators_.erase(accumulator_info->MappedGradVar());
|
|
|
|
|
double_grad_nodes_.emplace_back(assign_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
grads_to_accumulate_.clear();
|
|
|
|
|