Merge branch 'develop' of https://github.com/paddlepaddle/paddle into memory_cpu_allocator

gangliao-patch-1
Yi Wang 8 years ago
commit 67481ca871

@ -2,7 +2,6 @@ group: deprecated-2017Q2
language: cpp language: cpp
cache: cache:
directories: directories:
- $HOME/third_party
- $HOME/.ccache - $HOME/.ccache
- $HOME/.cache/pip - $HOME/.cache/pip
sudo: required sudo: required
@ -10,15 +9,13 @@ dist: trusty
os: os:
- linux - linux
env: env:
- JOB=DOCS - JOB=build_doc
- JOB=BUILD_AND_TEST - JOB=check_style
- JOB=PRE_COMMIT
addons: addons:
apt: apt:
packages: packages:
- gcc-4.8 - gcc-4.8
- g++-4.8 - g++-4.8
- gfortran-4.8
- git - git
- build-essential - build-essential
- python - python
@ -35,18 +32,7 @@ addons:
- libtool - libtool
- ccache - ccache
before_install: before_install:
- | - if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
if [ ${JOB} == "BUILD_AND_TEST" ]; then
local change_list=`git diff --name-only $TRAVIS_COMMIT_RANGE`
if [ $? -eq 0 ]; then # if git diff return no zero, then rerun unit test.
if ! echo ${change_list} | grep -qvE '(\.md$)|(\.rst$)|(\.jpg$)|(\.png$)'
then
echo "Only markdown docs were updated, stopping build process."
exit
fi
fi
fi
- if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version. # protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker - pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker
@ -55,9 +41,7 @@ before_install:
- | - |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script: script:
- | - paddle/scripts/travis/$JOB.sh
timeout 2580 paddle/scripts/travis/main.sh # 43min timeout
RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else false; fi;
notifications: notifications:
email: email:
on_success: change on_success: change

@ -25,7 +25,7 @@ COPY ./paddle/scripts/docker/root/ /root/
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y \ apt-get install -y \
git python-pip python-dev openssh-server bison \ git python-pip python-dev openssh-server bison \
wget unzip tar xz-utils bzip2 gzip coreutils \ wget unzip tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \ curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-numpy python-matplotlib gcc g++ \ python-numpy python-matplotlib gcc g++ \
automake locales clang-format-3.8 swig doxygen cmake \ automake locales clang-format-3.8 swig doxygen cmake \

@ -21,7 +21,8 @@ IF(NOT ${CBLAS_FOUND})
SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas)
SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE)
SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/${LIBRARY_PREFIX}openblas${STATIC_LIBRARY_SUFFIX}" SET(CBLAS_LIBRARIES
"${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}"
CACHE FILEPATH "openblas library." FORCE) CACHE FILEPATH "openblas library." FORCE)
SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs) SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs)

@ -14,11 +14,41 @@
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
# Print and set the protobuf library information,
# finish this cmake process and exit from this file.
macro(PROMPT_PROTOBUF_LIB) macro(PROMPT_PROTOBUF_LIB)
SET(protobuf_DEPS ${ARGN})
MESSAGE(STATUS "Protobuf protoc executable: ${PROTOBUF_PROTOC_EXECUTABLE}") MESSAGE(STATUS "Protobuf protoc executable: ${PROTOBUF_PROTOC_EXECUTABLE}")
MESSAGE(STATUS "Protobuf library: ${PROTOBUF_LIBRARY}") MESSAGE(STATUS "Protobuf library: ${PROTOBUF_LIBRARY}")
MESSAGE(STATUS "Protobuf version: ${PROTOBUF_VERSION}") MESSAGE(STATUS "Protobuf version: ${PROTOBUF_VERSION}")
INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR})
# Assuming that all the protobuf libraries are of the same type.
IF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_STATIC_LIBRARY_SUFFIX}$")
SET(protobuf_LIBTYPE STATIC)
ELSEIF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_SHARED_LIBRARY_SUFFIX}$")
SET(protobuf_LIBTYPE SHARED)
ELSE()
MESSAGE(FATAL_ERROR "Unknown library type: ${PROTOBUF_LIBRARY}")
ENDIF()
ADD_LIBRARY(protobuf ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protobuf PROPERTY IMPORTED_LOCATION ${PROTOBUF_LIBRARY})
ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY})
ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY})
FOREACH(dep ${protobuf_DEPS})
ADD_DEPENDENCIES(protobuf ${dep})
ADD_DEPENDENCIES(protobuf_lite ${dep})
ADD_DEPENDENCIES(protoc ${dep})
ENDFOREACH()
LIST(APPEND external_project_dependencies protobuf)
RETURN() RETURN()
endmacro() endmacro()
macro(SET_PROTOBUF_VERSION) macro(SET_PROTOBUF_VERSION)
@ -43,22 +73,23 @@ if (NOT "${PROTOBUF_ROOT}" STREQUAL "")
endif() endif()
FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/${TARGET_NAME}) STRING(REPLACE "extern_" "" TARGET_DIR_NAME "${TARGET_NAME}")
SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/${TARGET_NAME}) SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/${TARGET_DIR_NAME})
SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/${TARGET_DIR_NAME})
SET(${TARGET_NAME}_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) SET(${TARGET_NAME}_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE)
SET(PROTOBUF_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) SET(PROTOBUF_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE)
SET(${TARGET_NAME}_LITE_LIBRARY SET(${TARGET_NAME}_LITE_LIBRARY
"${PROTOBUF_INSTALL_DIR}/lib/libprotobuf-lite${STATIC_LIBRARY_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf-lite${CMAKE_STATIC_LIBRARY_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(${TARGET_NAME}_LIBRARY SET(${TARGET_NAME}_LIBRARY
"${PROTOBUF_INSTALL_DIR}/lib/libprotobuf${STATIC_LIBRARY_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf${CMAKE_STATIC_LIBRARY_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(${TARGET_NAME}_PROTOC_LIBRARY SET(${TARGET_NAME}_PROTOC_LIBRARY
"${PROTOBUF_INSTALL_DIR}/lib/libprotoc${STATIC_LIBRARY_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/lib/libprotoc${CMAKE_STATIC_LIBRARY_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(${TARGET_NAME}_PROTOC_EXECUTABLE SET(${TARGET_NAME}_PROTOC_EXECUTABLE
"${PROTOBUF_INSTALL_DIR}/bin/protoc${EXECUTABLE_SUFFIX}" "${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}"
PARENT_SCOPE) PARENT_SCOPE)
SET(OPTIONAL_CACHE_ARGS "") SET(OPTIONAL_CACHE_ARGS "")
@ -109,6 +140,8 @@ IF(NOT CMAKE_CROSSCOMPILING)
SET_PROTOBUF_VERSION() SET_PROTOBUF_VERSION()
IF("${PROTOBUF_VERSION}" VERSION_LESS "3.1.0") IF("${PROTOBUF_VERSION}" VERSION_LESS "3.1.0")
SET(PROTOBUF_FOUND OFF) SET(PROTOBUF_FOUND OFF)
ELSE()
PROMPT_PROTOBUF_LIB()
ENDIF() ENDIF()
ENDIF(PROTOBUF_FOUND) ENDIF(PROTOBUF_FOUND)
ELSE() ELSE()
@ -120,18 +153,22 @@ ELSE()
ENDIF() ENDIF()
IF(NOT PROTOBUF_FOUND) IF(NOT PROTOBUF_FOUND)
build_protobuf(protobuf FALSE) build_protobuf(extern_protobuf FALSE)
LIST(APPEND external_project_dependencies protobuf)
SET(PROTOBUF_INCLUDE_DIR ${protobuf_INCLUDE_DIR} SET(PROTOBUF_INCLUDE_DIR ${extern_protobuf_INCLUDE_DIR}
CACHE PATH "protobuf include directory." FORCE) CACHE PATH "protobuf include directory." FORCE)
IF(NOT CMAKE_CROSSCOMPILING) SET(PROTOBUF_LITE_LIBRARY ${extern_protobuf_LITE_LIBRARY}
SET(PROTOBUF_PROTOC_EXECUTABLE ${protobuf_PROTOC_EXECUTABLE} CACHE FILEPATH "protobuf lite library." FORCE)
SET(PROTOBUF_LIBRARY ${extern_protobuf_LIBRARY}
CACHE FILEPATH "protobuf library." FORCE)
SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY}
CACHE FILEPATH "protoc library." FORCE)
IF(CMAKE_CROSSCOMPILING)
PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf)
ELSE()
SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE}
CACHE FILEPATH "protobuf executable." FORCE) CACHE FILEPATH "protobuf executable." FORCE)
PROMPT_PROTOBUF_LIB(extern_protobuf)
ENDIF() ENDIF()
SET(PROTOBUF_LITE_LIBRARY ${protobuf_LITE_LIBRARY} CACHE FILEPATH "protobuf lite library." FORCE)
SET(PROTOBUF_LIBRARY ${protobuf_LIBRARY} CACHE FILEPATH "protobuf library." FORCE)
SET(PROTOBUF_PROTOC_LIBRARY ${protobuf_PROTOC_LIBRARY} CACHE FILEPATH "protoc library." FORCE)
ENDIF(NOT PROTOBUF_FOUND) ENDIF(NOT PROTOBUF_FOUND)
PROMPT_PROTOBUF_LIB()

@ -84,24 +84,6 @@ IF(DEFINED CMAKE_SYSTEM_NAME)
ENDIF() ENDIF()
ENDIF() ENDIF()
# prefix and suffix on different os
IF(WIN32)
SET(LIBRARY_PREFIX "")
SET(SHARED_LIBRARY_SUFFIX ".dll")
SET(STATIC_LIBRARY_SUFFIX ".lib")
SET(EXECUTABLE_SUFFIX ".exe")
ELSE(WIN32)
SET(LIBRARY_PREFIX "lib")
IF(APPLE)
SET(SHARED_LIBRARY_SUFFIX ".dylib")
ELSE(APPLE)
SET(SHARED_LIBRARY_SUFFIX ".so")
ENDIF(APPLE)
SET(STATIC_LIBRARY_SUFFIX ".a")
SET(EXECUTABLE_SUFFIX "")
ENDIF(WIN32)
# external dependencies log output # external dependencies log output
SET(EXTERNAL_PROJECT_LOG_ARGS SET(EXTERNAL_PROJECT_LOG_ARGS
LOG_DOWNLOAD 0 # Wrap download in script to log output LOG_DOWNLOAD 0 # Wrap download in script to log output

@ -99,3 +99,12 @@ value_printer
.. automodule:: paddle.v2.evaluator .. automodule:: paddle.v2.evaluator
:members: value_printer :members: value_printer
:noindex: :noindex:
Detection
=====
detection_map
-------------
.. automodule:: paddle.v2.evaluator
:members: detection_map
:noindex:

@ -1,45 +1,69 @@
package main package main
import ( import (
"fmt"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
) )
func main() { func main() {
port := flag.Int("port", 8080, "port of the master server.") port := flag.Int("port", 8080, "port of the master server.")
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
flag.Parse() flag.Parse()
if *faultTolerance { if *endpoints == "" {
panic("fault tolernance not implemented.") log.Warningln("-endpoints not set, fault tolerance not be enabled.")
}
var store master.Store
if *endpoints != "" {
eps := strings.Split(*endpoints, ",")
ip, err := networkhelper.GetExternalIP()
if err != nil {
log.Fatal(err)
}
addr := fmt.Sprintf("%s:%d", ip, *port)
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
if err != nil {
log.Fatal(err)
}
} else {
store = &master.InMemStore{}
}
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil {
log.Fatal(err)
} }
s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) err = rpc.Register(s)
err := rpc.Register(s)
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
rpc.HandleHTTP() rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
err = http.Serve(l, nil) err = http.Serve(l, nil)
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
} }

@ -5,18 +5,35 @@ import (
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
) )
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 0, "port of the pserver")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
s := pserver.NewService() level, err := log.ParseLevel(*logLevel)
err := rpc.Register(s) if err != nil {
panic(err)
}
log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, timeout)
if err != nil {
panic(err)
}
err = rpc.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -27,7 +44,9 @@ func main() {
panic(err) panic(err)
} }
log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil) err = http.Serve(l, nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -47,9 +47,13 @@ func TestGetFinishTask(t *testing.T) {
} }
go func(l net.Listener) { go func(l net.Listener) {
s := NewService(chunkPerTask, time.Second, 1) s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
panic(err)
}
server := rpc.NewServer() server := rpc.NewServer()
err := server.Register(s) err = server.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -33,9 +33,13 @@ func TestNextRecord(t *testing.T) {
} }
go func(l net.Listener) { go func(l net.Listener) {
s := master.NewService(10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
if err != nil {
panic(err)
}
server := rpc.NewServer() server := rpc.NewServer()
err := server.Register(s) err = server.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -0,0 +1,144 @@
package master
import (
"context"
"time"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
const (
// DefaultLockPath is the default etcd master lock path.
DefaultLockPath = "/master/lock"
// DefaultStatePath is the default etcd key for master state.
DefaultStatePath = "/master/state"
// DefaultAddrPath is the default etcd key for master address.
DefaultAddrPath = "/master/addr"
)
// EtcdClient is the etcd client that master uses for fault tolerance
// and service registry.
type EtcdClient struct {
lockPath string
statePath string
client *clientv3.Client
lock *concurrency.Mutex
}
// NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints)
// TODO(helin): gracefully shutdown etcd store. Becuase etcd
// store holds a etcd lock, even though the lock will expire
// when the lease timeout, we need to implement graceful
// shutdown to release the lock.
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: dialTimeout,
})
if err != nil {
return nil, err
}
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec))
if err != nil {
return nil, err
}
lock := concurrency.NewMutex(sess, lockPath)
// It's fine for the lock to get stuck, in this case we have
// multiple master servers running (only configured to have
// one master running, but split-brain problem may cuase
// multiple master servers running), and the cluster management
// software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath)
err = lock.Lock(context.TODO())
if err != nil {
return nil, err
}
log.Debugf("Successfully acquired lock at %s.", lockPath)
put := clientv3.OpPut(addrPath, string(addr))
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
if err != nil {
return nil, err
}
if !resp.Succeeded {
log.Fatal("No longer owns the master lock. Exiting.")
}
e := &EtcdClient{
lockPath: lockPath,
statePath: statePath,
client: cli,
lock: lock,
}
return e, nil
}
// Save saves the state into the etcd.
func (e *EtcdClient) Save(state []byte) error {
ctx := context.TODO()
put := clientv3.OpPut(e.statePath, string(state))
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
if err != nil {
return err
}
if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock again")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err := e.lock.Lock(ctx)
cancel()
if err != nil {
// We lost the master lock and can not acquire
// it back, it means some other master is
// already started. We don't want cluster
// managment system to kill the master server
// who is holding the lock and running
// correctly. So the most feasible solution is
// to kill current master server. The current
// state is not saved, but the trainer's RPC
// call will fail, so the trainer will retry.
log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err)
}
log.Infof("Successfully acquired lock at %s.", e.lockPath)
return e.Save(state)
}
return nil
}
// Load loads the state from etcd.
func (e *EtcdClient) Load() ([]byte, error) {
ctx := context.TODO()
get := clientv3.OpGet(e.statePath)
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit()
if err != nil {
return nil, err
}
if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock and load again.")
err = e.lock.Lock(context.Background())
if err != nil {
return nil, err
}
return e.Load()
}
kvs := resp.Responses[0].GetResponseRange().Kvs
if len(kvs) == 0 {
// No state exists
return nil, nil
}
state := kvs[0].Value
return state, nil
}

@ -0,0 +1,28 @@
package master
import "sync"
// InMemStore is an in memory implementation of Store interface.
//
// It does not tolerate the fault that casues the program to crash.
type InMemStore struct {
mu sync.Mutex
buf []byte
}
// Save saves the state into the in-memory store.
func (m *InMemStore) Save(state []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
m.buf = state
return nil
}
// Load loads the state from the in-memory store.
func (m *InMemStore) Load() ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.buf, nil
}

@ -1,6 +1,9 @@
package master package master
import ( import (
"bytes"
"compress/gzip"
"encoding/gob"
"errors" "errors"
"os" "os"
"path/filepath" "path/filepath"
@ -12,24 +15,54 @@ import (
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
const (
dialTimeout = 5 * time.Second
)
// Store is the interface for save and load the master state.
type Store interface {
Save([]byte) error
Load() ([]byte, error)
}
// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
Path string
Index recordio.Index // chunk index
}
// Task is the basic unit of data instances assigned to trainers.
type Task struct {
ID int
Chunks []Chunk
}
type taskEntry struct {
Epoch int
NumTimeout int
Task Task
}
type taskQueues struct {
Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry
Failed []Task
}
// Service is the master server service. // Service is the master server service.
type Service struct { type Service struct {
chunksPerTask int chunksPerTask int
timeoutDur time.Duration timeoutDur time.Duration
timeoutMax int timeoutMax int
ready chan struct{} ready chan struct{}
store Store
mu sync.Mutex mu sync.Mutex
initDone bool initDone bool
taskQueues taskQueues taskQueues taskQueues
} }
// Recover recovers service state from etcd.
func Recover() (*Service, error) {
// TODO(helin): recover from snapshot state from etcd.
return nil, nil
}
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0 id := 0
if chunksPerTask <= 0 { if chunksPerTask <= 0 {
@ -58,7 +91,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
// NewService creates a new service. // NewService creates a new service.
func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) {
s := &Service{} s := &Service{}
s.chunksPerTask = chunksPerTask s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
@ -66,38 +99,82 @@ func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Se
s.taskQueues = taskQueues{} s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry) s.taskQueues.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{}) s.ready = make(chan struct{})
return s s.store = store
} recovered, err := s.recover()
if err != nil {
return nil, err
}
// Chunk is a chunk of data consisted of several data instances. if recovered {
type Chunk struct { // Recovered. Now the state is already initialized,
Path string // and the master is ready.
Index recordio.Index // chunk index s.initDone = true
} close(s.ready)
log.Info("Master recovered from saved state.")
}
// Task is the basic unit of data instances assigned to trainers. return s, nil
type Task struct {
ID int
Chunks []Chunk
} }
type taskEntry struct { // recover recovers service state from etcd.
Epoch int func (s *Service) recover() (bool, error) {
NumTimeout int state, err := s.store.Load()
Task Task if err != nil {
} return false, err
}
type taskQueues struct { if state == nil {
Todo []taskEntry log.Infoln("No state exists, not recovered.")
Pending map[int]taskEntry // map from task ID to task entry return false, nil
Done []taskEntry }
Failed []Task
log.Infof("Loaded snapshot of size: %d bytes.", len(state))
gr, err := gzip.NewReader(bytes.NewReader(state))
if err != nil {
return false, err
}
dec := gob.NewDecoder(gr)
var tqs taskQueues
err = dec.Decode(&tqs)
if err != nil {
return false, err
}
err = gr.Close()
if err != nil {
// Only close failed, recover actually succeed, so
// just log error.
log.Errorln(err)
}
s.taskQueues = tqs
return true, nil
} }
// *must* be called with s.mu being held. // snapshot *must* be called with s.mu being held.
func (s *Service) snapshot() error { func (s *Service) snapshot() error {
// TODO(helin): snapshot state on etcd. // TOOD(helin): etcd request has a size limit, so the snapshot
return nil // size is limited by the max request size. We should either
// divide the snapshot into smaller chunks and save under
// different keys, or configure the request size to be big
// enough:
// https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
enc := gob.NewEncoder(gw)
err := enc.Encode(s.taskQueues)
if err != nil {
return err
}
err = gw.Close()
if err != nil {
return err
}
state := buf.Bytes()
log.Infof("Saving snapshot of size: %d bytes.", len(state))
return s.store.Save(state)
} }
func readChunks(globPaths []string) ([]Chunk, error) { func readChunks(globPaths []string) ([]Chunk, error) {
@ -207,12 +284,12 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
t.NumTimeout++ t.NumTimeout++
if t.NumTimeout > s.timeoutMax { if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v timed out %d times, discard.\n", t.Task, t.NumTimeout) log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return return
} }
log.Warningf("Task %v timed out %d times, retry.\n", t.Task, t.NumTimeout) log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t) s.taskQueues.Todo = append(s.taskQueues.Todo, t)
} }
} }

@ -133,7 +133,7 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter,
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name) log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.", name)
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Errorln(err)
@ -200,7 +200,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
@ -210,7 +210,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
} }

@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
@ -30,9 +31,12 @@ func init() {
port[i] = p port[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
if err != nil {
panic(err)
}
server := rpc.NewServer() server := rpc.NewServer()
err := server.Register(s) err = server.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -1,9 +1,18 @@
package pserver package pserver
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings"
"sync" "sync"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
@ -24,6 +33,9 @@ const (
Float64 Float64
) )
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"
// Parameter is a piece of data to sync with the parameter server. // Parameter is a piece of data to sync with the parameter server.
type Parameter struct { type Parameter struct {
Name string Name string
@ -47,14 +59,128 @@ type Service struct {
mu sync.Mutex mu sync.Mutex
opt *optimizer opt *optimizer
paramMap map[string]Parameter paramMap map[string]Parameter
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
} }
// NewService creates a new service. // NewService creates a new service, will bypass etcd registration if no
func NewService() *Service { // endpoints specified.
func NewService(endpoints string, timeout time.Duration) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)} s := &Service{opt: newOptimizer(sgd, 0.005)}
s.paramMap = make(map[string]Parameter) s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
return s s.etcdEndpoints = endpoints
s.etcdTimeout = timeout
var err error
s.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return nil, err
}
if endpoints != "" {
// initialize connection to etcd, try
ep := strings.Split(s.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: s.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
s.etcdClient = cli
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := s.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(s.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(s.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
} // if endpoints != ""
// Bypass etcd registration if no endpoints specified
return s, nil
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < s.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := s.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
} }
// InitParam initializes a parameter. // InitParam initializes a parameter.

@ -10,12 +10,15 @@ import (
) )
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
var p pserver.Parameter var p pserver.Parameter
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
@ -72,8 +75,11 @@ func TestFull(t *testing.T) {
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
err := s.FinishInitParams(0, nil) if err != nil {
t.Error(err)
}
err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) {
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
err := s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
} }
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
ch := make(chan struct{}, 2) ch := make(chan struct{}, 2)
errCh := make(chan error, 2) errCh := make(chan error, 2)
var wg sync.WaitGroup var wg sync.WaitGroup
@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }

@ -0,0 +1,45 @@
package networkhelper
import (
"errors"
"net"
)
// GetExternalIP returns the ip address of local network interface, not the
// loopback device.
func GetExternalIP() (string, error) {
ifaces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
return ip.String(), nil
}
}
return "", errors.New("are you connected to the network?")
}

@ -0,0 +1,10 @@
package networkhelper
import "testing"
func TestGetIP(t *testing.T) {
_, err := GetExternalIP()
if err != nil {
t.Errorf("GetExternalIP returns error : %v\n", err)
}
}

File diff suppressed because it is too large Load Diff

@ -309,35 +309,35 @@ public:
void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) { void addEvaluator(std::unique_ptr<Evaluator>&& evaluator) {
evaluators_.emplace_back(std::move(evaluator)); evaluators_.emplace_back(std::move(evaluator));
} }
virtual void start() { void start() override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->start(); evaluator->start();
} }
} }
virtual void finish() { void finish() override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->finish(); evaluator->finish();
} }
} }
virtual void eval(const NeuralNetwork& nn) override { void eval(const NeuralNetwork& nn) override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->eval(nn); evaluator->eval(nn);
} }
} }
virtual real evalImp(std::vector<Argument>& arguments) { real evalImp(std::vector<Argument>& arguments) override {
(void)arguments; (void)arguments;
return -1; return -1;
} }
virtual void printStats(std::ostream& os) const { void printStats(std::ostream& os) const override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->printStats(os); evaluator->printStats(os);
os << ' '; os << ' ';
} }
} }
virtual void distributeEval(ParameterClient2* client) { void distributeEval(ParameterClient2* client) override {
for (auto& evaluator : evaluators_) { for (auto& evaluator : evaluators_) {
evaluator->distributeEval(client); evaluator->distributeEval(client);
} }
@ -352,7 +352,7 @@ public:
* @brief getNames will return all inside evaluators' names. * @brief getNames will return all inside evaluators' names.
* @param names [out]: return names. * @param names [out]: return names.
*/ */
void getNames(std::vector<std::string>* names) { void getNames(std::vector<std::string>* names) override {
for (auto& eval : evaluators_) { for (auto& eval : evaluators_) {
eval->getNames(names); eval->getNames(names);
} }
@ -361,7 +361,7 @@ public:
/** /**
* @brief getValue could get all inside evaluators' value. * @brief getValue could get all inside evaluators' value.
*/ */
real getValue(const std::string& name, Error* err) const { real getValue(const std::string& name, Error* err) const override {
return this->getMethodHelper<real>( return this->getMethodHelper<real>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) { name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getValue(name, err); return eval->getValue(name, err);
@ -371,7 +371,7 @@ public:
/** /**
* @brief getType could get all inside evaluators' type. * @brief getType could get all inside evaluators' type.
*/ */
std::string getType(const std::string& name, Error* err) const { std::string getType(const std::string& name, Error* err) const override {
return this->getMethodHelper<std::string>( return this->getMethodHelper<std::string>(
name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) { name, err, [&name, err](const std::unique_ptr<Evaluator>& eval) {
return eval->getType(name, err); return eval->getType(name, err);

@ -138,6 +138,23 @@ void testEvaluatorAll(TestConfig testConf,
testEvaluator(testConf, testEvaluatorName, batchSize, false); testEvaluator(testConf, testEvaluatorName, batchSize, false);
} }
TEST(Evaluator, detection_map) {
TestConfig config;
config.evaluatorConfig.set_type("detection_map");
config.evaluatorConfig.set_overlap_threshold(0.5);
config.evaluatorConfig.set_background_id(0);
config.evaluatorConfig.set_ap_type("Integral");
config.evaluatorConfig.set_evaluate_difficult(0);
config.inputDefs.push_back({INPUT_DATA, "output", 7});
config.inputDefs.push_back({INPUT_SEQUENCE_DATA, "label", 6});
config.evaluatorConfig.set_evaluate_difficult(false);
testEvaluatorAll(config, "detection_map", 100);
config.evaluatorConfig.set_evaluate_difficult(true);
testEvaluatorAll(config, "detection_map", 100);
}
TEST(Evaluator, classification_error) { TEST(Evaluator, classification_error) {
TestConfig config; TestConfig config;
config.evaluatorConfig.set_type("classification_error"); config.evaluatorConfig.set_type("classification_error");

@ -14,11 +14,13 @@ limitations under the License. */
#include "ParameterUpdaterHook.h" #include "ParameterUpdaterHook.h"
#include <algorithm>
#include <atomic> #include <atomic>
#include <fstream> #include <fstream>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h" #include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h" #include "paddle/parameter/Parameter.h"
@ -29,106 +31,76 @@ namespace paddle {
/** /**
* The static pruning hook * The static pruning hook
* * Static means user specify a sparsity_ratio before training started, and the
* Static means user load a mask map before training started. This map will * network will prune the parameters based on the sparsity_ratio. More details
* define which link/weight between neural is disabled. * can be found https://arxiv.org/pdf/1506.02626.pdf.
*/ */
class StaticPruningHook : public IParameterUpdaterHook { class StaticPruningHook : public IParameterUpdaterHook {
public: public:
/** explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig)
* The Mask Map Header. : initCount_(0) {
* The map file started with this header. sparsityRatio_ = hookConfig.sparsity_ratio();
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/
struct StaticMaskHeader {
uint32_t version;
size_t size;
} __attribute__((__packed__));
explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) {
bool ok = this->loadMaskFile(mask_filename);
if (!ok) {
LOG(WARNING) << "Fail to load mask file " << mask_filename
<< " in current directory, searching in init_model_path";
std::string combineMaskFilename =
path::join(FLAGS_init_model_path, mask_filename);
CHECK(this->loadMaskFile(combineMaskFilename))
<< "Cannot load " << mask_filename << " in ./" << mask_filename
<< " and " << combineMaskFilename;
}
VLOG(3) << mask_filename << " mask size = " << this->mask_.size();
} }
void update(Parameter* para) { static bool sortPairAscend(const std::pair<real, size_t> &pair1,
const std::pair<real, size_t> &pair2) {
return pair1.first > pair2.first;
}
void update(Parameter *para) {
updateThreadChecker_.check(); updateThreadChecker_.check();
auto& vec = para->getBuf(PARAMETER_GRADIENT); auto &vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) { if (vec) {
vec->dotMul(*maskVec_); vec->dotMul(*maskVec_);
} }
} }
void init(Parameter* para) { void generateMask(Parameter *para) {
size_t initCount = this->initCount_.fetch_add(1); VectorPtr maskTemp = Vector::create(para->getSize(), false);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " maskTemp->zeroMem();
"in same ParamterUpdater"; real *maskTempData = maskTemp->getData();
VLOG(3) << "Initialize Parameter " << para; size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
SetDevice device(para->getDeviceId());
auto maskVec = Vector::create(this->mask_.size(), false); VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
{ // Initialize maskVec with float mask vector VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);
real* dataPtr = maskVec->getData();
size_t i = 0; paraCpuCopy->copyFrom(*paraVec);
for (bool m : mask_) { std::vector<std::pair<real, size_t>> param;
dataPtr[i++] = m ? 1.0 : 0.0;
} for (size_t i = 0; i < para->getSize(); i++)
} param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack. // Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) { if (para->useGpu()) {
maskVec_ = Vector::create(this->mask_.size(), para->useGpu()); maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskVec); maskVec_->copyFrom(*maskTemp);
} else { } else {
maskVec_ = maskVec; maskVec_ = maskTemp;
} }
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
} }
private: void init(Parameter *para) {
bool loadMaskFile(const std::string& mask_filename) { generateMask(para);
std::ifstream fin; size_t initCount = this->initCount_.fetch_add(1);
fin.open(mask_filename); CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
if (fin.is_open()) { "in same ParamterUpdater";
StaticMaskHeader header; VLOG(3) << "Initialize Parameter " << para;
fin.read(reinterpret_cast<char*>(&header), sizeof(StaticMaskHeader)); SetDevice device(para->getDeviceId());
CHECK_EQ(header.version, 0UL);
mask_.resize(header.size); auto &paraVec = para->getBuf(PARAMETER_VALUE);
uint8_t buf; paraVec->dotMul(*maskVec_);
for (size_t i = 0; i < header.size; ++i, buf <<= 1) {
if (i % 8 == 0) {
fin.read(reinterpret_cast<char*>(&buf), sizeof(uint8_t));
}
mask_[i] = buf & 0x80;
}
fin.close();
return true;
} else {
return false;
}
} }
private:
SameThreadChecker updateThreadChecker_; SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_; std::atomic<size_t> initCount_;
VectorPtr maskVec_; VectorPtr maskVec_;
std::vector<bool> mask_; real sparsityRatio_;
}; };
IParameterUpdaterHook::IParameterUpdaterHook() {} IParameterUpdaterHook::IParameterUpdaterHook() {}
@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
*/ */
class StringIntPairHasher { class StringIntPairHasher {
public: public:
size_t operator()(const std::pair<std::string, int>& k) const { size_t operator()(const std::pair<std::string, int> &k) const {
return intHasher_(strHasher_(k.first) + k.second); return intHasher_(strHasher_(k.first) + k.second);
} }
@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>,
/** /**
* ParameterUpdaterHook actually factory method. * ParameterUpdaterHook actually factory method.
*/ */
static IParameterUpdaterHook* createImpl( static IParameterUpdaterHook *createImpl(
const ParameterUpdaterHookConfig& config) { const ParameterUpdaterHookConfig &config) {
auto& type = config.type(); auto &type = config.type();
if (type == "pruning") { if (type == "pruning") {
if (config.has_purning_mask_filename()) { return new StaticPruningHook(config);
return new StaticPruningHook(config.purning_mask_filename());
}
} }
LOG(FATAL) << "Unknown Hook type: " << type;
return nullptr; return nullptr;
} }
std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create( std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create(
const ParameterConfig& paramConfig, int idx) { const ParameterConfig &paramConfig, int idx) {
std::pair<std::string, int> key = {paramConfig.name(), idx}; std::pair<std::string, int> key = {paramConfig.name(), idx};
return g_hookCache_.get( return g_hookCache_.get(
key, [&] { return createImpl(paramConfig.update_hooks(idx)); }); key, [&] { return createImpl(paramConfig.update_hooks(idx)); });

@ -1,12 +0,0 @@
#!/bin/bash
source ./common.sh
NPROC=1
export PYTHONPATH=/opt/python/2.7.12/lib/python2.7/site-packages
export PYTHONHOME=/opt/python/2.7.12
export PATH=/opt/python/2.7.12/bin:${PATH}
cmake .. -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DON_TRAVIS=ON -DWITH_COVERAGE=ON -DCOVERALLS_UPLOAD=ON ${EXTRA_CMAKE_OPTS}
NRPOC=`nproc`
make -j $NPROC
make coveralls
sudo make install

@ -1,15 +1,18 @@
#!/bin/bash #!/bin/bash
set -e
# Create the build directory for CMake.
mkdir -p $TRAVIS_BUILD_DIR/build
cd $TRAVIS_BUILD_DIR/build
# Add set -e, cd to directory.
source ./common.sh
# Compile Documentation only. # Compile Documentation only.
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_STYLE_CHECK=OFF ${EXTRA_CMAKE_OPTS} cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_STYLE_CHECK=OFF
mkdir output mkdir output
make -j `nproc` make -j `nproc`
find .. -name '*whl' | xargs pip install # install all wheels. find .. -name '*whl' | xargs pip install # install all wheels.
rm -rf * rm -rf *
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DWITH_GPU=OFF -DWITH_DOC=ON ${EXTRA_CMAKE_OPTS} cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=ON
make paddle_docs paddle_docs_cn make -j `nproc` paddle_docs paddle_docs_cn
# check websites for broken links # check websites for broken links
linkchecker doc/en/html/index.html linkchecker doc/en/html/index.html

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save