Add match_matrix_tensor op (#18525)
* add matrch_matrix_tensor op test=develop * fix ignore unittest if with_mkl=off test=develop * clean code and rm is_test param test=develop * modify API.spec test=develop * rm useless code in search_compute.h test=develop * modify api.spec test=develop * modify default_grad.spec test=develop * Add API test code test=develop * clean code in search_computer.h * modify PADDLE_ENFORCE and clean search_compute.h test=develop * fix code style test=developpadding_in_crf
parent
5b6673c44d
commit
78a3d837f8
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,41 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using Tensor = framework::Tensor;
|
||||
class MatchMatrixTensorOP : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||
};
|
||||
|
||||
class MatchMatrixTensorOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||
};
|
||||
|
||||
class MatchMatrixTensorOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override;
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,138 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/platform/dynload/mklml.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using LoD = framework::LoD;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
void call_gemm(const math::BlasT<DeviceContext, T>& blas,
|
||||
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
|
||||
const int M, const int N, const int K, const T alpha, const T* A,
|
||||
const T* B, const T beta, T* C) {
|
||||
int lda = (TransA == CblasNoTrans) ? K : M;
|
||||
int ldb = (TransB == CblasNoTrans) ? N : K;
|
||||
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void call_gemm(const framework::ExecutionContext& ctx,
|
||||
const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
|
||||
const int M, const int N, const int K, const T alpha, const T* A,
|
||||
const T* B, const T beta, T* C) {
|
||||
int lda = (TransA == CblasNoTrans) ? K : M;
|
||||
int ldb = (TransB == CblasNoTrans) ? N : K;
|
||||
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
|
||||
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
void call_gemm_with_lda(const math::BlasT<DeviceContext, T>& blas,
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB, const int M, const int N,
|
||||
const int K, const T alpha, const T* A, const T* B,
|
||||
const T beta, T* C, int lda) {
|
||||
int ldb = (TransB == CblasNoTrans) ? N : K;
|
||||
|
||||
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void call_gemm_batched(const framework::ExecutionContext& ctx,
|
||||
const CBLAS_TRANSPOSE TransA,
|
||||
const CBLAS_TRANSPOSE TransB, const int M, const int N,
|
||||
const int K, const T alpha, const T** A, const T** B,
|
||||
const T beta, T** C, const int batch) {
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
call_gemm(ctx, TransA, TransB, M, N, K, alpha, A[i], B[i], beta, C[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef TYPE_USE_FLOAT
|
||||
#define TYPE_USE_FLOAT
|
||||
#endif
|
||||
#ifndef USE_SSE
|
||||
#define USE_SSE
|
||||
#endif
|
||||
|
||||
#if defined(TYPE_USE_FLOAT)
|
||||
|
||||
#define __m256x __m256
|
||||
#define __m128x __m128
|
||||
|
||||
static const unsigned int AVX_STEP_SIZE = 8;
|
||||
static const unsigned int SSE_STEP_SIZE = 4;
|
||||
static const unsigned int AVX_CUT_LEN_MASK = 7U;
|
||||
static const unsigned int SSE_CUT_LEN_MASK = 3U;
|
||||
|
||||
#define _mm256_mul_px _mm256_mul_ps
|
||||
#define _mm256_add_px _mm256_add_ps
|
||||
#define _mm256_load_px _mm256_loadu_ps
|
||||
#define _mm256_store_px _mm256_storeu_ps
|
||||
#define _mm256_broadcast_sx _mm256_broadcast_ss
|
||||
|
||||
#define _mm_add_px _mm_add_ps
|
||||
#define _mm_mul_px _mm_mul_ps
|
||||
#define _mm_load_px _mm_loadu_ps
|
||||
#define _mm_store_px _mm_storeu_ps
|
||||
#define _mm_load1_px _mm_load1_ps
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
inline void sse_axpy(const T* x, T* y, size_t len, const T alpha) {
|
||||
unsigned int jjj, lll;
|
||||
jjj = lll = 0;
|
||||
|
||||
#if defined(USE_AVX)
|
||||
lll = len & ~AVX_CUT_LEN_MASK;
|
||||
__m256x mm_alpha = _mm256_broadcast_sx(&alpha);
|
||||
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
|
||||
_mm256_store_px(
|
||||
y + jjj,
|
||||
_mm256_add_px(_mm256_load_px(y + jjj),
|
||||
_mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj))));
|
||||
}
|
||||
|
||||
#elif defined(USE_SSE)
|
||||
lll = len & ~SSE_CUT_LEN_MASK;
|
||||
__m128x mm_alpha = _mm_load1_px(&alpha);
|
||||
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
|
||||
_mm_store_px(y + jjj,
|
||||
_mm_add_px(_mm_load_px(y + jjj),
|
||||
_mm_mul_px(mm_alpha, _mm_load_px(x + jjj))));
|
||||
}
|
||||
|
||||
#endif
|
||||
for (; jjj < len; jjj++) {
|
||||
y[jjj] += alpha * x[jjj];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class TestMatchMatrixTensorOp(OpTest):
|
||||
def setUp(self):
|
||||
self.init_op_type()
|
||||
self.set_data()
|
||||
self.compute()
|
||||
|
||||
def init_op_type(self):
|
||||
self.op_type = "match_matrix_tensor"
|
||||
|
||||
def set_data(self):
|
||||
ix, iy, h, dim_t = [5, 8, 3, 4]
|
||||
x_lod = [[1, 2, 2]]
|
||||
y_lod = [[3, 1, 4]]
|
||||
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
|
||||
|
||||
def init_data(self, ix, x_lod, iy, y_lod, h, dim_t):
|
||||
x_data = np.random.random((ix, h)).astype('float32')
|
||||
y_data = np.random.random((iy, h)).astype('float32')
|
||||
w_data = np.random.random((h, dim_t, h)).astype('float32')
|
||||
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod), 'W': w_data}
|
||||
self.attrs = {'dim_t': dim_t}
|
||||
|
||||
def compute(self):
|
||||
x_data, x_lod = self.inputs['X']
|
||||
y_data, y_lod = self.inputs['Y']
|
||||
# [k, dim_t, k] -> [dim_t, k, k]
|
||||
w_data = self.inputs['W'].transpose(1, 0, 2)
|
||||
out = np.zeros((0, 1), dtype=x_data.dtype)
|
||||
# for x*w
|
||||
tmp = np.zeros((0, 1), dtype=x_data.dtype)
|
||||
out_lod = [[]]
|
||||
tmp_lod = [[]]
|
||||
|
||||
x_offset, y_offset = 0, 0
|
||||
for idx in range(len(x_lod[0])):
|
||||
x_len = x_lod[0][idx]
|
||||
y_len = y_lod[0][idx]
|
||||
x_sub = x_data[x_offset:(x_offset + x_len), :]
|
||||
y_sub = y_data[y_offset:(y_offset + y_len), :]
|
||||
tmp_sub = np.dot(x_sub, w_data)
|
||||
tmp = np.vstack((tmp, tmp_sub.reshape(tmp_sub.size, 1)))
|
||||
|
||||
out_sub = np.dot(tmp_sub, y_sub.T).transpose(1, 0, 2)
|
||||
out_lod[0].append(out_sub.size)
|
||||
out = np.vstack((out, out_sub.reshape(out_sub.size, 1)))
|
||||
|
||||
x_offset += x_len
|
||||
y_offset += y_len
|
||||
self.outputs = {'Out': (out, out_lod), 'Tmp': tmp}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005)
|
||||
|
||||
|
||||
class TestMatchMatrixTensorOpCase1(TestMatchMatrixTensorOp):
|
||||
def set_data(self):
|
||||
ix, iy, h, dim_t = [5, 8, 16, 4]
|
||||
x_lod = [[5]]
|
||||
y_lod = [[8]]
|
||||
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
|
||||
|
||||
|
||||
class TestMatchMatrixTensorOpCase2(TestMatchMatrixTensorOp):
|
||||
def set_data(self):
|
||||
ix, iy, h, dim_t = [7, 8, 1, 4]
|
||||
x_lod = [[2, 3, 2]]
|
||||
y_lod = [[3, 1, 4]]
|
||||
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
|
||||
|
||||
|
||||
class TestMatchMatrixTensorOpCase3(TestMatchMatrixTensorOp):
|
||||
def set_data(self):
|
||||
ix, iy, h, dim_t = [5, 9, 32, 1]
|
||||
x_lod = [[1, 2, 2]]
|
||||
y_lod = [[3, 2, 4]]
|
||||
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
|
||||
|
||||
|
||||
class TestMatchMatrixTensorOpCase4(TestMatchMatrixTensorOp):
|
||||
def set_data(self):
|
||||
ix, iy, h, dim_t = [8, 12, 16, 5]
|
||||
x_lod = [[1, 2, 3, 1, 1]]
|
||||
y_lod = [[3, 2, 4, 1, 2]]
|
||||
self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
|
||||
|
||||
def test_api(self):
|
||||
x_lod_tensor = fluid.layers.data(name='x', shape=[10], lod_level=1)
|
||||
y_lod_tensor = fluid.layers.data(name='y', shape=[10], lod_level=1)
|
||||
out, out_tmp = fluid.layers.match_matrix_tensor(
|
||||
x=x_lod_tensor, y=y_lod_tensor, channel_num=3)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
x_data = np.random.rand(7, 10).astype('float32')
|
||||
y_data = np.random.rand(9, 10).astype('float32')
|
||||
x = fluid.create_lod_tensor(x_data, [[2, 5]], place)
|
||||
y = fluid.create_lod_tensor(y_data, [[3, 6]], place)
|
||||
|
||||
exe = fluid.Executor(place=place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
ret = exe.run(feed={'x': x,
|
||||
'y': y},
|
||||
fetch_list=[out],
|
||||
return_numpy=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue