Complete rename

trainerSaveLoadParams
Yu Yang 7 years ago
parent 651157f6a1
commit c0ac0cd6b3

@ -45,7 +45,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool skip_scale_loss)
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes) {
@ -53,7 +53,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
skip_scale_loss_ = skip_scale_loss;
use_default_grad_scale_ = use_default_grad_scale;
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
@ -126,8 +126,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if skip_scale_loss_
if (!skip_scale_loss_) {
// user can customize loss@grad if not use_default_grad_scale_
if (use_default_grad_scale_) {
CreateScaleLossGradOp(&result);
}
is_forwarding = false;

@ -41,7 +41,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
bool skip_scale_loss);
bool use_default_grad_scale);
#endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
@ -59,7 +59,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_;
#endif
bool skip_scale_loss_;
bool use_default_grad_scale_;
bool IsScaleLossOp(const OpDesc &op) const;

Loading…
Cancel
Save