follow comments

feature/design_of_v2_layer_converter
liaogang 9 years ago
parent 8cde2d119f
commit f27fd9dc28

@ -44,7 +44,6 @@ if(MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64)
message(STATUS "Found MKL (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") message(STATUS "Found MKL (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
set(CBLAS_FOUND ON) set(CBLAS_FOUND ON)
if(${MKL_LAPACK_INC_DIR}) if(${MKL_LAPACK_INC_DIR})
add_definitions(-DPADDLE_USE_LAPACK)
message(STATUS "Found lapack in MKL (include: ${MKL_LAPACK_INC_DIR})") message(STATUS "Found lapack in MKL (include: ${MKL_LAPACK_INC_DIR})")
endif() endif()
return() # return file. return() # return file.
@ -80,7 +79,6 @@ if(ATLAS_INC_DIR AND ATLAS_CBLAS_LIB AND ATLAS_LIB AND NOT CBLAS_FOUND)
message(STATUS "Found ATLAS (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") message(STATUS "Found ATLAS (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
set(CBLAS_FOUND ON) set(CBLAS_FOUND ON)
if(ATLAS_CLAPACK_INC_DIR) if(ATLAS_CLAPACK_INC_DIR)
add_definitions(-DPADDLE_USE_LAPACK)
message(STATUS "Found lapack in ATLAS (include: ${ATLAS_CLAPACK_INC_DIR})") message(STATUS "Found lapack in ATLAS (include: ${ATLAS_CLAPACK_INC_DIR})")
endif() endif()
return() return()
@ -114,7 +112,6 @@ if(OPENBLAS_INC_DIR AND OPENBLAS_LIB)
message(STATUS "Found OpenBLAS (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") message(STATUS "Found OpenBLAS (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
set(CBLAS_FOUND ON) set(CBLAS_FOUND ON)
if(OPENBLAS_LAPACKE_INC_DIR) if(OPENBLAS_LAPACKE_INC_DIR)
add_definitions(-DPADDLE_USE_LAPACK)
message(STATUS "Found lapack in OpenBLAS (include: ${OPENBLAS_LAPACKE_INC_DIR})") message(STATUS "Found lapack in OpenBLAS (include: ${OPENBLAS_LAPACKE_INC_DIR})")
endif() endif()
return() return()

@ -27,8 +27,6 @@ IF(NOT ${CBLAS_FOUND})
SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/libopenblas.a" CACHE FILEPATH "openblas library" FORCE) SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/libopenblas.a" CACHE FILEPATH "openblas library" FORCE)
ENDIF(WIN32) ENDIF(WIN32)
ADD_DEFINITIONS(-DPADDLE_USE_LAPACK)
ExternalProject_Add( ExternalProject_Add(
openblas openblas
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}

@ -40,18 +40,18 @@ public:
namespace gpu { namespace gpu {
static __device__ Active<real>::forward forward[] = HPPL_ACTIVE_FUNCTION; static __device__ Active<real>::forward forward[] = HPPL_ACTIVE_FUNCTION;
static __device__ Active<real>::backward backward[] = HPPL_ACTIVE_FUNCTION; static __device__ Active<real>::backward backward[] = HPPL_ACTIVE_FUNCTION;
} } // namespace gpu
#else #else
namespace cpu { namespace cpu {
static Active<real>::forward forward[] = HPPL_ACTIVE_FUNCTION; static Active<real>::forward forward[] = HPPL_ACTIVE_FUNCTION;
static Active<real>::backward backward[] = HPPL_ACTIVE_FUNCTION; static Active<real>::backward backward[] = HPPL_ACTIVE_FUNCTION;
} } // namespace cpu
#ifdef __AVX__ #ifdef __AVX__
namespace avx { namespace avx {
static Active<__m256>::forward forward[] = HPPL_ACTIVE_FUNCTION; static Active<__m256>::forward forward[] = HPPL_ACTIVE_FUNCTION;
static Active<__m256>::backward backward[] = HPPL_ACTIVE_FUNCTION; static Active<__m256>::backward backward[] = HPPL_ACTIVE_FUNCTION;
} } // namespace avx
#endif #endif
#endif #endif

@ -273,23 +273,23 @@ extern void hl_bilinear_forward(const real* inData,
const real ratioW); const real ratioW);
/** /**
* @brief Bilinear interpolation backward. * @brief Bilinear interpolation backward.
* *
* @param[out] inGrad input gradient. * @param[out] inGrad input gradient.
* @param[in] inImgH input image height. * @param[in] inImgH input image height.
* @param[in] inImgW input image width. * @param[in] inImgW input image width.
* @param[in] inputH input batchSize. * @param[in] inputH input batchSize.
* @param[in] inputW input image data dim. * @param[in] inputW input image data dim.
* @param[in] outGrad output gradient. * @param[in] outGrad output gradient.
* @param[in] outImgH output image height. * @param[in] outImgH output image height.
* @param[in] outImgW output image width. * @param[in] outImgW output image width.
* @param[in] outputH output batchSize. * @param[in] outputH output batchSize.
* @param[in] outputW output image data dim. * @param[in] outputW output image data dim.
* @param[in] numChannels number of channels. * @param[in] numChannels number of channels.
* @param[in] ratioH inImgH / outImgH. * @param[in] ratioH inImgH / outImgH.
* @param[in] ratioW inImgW / outImgW. * @param[in] ratioW inImgW / outImgW.
* *
*/ */
extern void hl_bilinear_backward(real* inGrad, extern void hl_bilinear_backward(real* inGrad,
const size_t inImgH, const size_t inImgH,
const size_t inImgW, const size_t inImgW,

@ -16,7 +16,7 @@ limitations under the License. */
#include <sys/time.h> #include <sys/time.h>
#include "hl_cuda.h" #include "hl_cuda.h"
#include "hl_thread.ph" #include "hl_thread.ph"
#include "paddle/utils/DynamicLoad.h" #include "paddle/utils/DynamicLoader.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
namespace dynload { namespace dynload {

@ -17,7 +17,7 @@ limitations under the License. */
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include "hl_cuda_cudnn.ph" #include "hl_cuda_cudnn.ph"
#include "hl_thread.ph" #include "hl_thread.ph"
#include "paddle/utils/DynamicLoad.h" #include "paddle/utils/DynamicLoader.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
DEFINE_int32(cudnn_conv_workspace_limit_in_mb, DEFINE_int32(cudnn_conv_workspace_limit_in_mb,

@ -24,7 +24,7 @@ limitations under the License. */
#include "hl_cuda.ph" #include "hl_cuda.ph"
#include "hl_thread.ph" #include "hl_thread.ph"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/DynamicLoad.h" #include "paddle/utils/DynamicLoader.h"
// clang-format on // clang-format on
namespace dynload { namespace dynload {

@ -14,7 +14,7 @@ limitations under the License. */
#include "hl_warpctc_wrap.h" #include "hl_warpctc_wrap.h"
#include <mutex> #include <mutex>
#include "paddle/utils/DynamicLoad.h" #include "paddle/utils/DynamicLoader.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
namespace dynload { namespace dynload {

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "BufferArg.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "BufferArg.h"
#include "paddle/math/MemoryHandle.h" #include "paddle/math/MemoryHandle.h"
namespace paddle { namespace paddle {

@ -165,11 +165,11 @@ void CosSimBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad,
real reciprocal_square_sum_x = 1.0f / square_sum_x; real reciprocal_square_sum_x = 1.0f / square_sum_x;
real reciprocal_square_sum_y = 1.0f / square_sum_y; real reciprocal_square_sum_y = 1.0f / square_sum_y;
for (size_t j = 0; j < dim; ++j) { for (size_t j = 0; j < dim; ++j) {
prev_grad_x[j] += prev_grad_x[j] += out[i] * grad[i] *
out[i] * grad[i] * (prev_out_y[j] * reciprocal_xy - (prev_out_y[j] * reciprocal_xy -
prev_out_x[j] * reciprocal_square_sum_x); prev_out_x[j] * reciprocal_square_sum_x);
prev_grad_y[j] += prev_grad_y[j] += out[i] * grad[i] *
out[i] * grad[i] * (prev_out_x[j] * reciprocal_xy - (prev_out_x[j] * reciprocal_xy -
prev_out_y[j] * reciprocal_square_sum_y); prev_out_y[j] * reciprocal_square_sum_y);
} }
} }

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "Function.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "Function.h"
#include "paddle/math/SparseMatrix.h" #include "paddle/math/SparseMatrix.h"
namespace paddle { namespace paddle {

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "TensorShape.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "TensorShape.h"
namespace paddle { namespace paddle {

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "TensorType.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "TensorType.h"
namespace paddle { namespace paddle {

@ -194,8 +194,8 @@ void PyDataProvider::fillSlotsByStr(const std::string& samples) {
auto& slot = slots_[j]; auto& slot = slots_[j];
CHECK(SlotDef::INDEX >= slot.type || SlotDef::STRING == slot.type) CHECK(SlotDef::INDEX >= slot.type || SlotDef::STRING == slot.type)
<< " Slot type:" << slot.type << " is out of range."; << " Slot type:" << slot.type << " is out of range.";
CHECK_GE(slot.type, SlotDef::VECTOR_DENSE) << " Slot type:" << slot.type CHECK_GE(slot.type, SlotDef::VECTOR_DENSE)
<< " is out of range."; << " Slot type:" << slot.type << " is out of range.";
switch (slot.type) { switch (slot.type) {
case SlotDef::VECTOR_DENSE: case SlotDef::VECTOR_DENSE:
fillDenseSlot(slot, data, dataEnd); fillDenseSlot(slot, data, dataEnd);

@ -446,8 +446,8 @@ real AucEvaluator::evalImp(std::vector<Argument>& arguments) {
for (size_t i = 0; i < insNum; ++i) { for (size_t i = 0; i < insNum; ++i) {
real value = outputD[pos]; real value = outputD[pos];
uint32_t binIdx = static_cast<uint32_t>(value * kBinNum_); uint32_t binIdx = static_cast<uint32_t>(value * kBinNum_);
CHECK(binIdx <= kBinNum_) << "bin index [" << binIdx CHECK(binIdx <= kBinNum_)
<< "] out of range, predict value[" << value << "bin index [" << binIdx << "] out of range, predict value[" << value
<< "]"; << "]";
real w = supportWeight ? weightD[i] : 1.0; real w = supportWeight ? weightD[i] : 1.0;
if (labelD[i] == kNegativeLabel_) { if (labelD[i] == kNegativeLabel_) {

@ -21,7 +21,6 @@ limitations under the License. */
#include "MultiGradientMachine.h" #include "MultiGradientMachine.h"
#include "MultiNetwork.h" #include "MultiNetwork.h"
#include "NeuralNetwork.h" #include "NeuralNetwork.h"
#include "NeuralNetwork.h"
#include "ParallelNeuralNetwork.h" #include "ParallelNeuralNetwork.h"
#include "hl_gpu.h" #include "hl_gpu.h"

@ -637,7 +637,7 @@ void RecurrentGradientMachine::removeBeamSearchStatisticsCallbacks() {
/* create scattered id infomation for all realLayer of inFrameLines one time. /* create scattered id infomation for all realLayer of inFrameLines one time.
* If hasSubseq, will also create scattered sequenceStartPositions infomation * If hasSubseq, will also create scattered sequenceStartPositions infomation
* for all realLayer of inFrameLines one time. * for all realLayer of inFrameLines one time.
*/ */
void RecurrentGradientMachine::createInFrameInfo(int inlinkId, void RecurrentGradientMachine::createInFrameInfo(int inlinkId,
const Argument& input, const Argument& input,

@ -263,8 +263,9 @@ void Layer::zeroGrad() {
} }
void Layer::initNeedFlags() { void Layer::initNeedFlags() {
auto initFlag = [this]( auto initFlag = [this](bool& flag,
bool& flag, bool (Layer::*flagQueryFunc)() const, ParameterType type) { bool (Layer::*flagQueryFunc)() const,
ParameterType type) {
flag = false; flag = false;
if (biasParameter_ && biasParameter_->hasType(type)) { if (biasParameter_ && biasParameter_->hasType(type)) {
flag = true; flag = true;

@ -29,7 +29,7 @@ namespace paddle {
* *
* The config file api is rotate_layer * The config file api is rotate_layer
* *
*/ */
class RotateLayer : public Layer { class RotateLayer : public Layer {
public: public:

@ -292,8 +292,8 @@ void checkRecurrentLayer(LayerConfig layerConfig,
TestRecurrentLayer<T> testGpu(layerConfig, true, gpuBatch); TestRecurrentLayer<T> testGpu(layerConfig, true, gpuBatch);
testCpu.init(batchSize); testCpu.init(batchSize);
testGpu.init(batchSize); testGpu.init(batchSize);
auto checkError = []( auto checkError =
MatrixPtr cpu, MatrixPtr gpu, int numSequences, const char* str) { [](MatrixPtr cpu, MatrixPtr gpu, int numSequences, const char* str) {
CpuMatrix check(gpu->getHeight(), gpu->getWidth()); CpuMatrix check(gpu->getHeight(), gpu->getWidth());
check.copyFrom(*gpu); check.copyFrom(*gpu);
int height = cpu->getHeight(); int height = cpu->getHeight();
@ -303,7 +303,8 @@ void checkRecurrentLayer(LayerConfig layerConfig,
int count = 0; int count = 0;
for (int i = 0; i < height; i++) { for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) { for (int j = 0; j < width; j++) {
if (fabs(data1[i * width + j] - data2[i * width + j]) / numSequences > if (fabs(data1[i * width + j] - data2[i * width + j]) /
numSequences >
1e-4) { 1e-4) {
count++; count++;
} }

@ -15,7 +15,7 @@ limitations under the License. */
#include "MathFunctions.h" #include "MathFunctions.h"
#include "hl_matrix_apply.cuh" #include "hl_matrix_apply.cuh"
#include "hl_matrix_ops.cuh" #include "hl_matrix_ops.cuh"
#include "paddle/utils/DynamicLoad.h" #include "paddle/utils/DynamicLoader.h"
namespace dynload { namespace dynload {
@ -32,7 +32,7 @@ void* lapack_dso_handle = nullptr;
#define DYNAMIC_LOAD_LAPACK_WRAP(__name) \ #define DYNAMIC_LOAD_LAPACK_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args)->decltype(__name(args...)) { \ auto operator()(Args... args) -> decltype(__name(args...)) { \
using lapack_func = decltype(__name(args...)) (*)(Args...); \ using lapack_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(lapack_dso_flag, GetLapackDsoHandle, &lapack_dso_handle); \ std::call_once(lapack_dso_flag, GetLapackDsoHandle, &lapack_dso_handle); \
void* p_##__name = dlsym(lapack_dso_handle, #__name); \ void* p_##__name = dlsym(lapack_dso_handle, #__name); \
@ -41,24 +41,27 @@ void* lapack_dso_handle = nullptr;
} __name; // struct DynLoad__##__name } __name; // struct DynLoad__##__name
// clang-format off // clang-format off
#ifdef PADDLE_USE_LAPACK
#ifdef PADDLE_USE_ATLAS #ifdef PADDLE_USE_ATLAS
#define LAPACK_ROUTINE_EACH(__macro) \ #define PADDLE_SGETRF clapack_sgetrf
__macro(clapack_sgetrf) \ #define PADDLE_DGETRF clapack_dgetrf
__macro(clapack_dgetrf) \ #define PADDLE_SGETRI clapack_sgetri
__macro(clapack_sgetri) \ #define PADDLE_DGETRI clapack_dgetri
__macro(clapack_dgetri)
#else #else
#define LAPACK_ROUTINE_EACH(__macro) \ #define PADDLE_SGETRF LAPACKE_sgetrf
__macro(LAPACKE_sgetrf) \ #define PADDLE_DGETRF LAPACKE_dgetrf
__macro(LAPACKE_dgetrf) \ #define PADDLE_SGETRI LAPACKE_sgetri
__macro(LAPACKE_sgetri) \ #define PADDLE_DGETRI LAPACKE_dgetri
__macro(LAPACKE_dgetri)
#endif
LAPACK_ROUTINE_EACH(DYNAMIC_LOAD_LAPACK_WRAP)
#endif #endif
#define LAPACK_ROUTINE_EACH(__macro) \
__macro(PADDLE_SGETRF) \
__macro(PADDLE_DGETRF) \
__macro(PADDLE_SGETRI) \
__macro(PADDLE_DGETRI)
// clang-format on // clang-format on
LAPACK_ROUTINE_EACH(DYNAMIC_LOAD_LAPACK_WRAP)
} // namespace dynload } // namespace dynload
namespace paddle { namespace paddle {
@ -130,16 +133,7 @@ int getrf<float>(const CBLAS_ORDER order,
float* A, float* A,
const int lda, const int lda,
int* ipiv) { int* ipiv) {
#ifdef PADDLE_USE_LAPACK return dynload::PADDLE_SGETRF(order, M, N, A, lda, ipiv);
#ifdef PADDLE_USE_ATLAS
return dynload::clapack_sgetrf(order, M, N, A, lda, ipiv);
#else
return dynload::LAPACKE_sgetrf(order, M, N, A, lda, ipiv);
#endif
#else
LOG(FATAL) << "Not implemented";
#endif
return 0;
} }
template <> template <>
@ -149,16 +143,7 @@ int getrf<double>(const CBLAS_ORDER order,
double* A, double* A,
const int lda, const int lda,
int* ipiv) { int* ipiv) {
#ifdef PADDLE_USE_LAPACK return dynload::PADDLE_DGETRF(order, M, N, A, lda, ipiv);
#ifdef PADDLE_USE_ATLAS
return dynload::clapack_dgetrf(order, M, N, A, lda, ipiv);
#else
return dynload::LAPACKE_dgetrf(order, M, N, A, lda, ipiv);
#endif
#else
LOG(FATAL) << "Not implemented";
#endif
return 0;
} }
template <> template <>
@ -167,16 +152,7 @@ int getri<float>(const CBLAS_ORDER order,
float* A, float* A,
const int lda, const int lda,
const int* ipiv) { const int* ipiv) {
#ifdef PADDLE_USE_LAPACK return dynload::PADDLE_SGETRI(order, N, A, lda, ipiv);
#ifdef PADDLE_USE_ATLAS
return dynload::clapack_sgetri(order, N, A, lda, ipiv);
#else
return dynload::LAPACKE_sgetri(order, N, A, lda, ipiv);
#endif
#else
LOG(FATAL) << "Not implemented";
#endif
return 0;
} }
template <> template <>
@ -185,15 +161,7 @@ int getri<double>(const CBLAS_ORDER order,
double* A, double* A,
const int lda, const int lda,
const int* ipiv) { const int* ipiv) {
#ifdef PADDLE_USE_LAPACK return dynload::PADDLE_DGETRI(order, N, A, lda, ipiv);
#ifdef PADDLE_USE_ATLAS
return dynload::clapack_dgetri(order, N, A, lda, ipiv);
#else
return dynload::LAPACKE_dgetri(order, N, A, lda, ipiv);
#endif
#else
LOG(FATAL) << "Not implemented";
#endif
return 0; return 0;
} }

@ -17,14 +17,11 @@ limitations under the License. */
#ifdef PADDLE_USE_MKL #ifdef PADDLE_USE_MKL
#include <mkl.h> #include <mkl.h>
#ifdef PADDLE_USE_LAPACK
#include <mkl_lapacke.h> #include <mkl_lapacke.h>
#endif
#else #else
extern "C" { extern "C" {
#include <cblas.h> #include <cblas.h>
} }
#ifdef PADDLE_USE_LAPACK
#ifdef PADDLE_USE_ATLAS #ifdef PADDLE_USE_ATLAS
extern "C" { extern "C" {
#include <clapack.h> #include <clapack.h>
@ -33,7 +30,6 @@ extern "C" {
#include <lapacke.h> #include <lapacke.h>
#endif #endif
#endif #endif
#endif
#include <cmath> #include <cmath>

@ -174,8 +174,10 @@ void CpuMatrix::mulByBitCode(size_t numClasses,
const IVector& codes, const IVector& codes,
const Matrix& weight, const Matrix& weight,
const Matrix& input) { const Matrix& input) {
auto op = []( auto op = [](real& t,
real& t, const real* weightRow, const real* inputRow, size_t inputDim) { const real* weightRow,
const real* inputRow,
size_t inputDim) {
real sum = 0; real sum = 0;
for (size_t k = 0; k < inputDim; ++k) { for (size_t k = 0; k < inputDim; ++k) {
sum += weightRow[k] * inputRow[k]; sum += weightRow[k] * inputRow[k];
@ -193,8 +195,8 @@ void CpuMatrix::mulByBitCodeBackwardWeight(size_t numClasses,
const IVector& codes, const IVector& codes,
Matrix& weight, Matrix& weight,
const Matrix& input) { const Matrix& input) {
auto op = []( auto op =
const real t, real* weightRow, const real* inputRow, size_t inputDim) { [](const real t, real* weightRow, const real* inputRow, size_t inputDim) {
for (size_t k = 0; k < inputDim; ++k) { for (size_t k = 0; k < inputDim; ++k) {
weightRow[k] += t * inputRow[k]; weightRow[k] += t * inputRow[k];
} }
@ -210,8 +212,8 @@ void CpuMatrix::mulByBitCodeBackwardError(size_t numClasses,
const IVector& codes, const IVector& codes,
const Matrix& weight, const Matrix& weight,
Matrix& input) { Matrix& input) {
auto op = []( auto op =
const real t, const real* weightRow, real* inputRow, size_t inputDim) { [](const real t, const real* weightRow, real* inputRow, size_t inputDim) {
for (size_t k = 0; k < inputDim; ++k) { for (size_t k = 0; k < inputDim; ++k) {
inputRow[k] += t * weightRow[k]; inputRow[k] += t * weightRow[k];
} }

@ -183,8 +183,8 @@ void TensorCheck(AssertEq compare,
template <typename AssertEq> template <typename AssertEq>
void TensorCheck(AssertEq compare, real args1, real args2) { void TensorCheck(AssertEq compare, real args1, real args2) {
EXPECT_EQ(compare(args1, args2), true) << "[Test error] args1 = " << args1 EXPECT_EQ(compare(args1, args2), true)
<< ", args2 = " << args2; << "[Test error] args1 = " << args1 << ", args2 = " << args2;
} }
template <typename AssertEq> template <typename AssertEq>

@ -37,7 +37,7 @@ limitations under the License. */
* *
* AutoCompare test; * AutoCompare test;
* test.cmpWithoutArg<I...>(function, height, width) * test.cmpWithoutArg<I...>(function, height, width)
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "TensorCheck.h" #include "TensorCheck.h"

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save