override comparison operators in Python for Variable

emailweixu-patch-1
qiaolongfei 8 years ago
parent 87b8c62071
commit df7c29e516

@ -151,7 +151,11 @@ def monkey_patch_variable():
("__div__", "elementwise_div", False),
("__rdiv__", "elementwise_div", True),
("__pow__", "elementwise_pow", False),
("__rpow__", "elementwise_pow", True)):
("__rpow__", "elementwise_pow", True),
# for logical compare
("__eq__", "equal", False),
("__lt__", "less_then", False),
("__le__", "less_equal", False), ):
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse))

@ -179,7 +179,7 @@ def polynomial_decay(learning_rate,
shape=[1], dtype='float32', value=1.0)
with layers.Switch() as switch:
with switch.case(layers.equal(x=global_step, y=zero_var)):
with switch.case(global_step == zero_var):
layers.assign(input=one_var, output=div_res)
decay_steps = decay_steps * div_res
else:
@ -229,7 +229,7 @@ def piecewise_decay(global_step, boundaries, values):
shape=[1], dtype='float32', value=float(boundaries[i]))
value_var = layers.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(layers.less_than(global_step, boundary_val)):
with switch.case(global_step < boundary_val):
layers.assign(value_var, lr)
last_value_var = layers.fill_constant(
shape=[1],

Loading…
Cancel
Save