Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into infershape
commit
55f572b2da
@ -0,0 +1,32 @@
|
||||
if(NOT WITH_GPU)
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(ANAKIN_ROOT "/usr" CACHE PATH "ANAKIN ROOT")
|
||||
find_path(ANAKIN_INCLUDE_DIR anakin_config.h
|
||||
PATHS ${ANAKIN_ROOT} ${ANAKIN_ROOT}/include
|
||||
$ENV{ANAKIN_ROOT} $ENV{ANAKIN_ROOT}/include
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
|
||||
find_library(ANAKIN_LIBRARY NAMES libanakin_saber_common.so libanakin.so
|
||||
PATHS ${ANAKIN_ROOT}
|
||||
$ENV{ANAKIN_ROOT} $ENV{ANAKIN_ROOT}/lib
|
||||
NO_DEFAULT_PATH
|
||||
DOC "Path to ANAKIN library.")
|
||||
|
||||
if(ANAKIN_INCLUDE_DIR AND ANAKIN_LIBRARY)
|
||||
if(WITH_DSO)
|
||||
set(ANAKIN_FOUND ON)
|
||||
endif(WITH_DSO)
|
||||
else()
|
||||
set(ANAKIN_FOUND OFF)
|
||||
endif()
|
||||
|
||||
if(ANAKIN_FOUND)
|
||||
message(STATUS "Current ANAKIN header is ${ANAKIN_INCLUDE_DIR}/anakin_config.h. ")
|
||||
include_directories(${ANAKIN_ROOT}/include)
|
||||
include_directories(${ANAKIN_ROOT}/include/saber)
|
||||
link_directories(${ANAKIN_ROOT})
|
||||
add_definitions(-DPADDLE_WITH_ANAKIN)
|
||||
endif()
|
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(DGC_SOURCES_DIR "${THIRD_PARTY_PATH}/dgc")
|
||||
SET(DGC_INSTALL_DIR "${THIRD_PARTY_PATH}/install/dgc")
|
||||
SET(DGC_INCLUDE_DIR "${DGC_INSTALL_DIR}/include" CACHE PATH "dgc include directory." FORCE)
|
||||
SET(DGC_LIBRARIES "${DGC_INSTALL_DIR}/lib/libdgc.a" CACHE FILEPATH "dgc library." FORCE)
|
||||
INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR})
|
||||
|
||||
ExternalProject_Add(
|
||||
extern_dgc
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
GIT_REPOSITORY "https://github.com/PaddlePaddle/Fleet"
|
||||
GIT_TAG "2d04dc3800cdd0601f1b65d547dabcc60b0cf9dc"
|
||||
SOURCE_DIR "${DGC_SOURCES_DIR}"
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND cd collective && make -j
|
||||
INSTALL_COMMAND mkdir -p ${DGC_INSTALL_DIR}/lib/ ${DGC_INCLUDE_DIR}/dgc
|
||||
&& cp ${DGC_SOURCES_DIR}/collective/build/lib/libdgc.a ${DGC_LIBRARIES}
|
||||
&& cp ${DGC_SOURCES_DIR}/collective/build/include/dgc.h ${DGC_INCLUDE_DIR}/dgc/
|
||||
BUILD_IN_SOURCE 1
|
||||
)
|
||||
|
||||
ADD_LIBRARY(dgc STATIC IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
|
||||
ADD_DEPENDENCIES(dgc extern_dgc)
|
||||
|
||||
LIST(APPEND external_project_dependencies dgc)
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,157 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_feed.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// Dataset is a abstract class, which defines user interfaces
|
||||
// Example Usage:
|
||||
// Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
|
||||
// dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
|
||||
// dataset->SetThreadNum(1)
|
||||
// dataset->CreateReaders();
|
||||
// dataset->SetDataFeedDesc(your_data_feed_desc);
|
||||
// dataset->LoadIntoMemory();
|
||||
// dataset->SetTrainerNum(2);
|
||||
// dataset->GlobalShuffle();
|
||||
class Dataset {
|
||||
public:
|
||||
Dataset() {}
|
||||
virtual ~Dataset() {}
|
||||
// set file list
|
||||
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
|
||||
// set readers' num
|
||||
virtual void SetThreadNum(int thread_num) = 0;
|
||||
// set workers' num
|
||||
virtual void SetTrainerNum(int trainer_num) = 0;
|
||||
// set fleet send batch size
|
||||
virtual void SetFleetSendBatchSize(int64_t size) = 0;
|
||||
// set fs name and ugi
|
||||
virtual void SetHdfsConfig(const std::string& fs_name,
|
||||
const std::string& fs_ugi) = 0;
|
||||
// set data fedd desc, which contains:
|
||||
// data feed name, batch size, slots
|
||||
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
|
||||
// get file list
|
||||
virtual const std::vector<std::string>& GetFileList() = 0;
|
||||
// get thread num
|
||||
virtual int GetThreadNum() = 0;
|
||||
// get worker num
|
||||
virtual int GetTrainerNum() = 0;
|
||||
// get fleet send batch size
|
||||
virtual int64_t GetFleetSendBatchSize() = 0;
|
||||
// get hdfs config
|
||||
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
|
||||
// get data fedd desc
|
||||
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
|
||||
// get readers, the reader num depend both on thread num
|
||||
// and filelist size
|
||||
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
||||
GetReaders() = 0;
|
||||
// register message handler between workers
|
||||
virtual void RegisterClientToClientMsgHandler() = 0;
|
||||
// load all data into memory
|
||||
virtual void LoadIntoMemory() = 0;
|
||||
// release all memory data
|
||||
virtual void ReleaseMemory() = 0;
|
||||
// local shuffle data
|
||||
virtual void LocalShuffle() = 0;
|
||||
// global shuffle data
|
||||
virtual void GlobalShuffle() = 0;
|
||||
// create readers
|
||||
virtual void CreateReaders() = 0;
|
||||
// destroy readers
|
||||
virtual void DestroyReaders() = 0;
|
||||
|
||||
protected:
|
||||
virtual int ReceiveFromClient(int msg_type, int client_id,
|
||||
const std::string& msg) = 0;
|
||||
};
|
||||
|
||||
// DatasetImpl is the implementation of Dataset,
|
||||
// it holds memory data if user calls load_into_memory
|
||||
template <typename T>
|
||||
class DatasetImpl : public Dataset {
|
||||
public:
|
||||
DatasetImpl();
|
||||
virtual ~DatasetImpl() {}
|
||||
|
||||
virtual void SetFileList(const std::vector<std::string>& filelist);
|
||||
virtual void SetThreadNum(int thread_num);
|
||||
virtual void SetTrainerNum(int trainer_num);
|
||||
virtual void SetFleetSendBatchSize(int64_t size);
|
||||
virtual void SetHdfsConfig(const std::string& fs_name,
|
||||
const std::string& fs_ugi);
|
||||
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
|
||||
|
||||
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
|
||||
virtual int GetThreadNum() { return thread_num_; }
|
||||
virtual int GetTrainerNum() { return trainer_num_; }
|
||||
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
|
||||
virtual std::pair<std::string, std::string> GetHdfsConfig() {
|
||||
return std::make_pair(fs_name_, fs_ugi_);
|
||||
}
|
||||
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
|
||||
return data_feed_desc_;
|
||||
}
|
||||
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
||||
GetReaders();
|
||||
|
||||
virtual void RegisterClientToClientMsgHandler();
|
||||
virtual void LoadIntoMemory();
|
||||
virtual void ReleaseMemory();
|
||||
virtual void LocalShuffle();
|
||||
virtual void GlobalShuffle();
|
||||
virtual void CreateReaders();
|
||||
virtual void DestroyReaders();
|
||||
|
||||
protected:
|
||||
virtual int ReceiveFromClient(int msg_type, int client_id,
|
||||
const std::string& msg);
|
||||
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
|
||||
std::vector<T> memory_data_;
|
||||
std::mutex mutex_for_update_memory_data_;
|
||||
int thread_num_;
|
||||
paddle::framework::DataFeedDesc data_feed_desc_;
|
||||
int trainer_num_;
|
||||
std::vector<std::string> filelist_;
|
||||
size_t file_idx_;
|
||||
std::mutex mutex_for_pick_file_;
|
||||
std::string fs_name_;
|
||||
std::string fs_ugi_;
|
||||
unsigned int rand_seed;
|
||||
int64_t fleet_send_batch_size_;
|
||||
};
|
||||
|
||||
// use std::vector<MultiSlotType> as data type
|
||||
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
|
||||
public:
|
||||
MultiSlotDataset() {}
|
||||
virtual ~MultiSlotDataset() {}
|
||||
};
|
||||
|
||||
} // end namespace framework
|
||||
} // end namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue