Merge pull request #15893 from xuezhong/add_sample_logits_op

fix bug for sampled softmax
revert-15774-anakin_subgraph_engine
xuezhong 6 years ago committed by GitHub
commit f857e07987
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5921,6 +5921,8 @@ def sampled_softmax_with_cross_entropy(logits,
sampled_logits \
= helper.create_variable_for_type_inference(dtype=logits.dtype)
sampled_label = helper.create_variable_for_type_inference(dtype='int64')
sampled_softlabel = helper.create_variable_for_type_inference(
dtype=logits.dtype)
helper.append_op(
type='sample_logits',
@ -5945,14 +5947,20 @@ def sampled_softmax_with_cross_entropy(logits,
})
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='one_hot',
inputs={'X': sampled_label},
attrs={'depth': num_samples + 1},
outputs={'Out': sampled_softlabel})
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': sampled_logits,
'Label': sampled_label},
'Label': sampled_softlabel},
outputs={'Softmax': softmax,
'Loss': loss},
attrs={
'soft_label': False,
'soft_label': True,
'ignore_index': False,
'numeric_stable_mode': False
})

Loading…
Cancel
Save