Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_bn_eq
commit
90e05a4b8c
@ -0,0 +1,188 @@
|
||||
if(NOT WITH_GPU)
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
|
||||
set(paddle_known_gpu_archs7 "30 35 50 52")
|
||||
set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
|
||||
|
||||
######################################################################################
|
||||
# A function for automatic detection of GPUs installed (if autodetection is enabled)
|
||||
# Usage:
|
||||
# detect_installed_gpus(out_variable)
|
||||
function(detect_installed_gpus out_variable)
|
||||
if(NOT CUDA_gpu_detect_output)
|
||||
set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu)
|
||||
|
||||
file(WRITE ${cufile} ""
|
||||
"#include <cstdio>\n"
|
||||
"int main() {\n"
|
||||
" int count = 0;\n"
|
||||
" if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
|
||||
" if (count == 0) return -1;\n"
|
||||
" for (int device = 0; device < count; ++device) {\n"
|
||||
" cudaDeviceProp prop;\n"
|
||||
" if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
|
||||
" std::printf(\"%d.%d \", prop.major, prop.minor);\n"
|
||||
" }\n"
|
||||
" return 0;\n"
|
||||
"}\n")
|
||||
|
||||
execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "-ccbin=${CUDA_HOST_COMPILER}"
|
||||
"--run" "${cufile}"
|
||||
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
|
||||
RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
if(nvcc_res EQUAL 0)
|
||||
# only keep the last line of nvcc_out
|
||||
STRING(REGEX REPLACE ";" "\\\\;" nvcc_out "${nvcc_out}")
|
||||
STRING(REGEX REPLACE "\n" ";" nvcc_out "${nvcc_out}")
|
||||
list(GET nvcc_out -1 nvcc_out)
|
||||
string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}")
|
||||
set(CUDA_gpu_detect_output ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_installed_gpus tool" FORCE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT CUDA_gpu_detect_output)
|
||||
message(STATUS "Automatic GPU detection failed. Building for all known architectures.")
|
||||
set(${out_variable} ${paddle_known_gpu_archs} PARENT_SCOPE)
|
||||
else()
|
||||
set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
########################################################################
|
||||
# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME
|
||||
# Usage:
|
||||
# select_nvcc_arch_flags(out_variable)
|
||||
function(select_nvcc_arch_flags out_variable)
|
||||
# List of arch names
|
||||
set(archs_names "Kepler" "Maxwell" "Pascal" "All" "Manual")
|
||||
set(archs_name_default "All")
|
||||
if(NOT CMAKE_CROSSCOMPILING)
|
||||
list(APPEND archs_names "Auto")
|
||||
endif()
|
||||
|
||||
# set CUDA_ARCH_NAME strings (so it will be seen as dropbox in CMake-Gui)
|
||||
set(CUDA_ARCH_NAME ${archs_name_default} CACHE STRING "Select target NVIDIA GPU achitecture.")
|
||||
set_property( CACHE CUDA_ARCH_NAME PROPERTY STRINGS "" ${archs_names} )
|
||||
mark_as_advanced(CUDA_ARCH_NAME)
|
||||
|
||||
# verify CUDA_ARCH_NAME value
|
||||
if(NOT ";${archs_names};" MATCHES ";${CUDA_ARCH_NAME};")
|
||||
string(REPLACE ";" ", " archs_names "${archs_names}")
|
||||
message(FATAL_ERROR "Only ${archs_names} architeture names are supported.")
|
||||
endif()
|
||||
|
||||
if(${CUDA_ARCH_NAME} STREQUAL "Manual")
|
||||
set(CUDA_ARCH_BIN ${paddle_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported")
|
||||
set(CUDA_ARCH_PTX "50" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for")
|
||||
mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX)
|
||||
else()
|
||||
unset(CUDA_ARCH_BIN CACHE)
|
||||
unset(CUDA_ARCH_PTX CACHE)
|
||||
endif()
|
||||
|
||||
if(${CUDA_ARCH_NAME} STREQUAL "Kepler")
|
||||
set(cuda_arch_bin "30 35")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
|
||||
set(cuda_arch_bin "50")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
|
||||
set(cuda_arch_bin "60 61")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
|
||||
set(cuda_arch_bin "70")
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
|
||||
set(cuda_arch_bin ${paddle_known_gpu_archs})
|
||||
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
|
||||
detect_installed_gpus(cuda_arch_bin)
|
||||
else() # (${CUDA_ARCH_NAME} STREQUAL "Manual")
|
||||
set(cuda_arch_bin ${CUDA_ARCH_BIN})
|
||||
endif()
|
||||
|
||||
# remove dots and convert to lists
|
||||
string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
|
||||
string(REGEX REPLACE "\\." "" cuda_arch_ptx "${CUDA_ARCH_PTX}")
|
||||
string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}")
|
||||
string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}")
|
||||
list(REMOVE_DUPLICATES cuda_arch_bin)
|
||||
list(REMOVE_DUPLICATES cuda_arch_ptx)
|
||||
|
||||
set(nvcc_flags "")
|
||||
set(nvcc_archs_readable "")
|
||||
|
||||
# Tell NVCC to add binaries for the specified GPUs
|
||||
foreach(arch ${cuda_arch_bin})
|
||||
if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
|
||||
# User explicitly specified PTX for the concrete BIN
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
|
||||
list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
|
||||
else()
|
||||
# User didn't explicitly specify PTX for the concrete BIN, we assume PTX=BIN
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
|
||||
list(APPEND nvcc_archs_readable sm_${arch})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Tell NVCC to add PTX intermediate code for the specified architectures
|
||||
foreach(arch ${cuda_arch_ptx})
|
||||
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
|
||||
list(APPEND nvcc_archs_readable compute_${arch})
|
||||
endforeach()
|
||||
|
||||
string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
|
||||
set(${out_variable} ${nvcc_flags} PARENT_SCOPE)
|
||||
set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
message(STATUS "CUDA detected: " ${CUDA_VERSION})
|
||||
if (${CUDA_VERSION} LESS 7.0)
|
||||
set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
|
||||
elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x
|
||||
set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
|
||||
elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
|
||||
set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
|
||||
# CUDA 8 may complain that sm_20 is no longer supported. Suppress the
|
||||
# warning for now.
|
||||
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
|
||||
endif()
|
||||
|
||||
include_directories(${CUDA_INCLUDE_DIRS})
|
||||
list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
|
||||
if(NOT WITH_DSO)
|
||||
list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${NCCL_LIBRARY})
|
||||
endif(NOT WITH_DSO)
|
||||
|
||||
# setting nvcc arch flags
|
||||
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
|
||||
list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
|
||||
message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA_readable}")
|
||||
|
||||
# Set C++11 support
|
||||
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||
|
||||
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
|
||||
# So, don't set these flags here.
|
||||
list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
|
||||
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
|
||||
# Set :expt-relaxed-constexpr to suppress Eigen warnings
|
||||
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
|
||||
|
||||
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG})
|
||||
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
|
||||
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
|
||||
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
|
||||
list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL})
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
|
||||
mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION)
|
File diff suppressed because it is too large
Load Diff
@ -1,179 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "DataFormat.pb.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
#include "DataProvider.h"
|
||||
#include "ProtoReader.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief Provider data from protobuf data file with each sample
|
||||
* specified by proto message
|
||||
*
|
||||
* DataSample defined in DataFormat.proto.
|
||||
*
|
||||
* The file format is
|
||||
*
|
||||
* header
|
||||
*
|
||||
* sample1
|
||||
*
|
||||
* sample2
|
||||
*
|
||||
* ...
|
||||
*
|
||||
* sampleN
|
||||
*
|
||||
* @note: In the data file, each message is prefixed with its length.
|
||||
* The read/write of the protbuf are implemented in ProtoReader.h
|
||||
*/
|
||||
class ProtoDataProvider : public DataProvider {
|
||||
public:
|
||||
ProtoDataProvider(const DataConfig& config,
|
||||
bool useGpu,
|
||||
bool loadDataAll = true);
|
||||
virtual void reset();
|
||||
|
||||
/**
|
||||
* @note this size includes the sequences which are skipped because they
|
||||
* are longer than the batch size.
|
||||
*/
|
||||
virtual int64_t getSize() {
|
||||
int64_t size = sampleNums_;
|
||||
if (usageRatio_ < 1.0f) {
|
||||
size = static_cast<int64_t>(size * usageRatio_);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
virtual void shuffle();
|
||||
|
||||
void loadData(const std::vector<std::string>& fileList);
|
||||
|
||||
virtual int64_t getNextBatchInternal(int64_t size, DataBatch* batch);
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief load protobuf data from a list of file
|
||||
* @param[in] fileName file name of a file which contains
|
||||
* a list of file names
|
||||
*/
|
||||
void loadData(const std::string& fileName);
|
||||
|
||||
/**
|
||||
* @brief load protobuf data from file
|
||||
* @param[in] fileName data file name
|
||||
*/
|
||||
void loadDataFile(const std::string& fileName);
|
||||
/** @brief check data header of each data sample
|
||||
* @param[in] header data header read from protobuf data
|
||||
*/
|
||||
void checkDataHeader(const DataHeader& header);
|
||||
/**
|
||||
* @brief fill protobuf data into slot_,
|
||||
* slot_ is a vector of ProtoSlot in memory.
|
||||
* @param[in] sample data sample read from protobuf data
|
||||
*/
|
||||
void fillSlots(const DataSample& sample);
|
||||
|
||||
/**
|
||||
* @brief return true if each sample is one sequence, i.e., independent
|
||||
* of other samples.
|
||||
*/
|
||||
inline bool iidData() const { return sequenceStartPositions_.empty(); }
|
||||
|
||||
/**
|
||||
* @brief check that sample is consistent with header_
|
||||
*/
|
||||
void checkSample(const DataSample& sample);
|
||||
|
||||
template <class Op>
|
||||
int64_t sequenceLoop(Op op, int64_t size);
|
||||
|
||||
template <class Op>
|
||||
int64_t sampleLoop(Op op, int64_t size);
|
||||
|
||||
template <class Op>
|
||||
int64_t subSampleLoop(Op op, int64_t size, int slot);
|
||||
|
||||
void showDataStats();
|
||||
|
||||
protected:
|
||||
struct ProtoVarSlot {
|
||||
std::vector<real> data;
|
||||
std::vector<int> dims;
|
||||
};
|
||||
|
||||
struct ProtoSlot {
|
||||
SlotDef::SlotType type;
|
||||
int dim;
|
||||
std::vector<int> indexData;
|
||||
std::vector<real> denseData;
|
||||
std::vector<sparse_non_value_t> sparseNonValueData;
|
||||
std::vector<sparse_float_value_t> sparseFloatValueData;
|
||||
std::vector<int64_t> indices;
|
||||
std::vector<int64_t> subIndices;
|
||||
|
||||
std::vector<ProtoVarSlot> varDenseData;
|
||||
std::vector<std::vector<int>> varIndices;
|
||||
std::vector<std::string> strData;
|
||||
};
|
||||
DataHeader header_;
|
||||
int numVecSlots_;
|
||||
|
||||
std::vector<ProtoSlot> slots_;
|
||||
size_t sampleNums_;
|
||||
|
||||
/**
|
||||
* The starting position of each sequence in samples.
|
||||
* The last element should be num of samples.
|
||||
* If empty, each sample is one sequence.
|
||||
*/
|
||||
std::vector<size_t> sequenceStartPositions_;
|
||||
|
||||
int64_t currentSequenceIndex_;
|
||||
|
||||
// The size should be the number of sequences.
|
||||
std::vector<size_t> shuffledSequenceIds_;
|
||||
|
||||
ThreadLocalD<DataBatch> cpuBatch_;
|
||||
ThreadLocalD<DataBatch> gpuBatch_;
|
||||
|
||||
RWLock lock_;
|
||||
std::vector<StatPtr> nnzStats_; // stats for number of none-zeros entries
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Special use for Proto data: instances should contain sparse-non-value
|
||||
* slots
|
||||
* and label.
|
||||
*
|
||||
* @note ProtoSequenceDataProvider treats each SPARSE SLOT as a SEQUENCE
|
||||
*/
|
||||
class ProtoSequenceDataProvider : public ProtoDataProvider {
|
||||
public:
|
||||
ProtoSequenceDataProvider(const DataConfig& config,
|
||||
bool useGpu,
|
||||
bool loadDataAll = true);
|
||||
~ProtoSequenceDataProvider() {}
|
||||
virtual int64_t getNextBatchInternal(int64_t size, DataBatch* batch);
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,97 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "Layer.h"
|
||||
#include "paddle/math/Matrix.h"
|
||||
#include "paddle/utils/Logging.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief A layer for computing the dot product of two vectors.
|
||||
* Input1: vector (batchSize * dim)
|
||||
* Input2: vector (batchSize * dim)
|
||||
* Output: a matrix: (batchSize * 1)
|
||||
*/
|
||||
|
||||
class DotProdLayer : public Layer {
|
||||
public:
|
||||
explicit DotProdLayer(const LayerConfig& config) : Layer(config) {}
|
||||
|
||||
~DotProdLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback = nullptr) override;
|
||||
};
|
||||
|
||||
REGISTER_LAYER(dot_prod, DotProdLayer);
|
||||
|
||||
bool DotProdLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
Layer::init(layerMap, parameterMap);
|
||||
|
||||
CHECK_EQ(inputLayers_.size(), 2U);
|
||||
CHECK_EQ(1UL, getSize())
|
||||
<< "The output dimensionality of this layer should be fixed to 1.";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void DotProdLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
|
||||
MatrixPtr inV0 = getInputValue(0);
|
||||
MatrixPtr inV1 = getInputValue(1);
|
||||
|
||||
size_t batchSize = inV0->getHeight();
|
||||
CHECK_EQ(inV1->getHeight(), batchSize);
|
||||
CHECK_EQ(inV0->getWidth(), inV1->getWidth());
|
||||
|
||||
{
|
||||
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
|
||||
reserveOutput(batchSize, 1);
|
||||
}
|
||||
|
||||
MatrixPtr outV = getOutputValue();
|
||||
{
|
||||
REGISTER_TIMER_INFO("FwDotProdTimer", getName().c_str());
|
||||
outV->sumOfProducts(*inV0, *inV1, 1, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void DotProdLayer::backward(const UpdateCallback& callback) {
|
||||
MatrixPtr inV0 = getInputValue(0);
|
||||
MatrixPtr inV1 = getInputValue(1);
|
||||
MatrixPtr outG = getOutputGrad();
|
||||
MatrixPtr inG0 = getInputGrad(0);
|
||||
MatrixPtr inG1 = getInputGrad(1);
|
||||
|
||||
{
|
||||
REGISTER_TIMER_INFO("BwDotProdTimer", getName().c_str());
|
||||
|
||||
if (inG0) {
|
||||
inG0->addRowScale(0, *inV1, *outG);
|
||||
}
|
||||
|
||||
if (inG1) {
|
||||
inG1->addRowScale(0, *inV0, *outG);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "L2DistanceLayer.h"
|
||||
#include "paddle/utils/Logging.h"
|
||||
#include "paddle/utils/Stat.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
REGISTER_LAYER(l2_distance, L2DistanceLayer);
|
||||
|
||||
bool L2DistanceLayer::init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) {
|
||||
/* Initialize the basic parent class */
|
||||
Layer::init(layerMap, parameterMap);
|
||||
|
||||
CHECK_EQ(inputLayers_.size(), 2UL) << "The L2DistanceLayer accepts two and "
|
||||
<< "only two inputs.";
|
||||
CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2DistanceLayer "
|
||||
<< "is fixed to be 1.";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void L2DistanceLayer::forward(PassType passType) {
|
||||
Layer::forward(passType);
|
||||
|
||||
const auto inV1 = getInputValue(0);
|
||||
const auto inV2 = getInputValue(1);
|
||||
|
||||
CHECK(inV1 && inV2);
|
||||
CHECK_EQ(inV1->getHeight(), inV2->getHeight())
|
||||
<< "The height of two inputs of this layer must be the same.";
|
||||
CHECK_EQ(inV1->getWidth(), inV2->getWidth())
|
||||
<< "The width of two inputs of this layer must be the same.";
|
||||
|
||||
int batchSize = inV1->getHeight();
|
||||
int output_dim = getSize();
|
||||
{
|
||||
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
|
||||
reserveOutput(batchSize, output_dim);
|
||||
auto outV = getOutputValue();
|
||||
CHECK(outV) << "The output matrix should not be null.";
|
||||
|
||||
Matrix::resizeOrCreate(
|
||||
inputSub_, inV1->getHeight(), inV1->getWidth(), false, useGpu_);
|
||||
|
||||
inputSub_->assign(*inV1);
|
||||
inputSub_->sub(*inV2);
|
||||
outV->sumOfProducts(*inputSub_, *inputSub_, 1, 0);
|
||||
outV->sqrt2(*outV);
|
||||
}
|
||||
}
|
||||
|
||||
void L2DistanceLayer::backward(const UpdateCallback& callback) {
|
||||
const auto outG = getOutputGrad();
|
||||
const auto outV = getOutputValue();
|
||||
CHECK(outG && outV);
|
||||
|
||||
auto inGrad1 = getInputGrad(0);
|
||||
auto inGrad2 = getInputGrad(1);
|
||||
|
||||
{
|
||||
REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str());
|
||||
|
||||
if (inGrad1 || inGrad2) {
|
||||
outV->scalarDiv(*outV, 1.);
|
||||
outV->dotMul(*outG, *outV);
|
||||
}
|
||||
|
||||
if (inGrad1) inGrad1->addRowScale(0, *inputSub_, *outV);
|
||||
|
||||
if (inGrad2) {
|
||||
inputSub_->mulScalar(-1.);
|
||||
inGrad2->addRowScale(0, *inputSub_, *outV);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "Layer.h"
|
||||
#include "paddle/math/Matrix.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* @brief The layer calculates the l2 distance between two input vectors.
|
||||
* \f[
|
||||
* f(\bf{x}, \bf{y}) = \sqrt{\sum_{i=1}^D(x_i - y_i)}
|
||||
* \f]
|
||||
*
|
||||
* - Input1: A vector (batchSize * dataDim)
|
||||
* - Input2: A vector (batchSize * dataDim)
|
||||
* - Output: A vector (batchSize * 1)
|
||||
*
|
||||
* The configuration api is: l2_distance_layer.
|
||||
*/
|
||||
|
||||
class L2DistanceLayer : public Layer {
|
||||
public:
|
||||
explicit L2DistanceLayer(const LayerConfig& config) : Layer(config) {}
|
||||
~L2DistanceLayer() {}
|
||||
|
||||
bool init(const LayerMap& layerMap,
|
||||
const ParameterMap& parameterMap) override;
|
||||
|
||||
void forward(PassType passType) override;
|
||||
void backward(const UpdateCallback& callback = nullptr) override;
|
||||
|
||||
private:
|
||||
// Store the result of subtracting Input2 from Input1 in forward computation,
|
||||
// which will be reused in backward computation.
|
||||
MatrixPtr inputSub_;
|
||||
};
|
||||
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue