fix bug of issue #21259 (#21287)

pass the argument `allow_out_of_range` of one_hot op to c++ back end.
revert-21172-masked_select_api
Yi Liu 7 years ago committed by GitHub
parent 319d2ba925
commit 0fd1281ef8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5244,16 +5244,16 @@ def one_hot(input, depth, allow_out_of_range=False):
if in_dygraph_mode():
inputs = {'X': input}
attrs = {'depth': depth}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
if not isinstance(depth, Variable):
# user attribute
inputs = {'X': input}
attrs = {'depth': depth}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth}
attrs = {}
attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op(
type="one_hot",
inputs=inputs,

Loading…
Cancel
Save