Add center Loss Op Support (#18681)
* support center loss * change tensor copy api to high level api tensorcopy * test=develop rewrite the center_loss cuda_kernel to make it faster and add document of the center loss api,also update test function * test=document_preview test=develop update document of center loss * test=document_preview test=develop modify API.spec modify test code remove nouse const_castpadding_in_crf
parent
d21c391447
commit
24f8543106
@ -0,0 +1,157 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/center_loss_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
class CenterLossOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
CenterLossOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of CenterLoss should not be null.");
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasInput("CenterUpdateRate"),
|
||||
"Input(CenterUpdateRate) of CenterLoss should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"),
|
||||
"Input(Label) of CenterLoss should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasInput("Centers"),
|
||||
"Input(Centers) of CenterLoss should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("SampleCenterDiff"),
|
||||
"Output(SampleCenterDiff) of CenterLoss should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
|
||||
"Output(Loss) of CenterLoss should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("CentersOut"),
|
||||
"Output(CentersOut) of CenterLoss shared data with Centers.");
|
||||
|
||||
ctx->SetOutputDim("SampleCenterDiff",
|
||||
{x_dims[0], product(x_dims) / x_dims[0]});
|
||||
ctx->SetOutputDim("CentersOut", ctx->GetInputDim("Centers"));
|
||||
ctx->SetOutputDim("Loss", {x_dims[0], 1});
|
||||
ctx->ShareLoD("X", /*->*/ "Loss");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class CenterLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) Input tensor of center_loss operator.");
|
||||
AddInput("Label", "(Tensor) Input tensor of center_loss operator.");
|
||||
AddInput("Centers", "(Tensor) Input tensor of center_loss operator.");
|
||||
AddInput("CenterUpdateRate",
|
||||
"(Tensor) Input tensor of center_loss operator.");
|
||||
|
||||
AddOutput("CentersOut", "(Tensor) Input tensor of center_loss operator.");
|
||||
AddOutput("SampleCenterDiff",
|
||||
"(Tensor) output tensor of center_loss operator.");
|
||||
AddOutput("Loss", "(Tensor) Output tensor of center_loss operator.");
|
||||
|
||||
AddAttr<int>("cluster_num",
|
||||
"The output cluster num of the center_loss operator.");
|
||||
AddAttr<bool>("need_update", "whether need to update center info.");
|
||||
AddComment(R"DOC(
|
||||
**CenterLoss operator**
|
||||
implemention of the center loss function in the papper<<A Discriminative
|
||||
Feature Learning Approach for Deep Face Recognition>>, equations in this implement
|
||||
is:loss = 1/2 * (x-y)^2 ,where x(X) means the deep feature(output of last hidden layer )
|
||||
and y(Label) the target label
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CenterLossGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("SampleCenterDiff"),
|
||||
"Input(SampleCenterDiff) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
|
||||
"Input(Loss) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
||||
"Output(X) should not be null");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->SetOutputDim(x_grad_name, x_dims);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
ctx.Input<Tensor>("SampleCenterDiff")->type(), ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class CenterLossOpGradMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc());
|
||||
retv->SetType("center_loss_grad");
|
||||
retv->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
|
||||
retv->SetInput("SampleCenterDiff", Output("SampleCenterDiff"));
|
||||
retv->SetInput("X", Input("X"));
|
||||
retv->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||
|
||||
retv->SetAttrMap(Attrs());
|
||||
return retv;
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CPUCtx = paddle::platform::CPUDeviceContext;
|
||||
|
||||
REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker,
|
||||
ops::CenterLossOpGradMaker);
|
||||
|
||||
REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(center_loss, ops::CenterLossKernel<CPUCtx, float>,
|
||||
ops::CenterLossKernel<CPUCtx, double>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(center_loss_grad,
|
||||
ops::CenterLossGradKernel<CPUCtx, float>,
|
||||
ops::CenterLossGradKernel<CPUCtx, double>);
|
@ -0,0 +1,147 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <iostream>
|
||||
#include "paddle/fluid/operators/center_loss_op.h"
|
||||
#include "paddle/fluid/platform/assert.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using platform::PADDLE_CUDA_NUM_THREADS;
|
||||
|
||||
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
|
||||
__global__ void ComputeDifferent(T *centers_diff, const T *X, const T *centers,
|
||||
const int64_t *ids, const int64_t N,
|
||||
const int64_t K, const int64_t D) {
|
||||
int idx = threadIdx.x;
|
||||
int idy = blockIdx.x + threadIdx.y * GridDimX;
|
||||
|
||||
while (idy < K) {
|
||||
int64_t id = ids[idy];
|
||||
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
|
||||
PADDLE_ASSERT_MSG(id < N, "received id:", id);
|
||||
T *out = centers_diff + idy * D;
|
||||
const T *x = X + idy * D;
|
||||
const T *cent = centers + id * D;
|
||||
for (int i = idx; i < D; i += BlockDimX) {
|
||||
out[i] = x[i] - cent[i];
|
||||
}
|
||||
idy += BlockDimY * GridDimX;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
|
||||
__global__ void UpdateCenters(T *centers, T *centers_diff, const int64_t *ids,
|
||||
const int64_t N, const int64_t K, const int64_t D,
|
||||
const T *alpha) {
|
||||
int idx = threadIdx.x;
|
||||
int idy = blockIdx.x + threadIdx.y * GridDimX;
|
||||
int count;
|
||||
while (idy < K) {
|
||||
int count = 1;
|
||||
int64_t id = ids[idy];
|
||||
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
|
||||
PADDLE_ASSERT_MSG(id < N, "received id:", id);
|
||||
|
||||
for (int i = 0; i < K; i++) {
|
||||
if (ids[i] == id) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
const T *diff = centers_diff + idy * D;
|
||||
T *cent = centers + id * D;
|
||||
for (int i = idx; i < D; i += BlockDimX) {
|
||||
paddle::platform::CudaAtomicAdd(¢[i], alpha[0] * diff[i] / count);
|
||||
}
|
||||
idy += BlockDimY * GridDimX;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class CenterLossCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto &device_context = ctx.template device_context<DeviceContext>();
|
||||
auto stream = device_context.stream();
|
||||
auto *X = ctx.Input<Tensor>("X"); // deep feature
|
||||
auto *labels = ctx.Input<Tensor>("Label");
|
||||
auto *centers = ctx.Input<Tensor>("Centers");
|
||||
auto *update_rate = ctx.Input<Tensor>("CenterUpdateRate");
|
||||
int cluster_num = ctx.Attr<int>("cluster_num");
|
||||
auto *lr_center = update_rate->data<T>();
|
||||
bool need_update = static_cast<T>(ctx.Attr<bool>("need_update"));
|
||||
|
||||
auto x_data = X->data<T>();
|
||||
auto label_data = labels->data<int64_t>();
|
||||
|
||||
auto x_dims = X->dims();
|
||||
int batch_size = x_dims[0];
|
||||
const int deep_feat_dim = x_dims[1];
|
||||
|
||||
auto *centers_diff = ctx.Output<Tensor>("SampleCenterDiff");
|
||||
auto centers_diff_data = centers_diff->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto centers_data = centers->data<T>();
|
||||
auto centers_dim = centers->dims();
|
||||
auto *out_loss = ctx.Output<Tensor>("Loss");
|
||||
auto loss_data = out_loss->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto *centers_out = ctx.Output<Tensor>("CentersOut");
|
||||
auto *centers_out_data = centers_out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto ctx_place = ctx.GetPlace();
|
||||
if (centers != centers_out) {
|
||||
framework::TensorCopy(
|
||||
*static_cast<const framework::Tensor *>(centers), ctx_place,
|
||||
*platform::DeviceContextPool::Instance().Get(ctx_place),
|
||||
static_cast<framework::Tensor *>(centers_out));
|
||||
}
|
||||
|
||||
int64_t numel = X->numel();
|
||||
|
||||
size_t N = centers->dims()[0];
|
||||
size_t D = centers->dims()[1];
|
||||
size_t K = labels->numel();
|
||||
|
||||
dim3 threads(128, 8);
|
||||
dim3 grids(8, 1);
|
||||
|
||||
ComputeDifferent<T, 128, 8, 8><<<grids, threads, 0, stream>>>(
|
||||
centers_diff_data, x_data, centers_data, label_data, N, K, D);
|
||||
|
||||
auto &place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto sub_result = EigenMatrix<T>::From(*centers_diff);
|
||||
|
||||
auto sub_res_pow2 = (sub_result * sub_result) / T(2.0);
|
||||
auto z = EigenVector<T>::Flatten(*out_loss);
|
||||
z.device(place) = sub_res_pow2.sum(Eigen::array<int, 1>({{1}}));
|
||||
if (need_update) {
|
||||
UpdateCenters<T, 128, 8, 8><<<grids, threads, 0, stream>>>(
|
||||
centers_out_data, centers_diff_data, label_data, N, K, D, lr_center);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using GPUCtx = paddle::platform::CUDADeviceContext;
|
||||
REGISTER_OP_CUDA_KERNEL(center_loss, ops::CenterLossCUDAKernel<GPUCtx, float>,
|
||||
ops::CenterLossCUDAKernel<GPUCtx, double>);
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(center_loss_grad,
|
||||
ops::CenterLossGradKernel<GPUCtx, float>,
|
||||
ops::CenterLossGradKernel<GPUCtx, double>);
|
@ -0,0 +1,155 @@
|
||||
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/functors.h"
|
||||
#include "paddle/fluid/platform/transform.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename T>
|
||||
struct SubFunctor {
|
||||
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class CenterLossKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto *X = ctx.Input<Tensor>("X"); // deep feature
|
||||
auto *labels = ctx.Input<Tensor>("Label");
|
||||
auto *centers = ctx.Input<Tensor>("Centers");
|
||||
auto *update_rate = ctx.Input<Tensor>("CenterUpdateRate");
|
||||
int cluster_num = ctx.Attr<int>("cluster_num");
|
||||
auto *lr_center = update_rate->data<T>();
|
||||
T alpha = lr_center[0];
|
||||
bool need_update = static_cast<T>(ctx.Attr<bool>("need_update"));
|
||||
|
||||
auto x_data = X->data<T>();
|
||||
auto label_data = labels->data<int64_t>();
|
||||
|
||||
auto centers_dim = centers->dims();
|
||||
auto centers_data = centers->data<T>();
|
||||
|
||||
auto x_dims = X->dims();
|
||||
int batch_size = x_dims[0];
|
||||
int deep_feat_dim = x_dims[1];
|
||||
|
||||
auto centers_diff = ctx.Output<Tensor>("SampleCenterDiff");
|
||||
auto centers_diff_data = centers_diff->mutable_data<T>(ctx.GetPlace());
|
||||
auto *out_loss = ctx.Output<Tensor>("Loss");
|
||||
|
||||
auto *centers_out = ctx.Output<Tensor>("CentersOut");
|
||||
auto *centers_out_data = centers_out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
if (centers_out_data != centers_data) {
|
||||
int size = centers_out->numel() * sizeof(T);
|
||||
memcpy(centers_out_data, centers_data, size);
|
||||
}
|
||||
|
||||
std::vector<int> center_update_count(cluster_num, 1);
|
||||
auto &dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
|
||||
auto loss_data = out_loss->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
Tensor centers_diffacc; // used to accumulate all diff
|
||||
auto centers_diffacc_data =
|
||||
centers_diffacc.mutable_data<T>(centers_dim, ctx.GetPlace());
|
||||
int numel = centers_diffacc.numel();
|
||||
std::memset(centers_diffacc_data, 0, sizeof(T) * numel);
|
||||
|
||||
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
||||
int tLabel;
|
||||
|
||||
const T *x_index;
|
||||
const T *center_index;
|
||||
T *center_out_index;
|
||||
T *center_loss_diff_index;
|
||||
T *acc_index;
|
||||
platform::Transform<DeviceContext> trans;
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
tLabel = label_data[i];
|
||||
center_update_count[tLabel]++;
|
||||
x_index = x_data + i * deep_feat_dim; // xi index
|
||||
center_index = centers_data + tLabel * deep_feat_dim; // center index
|
||||
center_loss_diff_index = centers_diff_data + i * deep_feat_dim;
|
||||
trans(dev_ctx, x_index, x_index + deep_feat_dim, center_index,
|
||||
center_loss_diff_index, SubFunctor<T>());
|
||||
|
||||
acc_index = centers_diffacc_data + tLabel * deep_feat_dim;
|
||||
blas.VADD(deep_feat_dim, center_loss_diff_index, acc_index,
|
||||
acc_index); // accumulate
|
||||
loss_data[i] = blas.DOT(deep_feat_dim, center_loss_diff_index,
|
||||
center_loss_diff_index) /
|
||||
T(2.0);
|
||||
}
|
||||
|
||||
// update centers data
|
||||
if (need_update == true) {
|
||||
for (int i = 0; i < cluster_num; i++) {
|
||||
acc_index = centers_diffacc_data + i * deep_feat_dim;
|
||||
center_out_index = centers_out_data + i * deep_feat_dim;
|
||||
T scale = alpha / center_update_count[i];
|
||||
blas.SCAL(deep_feat_dim, scale, acc_index);
|
||||
blas.VADD(deep_feat_dim, acc_index, center_out_index, center_out_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class CenterLossGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *in0 = context.Input<Tensor>("SampleCenterDiff");
|
||||
auto *in1 = context.Input<Tensor>(framework::GradVarName("Loss"));
|
||||
auto *x_g = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto sub_result = EigenMatrix<T>::From(*in0);
|
||||
auto out_grad = EigenMatrix<T>::From(*in1);
|
||||
|
||||
auto x_dims = x_g->dims();
|
||||
int cols = x_g->numel() / x_dims[0];
|
||||
// calculate gradient
|
||||
auto grad_mat =
|
||||
(out_grad.broadcast(Eigen::array<int, 2>({{1, cols}}))) * sub_result;
|
||||
|
||||
// propagate back to input
|
||||
auto &eigen_place =
|
||||
*context.template device_context<DeviceContext>().eigen_device();
|
||||
x_g->mutable_data<T>(context.GetPlace());
|
||||
// eigen matrix
|
||||
auto x_grad =
|
||||
EigenMatrix<T>::From(*x_g, framework::make_ddim({x_dims[0], cols}));
|
||||
x_grad.device(eigen_place) = grad_mat;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,95 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
class TestCenterLossOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "center_loss"
|
||||
self.dtype = np.float32
|
||||
self.init_dtype_type()
|
||||
batch_size = 6
|
||||
feet_dim = 10
|
||||
cluster_num = 8
|
||||
self.attrs = {}
|
||||
self.attrs['cluster_num'] = cluster_num
|
||||
self.attrs['lambda'] = 0.1
|
||||
self.config()
|
||||
self.attrs['need_update'] = self.need_update
|
||||
labels = np.random.randint(cluster_num, size=batch_size, dtype='int64')
|
||||
feat = np.random.random((batch_size, feet_dim)).astype(np.float32)
|
||||
centers = np.random.random((cluster_num, feet_dim)).astype(np.float32)
|
||||
var_sum = np.zeros((cluster_num, feet_dim), dtype=np.float32)
|
||||
centers_select = centers[labels]
|
||||
output = feat - centers_select
|
||||
diff_square = np.square(output).reshape(batch_size, feet_dim)
|
||||
loss = 0.5 * np.sum(diff_square, axis=1).reshape(batch_size, 1)
|
||||
cout = []
|
||||
for i in range(cluster_num):
|
||||
cout.append(0)
|
||||
for i in range(batch_size):
|
||||
cout[labels[i]] += 1
|
||||
var_sum[labels[i]] += output[i]
|
||||
for i in range(cluster_num):
|
||||
var_sum[i] /= (1 + cout[i])
|
||||
var_sum *= 0.1
|
||||
result = centers + var_sum
|
||||
rate = np.array([0.1]).astype(np.float32)
|
||||
|
||||
self.inputs = {
|
||||
'X': feat,
|
||||
'Label': labels,
|
||||
'Centers': centers,
|
||||
'CenterUpdateRate': rate
|
||||
}
|
||||
|
||||
if self.need_update == True:
|
||||
self.outputs = {
|
||||
'SampleCenterDiff': output,
|
||||
'Loss': loss,
|
||||
'CentersOut': result
|
||||
}
|
||||
else:
|
||||
self.outputs = {
|
||||
'SampleCenterDiff': output,
|
||||
'Loss': loss,
|
||||
'CentersOut': centers
|
||||
}
|
||||
|
||||
def config(self):
|
||||
self.need_update = True
|
||||
|
||||
def init_dtype_type(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Loss')
|
||||
|
||||
|
||||
class TestCenterLossOpNoUpdate(TestCenterLossOp):
|
||||
def config(self):
|
||||
self.need_update = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue