fix Tracer::NoGrad, test=develop (#23443)

revert-23830-2.0-beta
Zeng Jinle 5 years ago committed by GitHub
parent ebae6fb6b9
commit 0c23e3ff4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -70,7 +70,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
framework::AttributeMap attrs) {
TraceOp(type, ins, outs, std::move(attrs), expected_place_, no_grad_);
TraceOp(type, ins, outs, std::move(attrs), expected_place_, has_grad_);
}
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,

@ -86,9 +86,9 @@ class Tracer {
void SetExpectedPlace(platform::Place place) { expected_place_ = place; }
bool NoGrad() const { return no_grad_; }
bool HasGrad() const { return has_grad_; }
void SetNoGrad(bool no_grad) { no_grad_ = no_grad; }
void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }
private:
std::unique_ptr<BasicEngine> basic_engine_;
@ -96,7 +96,7 @@ class Tracer {
bool enable_program_desc_tracing_{false};
std::unique_ptr<UniqueNameGenerator> generator_;
platform::Place expected_place_;
bool no_grad_{false};
bool has_grad_{true};
};
// To access static variable current_tracer

@ -695,8 +695,8 @@ void BindImperative(py::module *m_ptr) {
.def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_train_mode", &imperative::Tracer::NoGrad,
&imperative::Tracer::SetNoGrad)
.def_property("_train_mode", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetHasGrad)
.def_property(
"_expected_place",
[](const imperative::Tracer &self) -> py::object {

Loading…
Cancel
Save