|
|
|
@ -5921,6 +5921,8 @@ def sampled_softmax_with_cross_entropy(logits,
|
|
|
|
|
sampled_logits \
|
|
|
|
|
= helper.create_variable_for_type_inference(dtype=logits.dtype)
|
|
|
|
|
sampled_label = helper.create_variable_for_type_inference(dtype='int64')
|
|
|
|
|
sampled_softlabel = helper.create_variable_for_type_inference(
|
|
|
|
|
dtype=logits.dtype)
|
|
|
|
|
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='sample_logits',
|
|
|
|
@ -5945,14 +5947,20 @@ def sampled_softmax_with_cross_entropy(logits,
|
|
|
|
|
})
|
|
|
|
|
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
|
|
|
|
|
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='one_hot',
|
|
|
|
|
inputs={'X': sampled_label},
|
|
|
|
|
attrs={'depth': num_samples + 1},
|
|
|
|
|
outputs={'Out': sampled_softlabel})
|
|
|
|
|
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='softmax_with_cross_entropy',
|
|
|
|
|
inputs={'Logits': sampled_logits,
|
|
|
|
|
'Label': sampled_label},
|
|
|
|
|
'Label': sampled_softlabel},
|
|
|
|
|
outputs={'Softmax': softmax,
|
|
|
|
|
'Loss': loss},
|
|
|
|
|
attrs={
|
|
|
|
|
'soft_label': False,
|
|
|
|
|
'soft_label': True,
|
|
|
|
|
'ignore_index': False,
|
|
|
|
|
'numeric_stable_mode': False
|
|
|
|
|
})
|
|
|
|
|