Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_cudnn_lstm
commit
1ffe41d722
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
|
||||
#define SIGMOID_THRESHOLD_MIN -40.0
|
||||
#define SIGMOID_THRESHOLD_MAX 13.0
|
||||
#define EXP_MAX_INPUT 40.0
|
||||
#define XMM_FLOAT_BLOCK 4
|
||||
#define YMM_FLOAT_BLOCK 8
|
||||
#define ZMM_FLOAT_BLOCK 16
|
||||
|
||||
typedef struct {
|
||||
void* gates; // gates: W_ch, W_ih, W_fh, W_oh
|
||||
const void* ct_1;
|
||||
void* ct;
|
||||
void* ht;
|
||||
/* weight_peephole and checked data are only used in peephole*/
|
||||
const void* wp{nullptr};
|
||||
void* checked{nullptr};
|
||||
} lstm_t;
|
||||
|
||||
typedef struct {
|
||||
void* gates; // gates: {W_update, W_reset; W_state}
|
||||
const void* ht_1;
|
||||
void* ht;
|
||||
} gru_t;
|
||||
|
||||
struct rnn_attr_s {
|
||||
int d;
|
||||
std::string act_gate, act_cand;
|
||||
rnn_attr_s() = default;
|
||||
rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand)
|
||||
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
|
||||
};
|
||||
|
||||
struct lstm_attr_s : public rnn_attr_s {
|
||||
bool use_peephole;
|
||||
std::string act_cell;
|
||||
lstm_attr_s() = default;
|
||||
lstm_attr_s(int _d, const std::string& _act_gate,
|
||||
const std::string& _act_cand, const std::string& _act_cell,
|
||||
bool _use_peephole = false)
|
||||
: rnn_attr_s(_d, _act_gate, _act_cand),
|
||||
use_peephole(_use_peephole),
|
||||
act_cell(_act_cell) {}
|
||||
};
|
||||
|
||||
typedef struct rnn_attr_s gru_attr_t;
|
||||
typedef struct lstm_attr_s lstm_attr_t;
|
||||
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,238 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace jitkernel {
|
||||
namespace refer {
|
||||
/* Refer code only focus on correctness */
|
||||
|
||||
template <typename T>
|
||||
void VMul(const T* x, const T* y, T* z, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
z[i] = x[i] * y[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VAdd(const T* x, const T* y, T* z, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
z[i] = x[i] + y[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VAddRelu(const T* x, const T* y, T* z, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
z[i] = x[i] + y[i];
|
||||
z[i] = z[i] > 0 ? z[i] : 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VScal(const T* a, const T* x, T* y, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = a[0] * x[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VAddBias(const T* a, const T* x, T* y, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = a[0] + x[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VRelu(const T* x, T* y, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = x[i] > 0 ? x[i] : 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void VIdentity(const T* x, T* y, int n) {}
|
||||
|
||||
template <typename T>
|
||||
void VExp(const T* x, T* y, int n) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = std::exp(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VSigmoid(const T* x, T* y, int n) {
|
||||
// y = 1 / (1 + e^-x)
|
||||
const T min = SIGMOID_THRESHOLD_MIN;
|
||||
const T max = SIGMOID_THRESHOLD_MAX;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
|
||||
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VTanh(const T* x, T* y, int n) {
|
||||
// y = 2 * sigmoid(2x) - 1
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = static_cast<T>(2) * x[i];
|
||||
}
|
||||
VSigmoid(y, y, n);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
|
||||
if (type == "sigmoid") {
|
||||
return VSigmoid<T>;
|
||||
} else if (type == "relu") {
|
||||
return VRelu<T>;
|
||||
} else if (type == "tanh") {
|
||||
return VTanh<T>;
|
||||
} else if (type == "identity" || type == "") {
|
||||
return VIdentity<T>;
|
||||
}
|
||||
PADDLE_THROW("Not support type: %s", type);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// compute ct and ht
|
||||
template <typename T>
|
||||
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
|
||||
T* gates = reinterpret_cast<T*>(step->gates);
|
||||
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
|
||||
T* ct = reinterpret_cast<T*>(step->ct);
|
||||
T* ht = reinterpret_cast<T*>(step->ht);
|
||||
const T* wp = reinterpret_cast<const T*>(step->wp);
|
||||
T* checked = reinterpret_cast<T*>(step->checked);
|
||||
auto act_gate = getActFunc<T>(attr->act_gate);
|
||||
auto act_cand = getActFunc<T>(attr->act_cand);
|
||||
auto act_cell = getActFunc<T>(attr->act_cell);
|
||||
int d = attr->d;
|
||||
int d2 = d * 2;
|
||||
int d3 = d * 3;
|
||||
// gates: W_ch, W_ih, W_fh, W_oh
|
||||
if (attr->use_peephole) {
|
||||
VMul(wp, ct_1, checked, d);
|
||||
VMul(wp + d, ct_1, checked + d, d);
|
||||
VAdd(checked, gates + d, gates + d, d2);
|
||||
act_gate(gates + d, gates + d, d2);
|
||||
} else {
|
||||
act_gate(gates + d, gates + d, d3);
|
||||
}
|
||||
|
||||
// C_t = C_t-1 * fgated + cand_gated * igated
|
||||
act_cand(gates, gates, d);
|
||||
VMul(gates, gates + d, gates + d, d);
|
||||
VMul(ct_1, gates + d2, gates + d2, d);
|
||||
VAdd(gates + d, gates + d2, ct, d);
|
||||
|
||||
if (attr->use_peephole) {
|
||||
// get ogated
|
||||
VMul(wp + d2, ct, gates + d, d);
|
||||
VAdd(gates + d, gates + d3, gates + d3, d);
|
||||
act_gate(gates + d3, gates + d3, d);
|
||||
}
|
||||
// H_t = act_cell(C_t) * ogated
|
||||
act_cell(ct, gates + d2, d);
|
||||
VMul(gates + d2, gates + d3, ht, d);
|
||||
}
|
||||
|
||||
// compute c1 and h1 without c0 or h0
|
||||
template <typename T>
|
||||
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
|
||||
T* gates = reinterpret_cast<T*>(step->gates);
|
||||
T* ct = reinterpret_cast<T*>(step->ct);
|
||||
T* ht = reinterpret_cast<T*>(step->ht);
|
||||
auto act_gate = getActFunc<T>(attr->act_gate);
|
||||
auto act_cand = getActFunc<T>(attr->act_cand);
|
||||
auto act_cell = getActFunc<T>(attr->act_cell);
|
||||
int d = attr->d;
|
||||
int d2 = d * 2;
|
||||
int d3 = d * 3;
|
||||
/* C_t = igated * cgated*/
|
||||
act_gate(gates + d, gates + d, d);
|
||||
act_cand(gates, gates, d);
|
||||
VMul(gates, gates + d, ct, d);
|
||||
if (attr->use_peephole) {
|
||||
// get outgated, put W_oc * C_t on igated
|
||||
const T* wp = reinterpret_cast<const T*>(step->wp);
|
||||
VMul(wp + d2, ct, gates + d, d);
|
||||
VAdd(gates + d, gates + d3, gates + d3, d);
|
||||
}
|
||||
/* H_t = act_cell(C_t) * ogated */
|
||||
act_gate(gates + d3, gates + d3, d);
|
||||
act_cell(ct, gates + d2, d);
|
||||
VMul(gates + d2, gates + d3, ht, d);
|
||||
}
|
||||
|
||||
// compute h1 without h0
|
||||
template <typename T>
|
||||
void GRUH1(gru_t* step, const gru_attr_t* attr) {
|
||||
T* gates = reinterpret_cast<T*>(step->gates);
|
||||
T* ht = reinterpret_cast<T*>(step->ht);
|
||||
auto act_gate = getActFunc<T>(attr->act_gate);
|
||||
auto act_cand = getActFunc<T>(attr->act_cand);
|
||||
int d = attr->d;
|
||||
int d2 = d * 2;
|
||||
act_gate(gates, gates, d);
|
||||
act_cand(gates + d2, gates + d2, d);
|
||||
VMul(gates, gates + d2, ht, d);
|
||||
}
|
||||
|
||||
// compute the first part of GRU: ht = act_gate(r) * ht_1
|
||||
template <typename T>
|
||||
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
|
||||
// W: {W_update, W_reset; W_state}
|
||||
T* gates = reinterpret_cast<T*>(step->gates);
|
||||
T* ht = reinterpret_cast<T*>(step->ht);
|
||||
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
|
||||
auto act_gate = getActFunc<T>(attr->act_gate);
|
||||
act_gate(gates + attr->d, gates + attr->d, attr->d);
|
||||
VMul(ht_1, gates + attr->d, ht, attr->d);
|
||||
}
|
||||
|
||||
// compute the second part of GRU:
|
||||
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
|
||||
template <typename T>
|
||||
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
|
||||
T* gates = reinterpret_cast<T*>(step->gates);
|
||||
T* ht = reinterpret_cast<T*>(step->ht);
|
||||
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
|
||||
auto act_gate = getActFunc<T>(attr->act_gate);
|
||||
auto act_cand = getActFunc<T>(attr->act_cand);
|
||||
int d = attr->d;
|
||||
T* y = gates + d * 2;
|
||||
act_gate(gates, gates, d);
|
||||
act_cand(y, y, d);
|
||||
// out = zt*ht~ + (1-zt)*ht_1
|
||||
for (int i = 0; i < d; ++i) {
|
||||
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace refer
|
||||
} // namespace jitkernel
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,197 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
def nearest_neighbor_interp_np(X,
|
||||
out_h,
|
||||
out_w,
|
||||
out_size=None,
|
||||
actual_shape=None):
|
||||
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
|
||||
if out_size is not None:
|
||||
out_h = out_size[0]
|
||||
out_w = out_size[1]
|
||||
if actual_shape is not None:
|
||||
out_h = actual_shape[0]
|
||||
out_w = actual_shape[1]
|
||||
n, c, in_h, in_w = X.shape
|
||||
|
||||
ratio_h = ratio_w = 0.0
|
||||
if out_h > 1:
|
||||
ratio_h = (in_h - 1.0) / (out_h - 1.0)
|
||||
if out_w > 1:
|
||||
ratio_w = (in_w - 1.0) / (out_w - 1.0)
|
||||
|
||||
out = np.zeros((n, c, out_h, out_w))
|
||||
for i in range(out_h):
|
||||
in_i = int(ratio_h * i + 0.5)
|
||||
for j in range(out_w):
|
||||
in_j = int(ratio_w * j + 0.5)
|
||||
out[:, :, i, j] = X[:, :, in_i, in_j]
|
||||
|
||||
return out.astype(X.dtype)
|
||||
|
||||
|
||||
class TestNearestInterpOp(OpTest):
|
||||
def setUp(self):
|
||||
self.out_size = None
|
||||
self.actual_shape = None
|
||||
self.init_test_case()
|
||||
self.op_type = "nearest_interp"
|
||||
input_np = np.random.random(self.input_shape).astype("float32")
|
||||
|
||||
output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w,
|
||||
self.out_size, self.actual_shape)
|
||||
self.inputs = {'X': input_np}
|
||||
if self.out_size is not None:
|
||||
self.inputs['OutSize'] = self.out_size
|
||||
if self.actual_shape is not None:
|
||||
self.inputs['OutSize'] = self.actual_shape
|
||||
self.attrs = {
|
||||
'out_h': self.out_h,
|
||||
'out_w': self.out_w,
|
||||
'interp_method': self.interp_method
|
||||
}
|
||||
self.outputs = {'Out': output_np}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out', in_place=True)
|
||||
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [2, 3, 4, 4]
|
||||
self.out_h = 2
|
||||
self.out_w = 2
|
||||
self.out_size = np.array([3, 3]).astype("int32")
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase1(TestNearestInterpOp):
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [4, 1, 7, 8]
|
||||
self.out_h = 1
|
||||
self.out_w = 1
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase2(TestNearestInterpOp):
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [3, 3, 9, 6]
|
||||
self.out_h = 12
|
||||
self.out_w = 12
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase3(TestNearestInterpOp):
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [1, 1, 128, 64]
|
||||
self.out_h = 64
|
||||
self.out_w = 128
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase4(TestNearestInterpOp):
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [4, 1, 7, 8]
|
||||
self.out_h = 1
|
||||
self.out_w = 1
|
||||
self.out_size = np.array([2, 2]).astype("int32")
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase5(TestNearestInterpOp):
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [3, 3, 9, 6]
|
||||
self.out_h = 12
|
||||
self.out_w = 12
|
||||
self.out_size = np.array([11, 11]).astype("int32")
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase6(TestNearestInterpOp):
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [1, 1, 128, 64]
|
||||
self.out_h = 64
|
||||
self.out_w = 128
|
||||
self.out_size = np.array([65, 129]).astype("int32")
|
||||
|
||||
|
||||
class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [3, 2, 32, 16]
|
||||
self.out_h = 64
|
||||
self.out_w = 32
|
||||
self.out_size = np.array([66, 40]).astype("int32")
|
||||
|
||||
|
||||
class TestNearestInterpOpUint8(OpTest):
|
||||
def setUp(self):
|
||||
self.out_size = None
|
||||
self.actual_shape = None
|
||||
self.init_test_case()
|
||||
self.op_type = "nearest_interp"
|
||||
input_np = np.random.randint(
|
||||
low=0, high=256, size=self.input_shape).astype("uint8")
|
||||
output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w,
|
||||
self.out_size, self.actual_shape)
|
||||
self.inputs = {'X': input_np}
|
||||
if self.out_size is not None:
|
||||
self.inputs['OutSize'] = self.out_size
|
||||
self.attrs = {
|
||||
'out_h': self.out_h,
|
||||
'out_w': self.out_w,
|
||||
'interp_method': self.interp_method
|
||||
}
|
||||
self.outputs = {'Out': output_np}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_with_place(place=core.CPUPlace(), atol=1)
|
||||
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [1, 3, 9, 6]
|
||||
self.out_h = 10
|
||||
self.out_w = 9
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8):
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [2, 3, 128, 64]
|
||||
self.out_h = 120
|
||||
self.out_w = 50
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8):
|
||||
def init_test_case(self):
|
||||
self.interp_method = 'nearest'
|
||||
self.input_shape = [4, 1, 7, 8]
|
||||
self.out_h = 5
|
||||
self.out_w = 13
|
||||
self.out_size = np.array([6, 15]).astype("int32")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue