fix hook operator compare issue

pull/2767/head
kingfo 5 years ago
parent 7b5b4837ff
commit 5d301092f6

@ -106,28 +106,23 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
}
FuncGraphPtr bprop_fg = nullptr;
auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
bprop_fg = iter->second;
}
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
bprop_fg = BpropCut(value_node, resources);
} else {
auto iter = bprop_registry_.find(prim);
if (iter != bprop_registry_.end()) {
bprop_fg = iter->second;
}
if (bprop_fg == nullptr) {
bool is_faked_bprop = false;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
bprop_fg = BpropCut(value_node, resources);
} else {
if (bprop_fg == nullptr) {
bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
if (bprop_fg != nullptr) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
} else {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
}
}
// To support primitives with variable params, do not cache faked bprop
if (!is_faked_bprop) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
}
}
auto expanded_fg = BpropToK(prim, bprop_fg);

@ -109,6 +109,9 @@ class Tensor(Tensor_):
out = tensor_operator_registry.get('__neg__')(self)
return out
def __pos__(self):
return self
def __iadd__(self, other):
return self.__add__(other)

Loading…
Cancel
Save