refine code

revert-15774-anakin_subgraph_engine
xuezhong 6 years ago
parent 4424021623
commit c5360a3f6b

@ -25,63 +25,64 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor, default: Tensor<float>), The unscaled log probabilities " "(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, " "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number."); "and K is the class number.");
AddInput("Label", AddInput("Labels",
"(Tensor) The ground truth which is a 2-D tensor. Label is a " "(Tensor) The ground truth which is a 2-D tensor. Labels is a "
"Tensor<int64> with shape [N x NT], where NT is the number of" "Tensor<int64> with shape [N x NT], where NT is the number of"
"true labels for each example."); "true labels for each example.");
AddInput( AddInput("CustomizedSamples",
"CustomSamples", "(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N, "
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shaoe [N x " "NT + S],"
"S+NT]." " where N is the batch size, NT is the number of true labels "
"The customized sample labels with true labels at first. This tensor" "and S is the number of negtive sample for each example."
"is only use_custom_samples is true.") "The first NT elements of each row should be the same with true "
"labels, "
"followed by S custom negtive samples. This tensor"
"is only used when use_customized_samples is true.")
.AsDispensable(); .AsDispensable();
AddInput( AddInput(
"CustomProbabilities", "CustomizedProbabilities",
"(Tensor, default: Tensor<float>), A 2-D tensor with shaoe [N x S+NT]." "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
"The customized sample probabilities with true labels at first. This " "The tensor has the same shape with CustomSamples,"
"tensor is only use_custom_samples is true.") "and each element represents probability of element in CustomSamples. "
"This "
"tensor is only used when use_customized_samples is true.")
.AsDispensable(); .AsDispensable();
AddOutput( AddOutput("Samples",
"Samples", "(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N, "
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N x " "NT + S]."
"S+NT]." "The outputs value of sampler, including NT true lables and S "
"The outputs value of sampler by given the true label, where S is the " "negetive samples "
"number of negative sample for each example. So Samples includes NT " "for each example. This will be used in"
"true" "backward calculation.")
"labels and S negative labels for each example. This will be used in"
"backward calculation.")
.AsIntermediate(); .AsIntermediate();
AddOutput( AddOutput(
"Probabilities", "Probabilities",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x " "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
"S+NT]." "The probabilites of sampled positive and negtive labels.")
"The outputs value of progabilites of samples by given the true label, "
"where S is the "
"number of negative sample for each example. So Samples includes NT "
"true"
"labels and S negative labels for each example.")
.AsIntermediate(); .AsIntermediate();
AddOutput("SampledLogits", AddOutput("SampledLogits",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape" "(Tensor, default: Tensor<float>), A 2-D tensor with shape"
"[N x S+NT]. The outputs value of sample logits, which will be" "[N, NT + S]. The outputs value of sampled logits, which will be"
"used in backward calculation.") "used in backward propagation.")
.AsIntermediate(); .AsIntermediate();
AddOutput( AddOutput(
"SampledLabel", "SampledLabels",
"(Tensor, default: Tensor<int64>), A 2-D tensor. The sampled label" "(Tensor, default: Tensor<int64>), A 2-D tensor. The sampled labels"
"with shape [N x S + NT]."); "with shape [N, NT]. The tonsor contains hard labels as input to "
" softmax op, that is 0, 1, …, NT-1 because of the first NT elements"
" of Sampels are positive lables.");
AddAttr<bool>( AddAttr<bool>(
"use_custom_samples", "use_customized_samples",
"An indicator whether to use custom samples with probabilities, if True" "An indicator whether to use customized samples with probabilities, if "
"the operator will use custom samples and custom probabilities" "True"
"the operator will use customized samples and customized probabilities"
"otherwise, the operator will generate them by itself.") "otherwise, the operator will generate them by itself.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>( AddAttr<bool>(
"uniq", "uniq",
"An indicator whether to sample non-repetitive negtive labels, if True" "An indicator whether to sample non-repetitive negtive labels, if True"
"the operator will sample negtive labels without replacement." "the operator will sample negtive labels without replacement."
"otherwise, the operator will sample negtive labels with replacement.") "Otherwise, the operator will sample negtive labels with replacement.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>( AddAttr<bool>(
"remove_accidental_hits", "remove_accidental_hits",
@ -95,8 +96,7 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
""" """
Computes sampled output training logits and labels suitable for implementing Computes sampled output training logits and labels suitable for implementing
sampled softmax. sampled softmax.
""" """
)DOC"); )DOC");
@ -110,7 +110,8 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"), PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) should be not null."); "Input(Logits) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Samples"), PADDLE_ENFORCE(ctx->HasOutput("Samples"),
"Output(Samples) should be not null."); "Output(Samples) should be not null.");
@ -118,11 +119,11 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
"Output(Probabilities) should be not null."); "Output(Probabilities) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("SampledLogits"), PADDLE_ENFORCE(ctx->HasOutput("SampledLogits"),
"Output(SampledLogits) should be not null."); "Output(SampledLogits) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("SampledLabel"), PADDLE_ENFORCE(ctx->HasOutput("SampledLabels"),
"Output(SampledLabel) should be not null."); "Output(SampledLabels) should be not null.");
auto logits_dims = ctx->GetInputDim("Logits"); auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label"); auto labels_dims = ctx->GetInputDim("Labels");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
logits_dims.size(), 2UL, logits_dims.size(), 2UL,
@ -135,7 +136,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("SampledLabel", {logits_dims[0], labels_dims[1]}); ctx->SetOutputDim("SampledLabels", {logits_dims[0], labels_dims[1]});
} }
protected: protected:
@ -144,7 +145,6 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits")); auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits"));
framework::OpKernelType kt = framework::OpKernelType kt =
framework::OpKernelType(data_type, ctx.device_context()); framework::OpKernelType(data_type, ctx.device_context());
// kt.place_ = platform::CPUPlace();
return kt; return kt;
} }
}; };
@ -157,7 +157,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"), PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) should not be null."); "Input(Logits) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Samples"), PADDLE_ENFORCE(ctx->HasInput("Samples"),
"Input(Samples) should be not null."); "Input(Samples) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("SampledLogits"), PADDLE_ENFORCE(ctx->HasInput("SampledLogits"),
@ -168,7 +169,7 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
"Output(Logits@Grad) should be not null."); "Output(Logits@Grad) should be not null.");
auto logit_dims = ctx->GetInputDim("Logits"); auto logit_dims = ctx->GetInputDim("Logits");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Labels");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
"The label should be a 2-D tensor."); "The label should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL, PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL,
@ -185,7 +186,6 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
ctx.InputVar(framework::GradVarName("SampledLogits"))); ctx.InputVar(framework::GradVarName("SampledLogits")));
framework::OpKernelType kt = framework::OpKernelType kt =
framework::OpKernelType(data_type, ctx.device_context()); framework::OpKernelType(data_type, ctx.device_context());
// kt.place_ = platform::CPUPlace();
return kt; return kt;
} }
}; };
@ -200,7 +200,7 @@ class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
auto* grad_op = new framework::OpDesc(); auto* grad_op = new framework::OpDesc();
grad_op->SetType("sample_logits_grad"); grad_op->SetType("sample_logits_grad");
grad_op->SetInput("Logits", Input("Logits")); grad_op->SetInput("Logits", Input("Logits"));
grad_op->SetInput("Label", Input("Label")); grad_op->SetInput("Labels", Input("Labels"));
grad_op->SetInput("Samples", Output("Samples")); grad_op->SetInput("Samples", Output("Samples"));
grad_op->SetInput("SampledLogits", Output("SampledLogits")); grad_op->SetInput("SampledLogits", Output("SampledLogits"));
grad_op->SetInput(framework::GradVarName("SampledLogits"), grad_op->SetInput(framework::GradVarName("SampledLogits"),

@ -109,25 +109,26 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
// get necessary inputs // get necessary inputs
const Tensor* logits = context.Input<Tensor>("Logits"); const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* label = context.Input<Tensor>("Label"); const Tensor* labels = context.Input<Tensor>("Labels");
VLOG(3) << "Enter SampleLogitsCUDAKernel"; VLOG(3) << "Enter SampleLogitsCUDAKernel";
// get necessary outputs // get necessary outputs
Tensor* samples = context.Output<Tensor>("Samples"); Tensor* samples = context.Output<Tensor>("Samples");
Tensor* probabilities = context.Output<Tensor>("Probabilities"); Tensor* probabilities = context.Output<Tensor>("Probabilities");
Tensor* sampled_logits = context.Output<Tensor>("SampledLogits"); Tensor* sampled_logits = context.Output<Tensor>("SampledLogits");
Tensor* sampled_label = context.Output<Tensor>("SampledLabel"); Tensor* sampled_labels = context.Output<Tensor>("SampledLabels");
// shapes // shapes
const auto batch_size = logits->dims()[0]; const auto batch_size = logits->dims()[0];
const auto num_classes = logits->dims()[1]; const auto num_classes = logits->dims()[1];
const auto label_dim = label->dims(); const auto labels_dim = labels->dims();
const auto num_true = label_dim[1]; const auto num_true = labels_dim[1];
const auto samples_dim = samples->dims(); const auto samples_dim = samples->dims();
// attrs // attrs
const auto num_samples = context.Attr<int>("num_samples"); const auto num_samples = context.Attr<int>("num_samples");
const bool use_custom_samples = context.Attr<bool>("use_custom_samples"); const bool use_customized_samples =
context.Attr<bool>("use_customized_samples");
const bool uniq = context.Attr<bool>("uniq"); const bool uniq = context.Attr<bool>("uniq");
const bool remove_accidental_hits = const bool remove_accidental_hits =
context.Attr<bool>("remove_accidental_hits"); context.Attr<bool>("remove_accidental_hits");
@ -140,21 +141,22 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
math::SetConstant<platform::CUDADeviceContext, T> set_zero; math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(dev_ctx, sampled_logits, static_cast<T>(0)); set_zero(dev_ctx, sampled_logits, static_cast<T>(0));
auto sampled_label_data = auto sampled_labels_data =
sampled_label->mutable_data<int64_t>(label_dim, context.GetPlace()); sampled_labels->mutable_data<int64_t>(labels_dim, context.GetPlace());
int threads = 512; int threads = 512;
size_t size = batch_size * num_true; size_t size = batch_size * num_true;
int grid = (size + threads - 1) / threads; int grid = (size + threads - 1) / threads;
GPUSetLabel< GPUSetLabel<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>( T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, num_true, sampled_label_data); size, num_true, sampled_labels_data);
if (use_custom_samples) { if (use_customized_samples) {
const Tensor* custom_samples = context.Input<Tensor>("CustomSamples"); const Tensor* customized_samples =
const Tensor* custom_probabilities = context.Input<Tensor>("CustomizedSamples");
context.Input<Tensor>("CustomProbabilities"); const Tensor* customized_probabilities =
samples->ShareDataWith(*custom_samples); context.Input<Tensor>("CustomizedProbabilities");
probabilities->ShareDataWith(*custom_probabilities); samples->ShareDataWith(*customized_samples);
probabilities->ShareDataWith(*customized_probabilities);
} else { } else {
samples->mutable_data<int64_t>(context.GetPlace()); samples->mutable_data<int64_t>(context.GetPlace());
probabilities->mutable_data<T>(samples_dim, context.GetPlace()); probabilities->mutable_data<T>(samples_dim, context.GetPlace());
@ -162,7 +164,7 @@ class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
const auto seed = context.Attr<int>("seed"); const auto seed = context.Attr<int>("seed");
auto sampler_with_prob = math::GPUSampleWithProb<T>(); auto sampler_with_prob = math::GPUSampleWithProb<T>();
sampler_with_prob(context.cuda_device_context(), seed, num_classes, uniq, sampler_with_prob(context.cuda_device_context(), seed, num_classes, uniq,
num_samples, label, samples, probabilities); num_samples, labels, samples, probabilities);
} }
// UNDERSTAND: gather sampled logits and remove accidental hits if needed // UNDERSTAND: gather sampled logits and remove accidental hits if needed

@ -150,24 +150,25 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
VLOG(3) << "Enter SampleLogitsKernel"; VLOG(3) << "Enter SampleLogitsKernel";
// get necessary inputs // get necessary inputs
const Tensor* logits = context.Input<Tensor>("Logits"); const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* label = context.Input<Tensor>("Label"); const Tensor* labels = context.Input<Tensor>("Labels");
// get necessary outputs // get necessary outputs
Tensor* samples = context.Output<Tensor>("Samples"); Tensor* samples = context.Output<Tensor>("Samples");
Tensor* probabilities = context.Output<Tensor>("Probabilities"); Tensor* probabilities = context.Output<Tensor>("Probabilities");
Tensor* sampled_logits = context.Output<Tensor>("SampledLogits"); Tensor* sampled_logits = context.Output<Tensor>("SampledLogits");
Tensor* sampled_label = context.Output<Tensor>("SampledLabel"); Tensor* sampled_labels = context.Output<Tensor>("SampledLabels");
// shapes // shapes
const auto batch_size = logits->dims()[0]; const auto batch_size = logits->dims()[0];
const auto num_classes = logits->dims()[1]; const auto num_classes = logits->dims()[1];
const auto label_dim = label->dims(); const auto labels_dim = labels->dims();
const auto num_true = label_dim[1]; const auto num_true = labels_dim[1];
const auto samples_dim = samples->dims(); const auto samples_dim = samples->dims();
// attrs // attrs
const auto num_samples = context.Attr<int>("num_samples"); const auto num_samples = context.Attr<int>("num_samples");
const bool use_custom_samples = context.Attr<bool>("use_custom_samples"); const bool use_customized_samples =
context.Attr<bool>("use_customized_samples");
const bool remove_accidental_hits = const bool remove_accidental_hits =
context.Attr<bool>("remove_accidental_hits"); context.Attr<bool>("remove_accidental_hits");
@ -177,18 +178,21 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
// UNDERSTAND: allocate memories for temporaries // UNDERSTAND: allocate memories for temporaries
sampled_logits->mutable_data<T>(samples_dim, context.GetPlace()); sampled_logits->mutable_data<T>(samples_dim, context.GetPlace());
auto sampled_label_data = auto sampled_labels_data =
sampled_label->mutable_data<int64_t>(label_dim, context.GetPlace()); sampled_labels->mutable_data<int64_t>(labels_dim, context.GetPlace());
for (int i = 0; i < batch_size; ++i) for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_true; ++j) for (int j = 0; j < num_true; ++j) {
sampled_label_data[i * num_true + j] = j; sampled_labels_data[i * num_true + j] = j;
}
if (use_custom_samples) { }
const Tensor* custom_samples = context.Input<Tensor>("CustomSamples");
const Tensor* custom_probabilities = if (use_customized_samples) {
context.Input<Tensor>("CustomProbabilities"); const Tensor* customized_samples =
samples->ShareDataWith(*custom_samples); context.Input<Tensor>("CustomizedSamples");
probabilities->ShareDataWith(*custom_probabilities); const Tensor* customized_probabilities =
context.Input<Tensor>("CustomizedProbabilities");
samples->ShareDataWith(*customized_samples);
probabilities->ShareDataWith(*customized_probabilities);
} else { } else {
samples->mutable_data<int64_t>(context.GetPlace()); samples->mutable_data<int64_t>(context.GetPlace());
probabilities->mutable_data<T>(samples_dim, context.GetPlace()); probabilities->mutable_data<T>(samples_dim, context.GetPlace());
@ -197,7 +201,7 @@ class SampleLogitsKernel : public framework::OpKernel<T> {
auto sampler_with_prob = auto sampler_with_prob =
math::SampleWithProb<platform::CPUDeviceContext, T>(); math::SampleWithProb<platform::CPUDeviceContext, T>();
sampler_with_prob(dev_ctx, math::LogUniformSampler(num_classes, seed), sampler_with_prob(dev_ctx, math::LogUniformSampler(num_classes, seed),
num_samples, label, samples, probabilities); num_samples, labels, samples, probabilities);
} }
// UNDERSTAND: gather sampled logits and remove accidental hits if needed // UNDERSTAND: gather sampled logits and remove accidental hits if needed

@ -5771,9 +5771,9 @@ def sampled_softmax_with_cross_entropy(logits,
num_samples, num_samples,
num_true=1, num_true=1,
remove_accidental_hits=True, remove_accidental_hits=True,
use_custom_samples=False, use_customized_samples=False,
custom_samples=None, customized_samples=None,
custom_probabilities=None, customized_probabilities=None,
seed=0): seed=0):
""" """
**Sampled Softmax With Cross Entropy Operator.** **Sampled Softmax With Cross Entropy Operator.**
@ -5789,7 +5789,7 @@ def sampled_softmax_with_cross_entropy(logits,
For examples with T true labels (T >= 1), we assume that each true label has For examples with T true labels (T >= 1), we assume that each true label has
a probability of 1/T. For each sample, S samples are generated using a a probability of 1/T. For each sample, S samples are generated using a
log uniform distribution. True labels are concatenated with hese samples to log uniform distribution. True labels are concatenated with these samples to
form T + S samples for each example. So, assume the shape of logits is form T + S samples for each example. So, assume the shape of logits is
[N x K], the shape for samples is [N x (T+S)]. For each sampled label, a [N x K], the shape for samples is [N x (T+S)]. For each sampled label, a
probability is calculated, which corresponds to the Q(y|x) in probability is calculated, which corresponds to the Q(y|x) in
@ -5798,7 +5798,7 @@ def sampled_softmax_with_cross_entropy(logits,
Logits are sampled according to the sampled labels. Then if Logits are sampled according to the sampled labels. Then if
remove_accidental_hits is True, if a sample[i, j] accidentally hits true remove_accidental_hits is True, if a sample[i, j] accidentally hits true
labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to
make its softmax result close to zero. Then samled logits are subtracted by make its softmax result close to zero. Then sampled logits are subtracted by
logQ(y|x), these sampled logits and re-indexed labels are used to compute logQ(y|x), these sampled logits and re-indexed labels are used to compute
a softmax with cross entropy. a softmax with cross entropy.
@ -5816,14 +5816,16 @@ def sampled_softmax_with_cross_entropy(logits,
accidentally hits true labels, then the corresponding accidentally hits true labels, then the corresponding
sampled_logits[i, j] is minus by 1e20 to make its softmax result sampled_logits[i, j] is minus by 1e20 to make its softmax result
close to zero. Default is True. close to zero. Default is True.
use_custom_samples (bool): Whether to use custom samples and probabities to sample use_customized_samples (bool): Whether to use custom samples and probabities to sample
logits. logits.
custom_samples (Variable): User defined samples, which is a 1-D tensor with shape [S]. S is the num_samples. customized_samples (Variable): User defined samples, which is a 2-D tensor
custom_probabilities (Variable): User defined probabilities of samples, a 1-D tensor which has the same shape with custom_samples. with shape [N, T + S]. S is the num_samples, and T is the number of true
labels per example.
customized_probabilities (Variable): User defined probabilities of samples,
a 2-D tensor which has the same shape with customized_samples.
seed (int): The random seed for generating random number, which is used seed (int): The random seed for generating random number, which is used
in the process of sampling. Default is 0. in the process of sampling. Default is 0.
Returns: Returns:
Variable: Return the cross entropy loss which is a 2-D tensor with shape Variable: Return the cross entropy loss which is a 2-D tensor with shape
[N x 1]. [N x 1].
@ -5849,18 +5851,18 @@ def sampled_softmax_with_cross_entropy(logits,
type='sample_logits', type='sample_logits',
inputs={ inputs={
'Logits': logits, 'Logits': logits,
'Label': label, 'Labels': label,
'CustomSamples': custom_samples, 'CustomSamples': custom_samples,
'CustomProbabilities': custom_probabilities 'CustomProbabilities': custom_probabilities
}, },
outputs={ outputs={
'Samples': samples, 'Samples': samples,
'Probabilities': probabilities, 'Probabilities': probabilities,
'SampledLabel': sampled_label, 'SampledLabels': sampled_label,
'SampledLogits': sampled_logits 'SampledLogits': sampled_logits
}, },
attrs={ attrs={
'use_custom_samples': use_custom_samples, 'use_customized_samples': use_customized_samples,
'uniq': True, 'uniq': True,
'remove_accidental_hits': remove_accidental_hits, 'remove_accidental_hits': remove_accidental_hits,
'num_samples': num_samples, 'num_samples': num_samples,

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save