diff --git a/.travis.yml b/.travis.yml index 87cef10b2b..2c46da71e7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,6 @@ group: deprecated-2017Q2 language: cpp cache: directories: - - $HOME/third_party - $HOME/.ccache - $HOME/.cache/pip sudo: required @@ -10,15 +9,13 @@ dist: trusty os: - linux env: - - JOB=DOCS - - JOB=BUILD_AND_TEST - - JOB=PRE_COMMIT + - JOB=build_doc + - JOB=check_style addons: apt: packages: - gcc-4.8 - g++-4.8 - - gfortran-4.8 - git - build-essential - python @@ -35,18 +32,7 @@ addons: - libtool - ccache before_install: - - | - if [ ${JOB} == "BUILD_AND_TEST" ]; then - local change_list=`git diff --name-only $TRAVIS_COMMIT_RANGE` - if [ $? -eq 0 ]; then # if git diff return no zero, then rerun unit test. - if ! echo ${change_list} | grep -qvE '(\.md$)|(\.rst$)|(\.jpg$)|(\.png$)' - then - echo "Only markdown docs were updated, stopping build process." - exit - fi - fi - fi - - if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi + - if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python # protobuf version. - pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker @@ -55,9 +41,7 @@ before_install: - | function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } script: - - | - timeout 2580 paddle/scripts/travis/main.sh # 43min timeout - RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else false; fi; + - paddle/scripts/travis/$JOB.sh notifications: email: on_success: change diff --git a/Dockerfile b/Dockerfile index 39af60966b..bf227737c5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,7 +25,7 @@ COPY ./paddle/scripts/docker/root/ /root/ RUN apt-get update && \ apt-get install -y \ git python-pip python-dev openssh-server bison \ - wget unzip tar xz-utils bzip2 gzip coreutils \ + wget unzip tar xz-utils bzip2 gzip coreutils ntp \ curl sed grep graphviz libjpeg-dev zlib1g-dev \ python-numpy python-matplotlib gcc g++ \ automake locales clang-format-3.8 swig doxygen cmake \ diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index 2341e3785b..5b9d9844ed 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -21,7 +21,8 @@ IF(NOT ${CBLAS_FOUND}) SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) - SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/${LIBRARY_PREFIX}openblas${STATIC_LIBRARY_SUFFIX}" + SET(CBLAS_LIBRARIES + "${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE FILEPATH "openblas library." FORCE) SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs) diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 7340394b1e..d43badc1da 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -14,11 +14,41 @@ INCLUDE(ExternalProject) +# Print and set the protobuf library information, +# finish this cmake process and exit from this file. macro(PROMPT_PROTOBUF_LIB) + SET(protobuf_DEPS ${ARGN}) + MESSAGE(STATUS "Protobuf protoc executable: ${PROTOBUF_PROTOC_EXECUTABLE}") MESSAGE(STATUS "Protobuf library: ${PROTOBUF_LIBRARY}") MESSAGE(STATUS "Protobuf version: ${PROTOBUF_VERSION}") INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR}) + + # Assuming that all the protobuf libraries are of the same type. + IF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_STATIC_LIBRARY_SUFFIX}$") + SET(protobuf_LIBTYPE STATIC) + ELSEIF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_SHARED_LIBRARY_SUFFIX}$") + SET(protobuf_LIBTYPE SHARED) + ELSE() + MESSAGE(FATAL_ERROR "Unknown library type: ${PROTOBUF_LIBRARY}") + ENDIF() + + ADD_LIBRARY(protobuf ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET protobuf PROPERTY IMPORTED_LOCATION ${PROTOBUF_LIBRARY}) + + ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY}) + + ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY}) + + FOREACH(dep ${protobuf_DEPS}) + ADD_DEPENDENCIES(protobuf ${dep}) + ADD_DEPENDENCIES(protobuf_lite ${dep}) + ADD_DEPENDENCIES(protoc ${dep}) + ENDFOREACH() + + LIST(APPEND external_project_dependencies protobuf) RETURN() endmacro() macro(SET_PROTOBUF_VERSION) @@ -43,22 +73,23 @@ if (NOT "${PROTOBUF_ROOT}" STREQUAL "") endif() FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) - SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/${TARGET_NAME}) - SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/${TARGET_NAME}) + STRING(REPLACE "extern_" "" TARGET_DIR_NAME "${TARGET_NAME}") + SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/${TARGET_DIR_NAME}) + SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/${TARGET_DIR_NAME}) SET(${TARGET_NAME}_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) SET(PROTOBUF_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) SET(${TARGET_NAME}_LITE_LIBRARY - "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf-lite${STATIC_LIBRARY_SUFFIX}" + "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf-lite${CMAKE_STATIC_LIBRARY_SUFFIX}" PARENT_SCOPE) SET(${TARGET_NAME}_LIBRARY - "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf${STATIC_LIBRARY_SUFFIX}" + "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf${CMAKE_STATIC_LIBRARY_SUFFIX}" PARENT_SCOPE) SET(${TARGET_NAME}_PROTOC_LIBRARY - "${PROTOBUF_INSTALL_DIR}/lib/libprotoc${STATIC_LIBRARY_SUFFIX}" + "${PROTOBUF_INSTALL_DIR}/lib/libprotoc${CMAKE_STATIC_LIBRARY_SUFFIX}" PARENT_SCOPE) SET(${TARGET_NAME}_PROTOC_EXECUTABLE - "${PROTOBUF_INSTALL_DIR}/bin/protoc${EXECUTABLE_SUFFIX}" + "${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}" PARENT_SCOPE) SET(OPTIONAL_CACHE_ARGS "") @@ -109,6 +140,8 @@ IF(NOT CMAKE_CROSSCOMPILING) SET_PROTOBUF_VERSION() IF("${PROTOBUF_VERSION}" VERSION_LESS "3.1.0") SET(PROTOBUF_FOUND OFF) + ELSE() + PROMPT_PROTOBUF_LIB() ENDIF() ENDIF(PROTOBUF_FOUND) ELSE() @@ -120,18 +153,22 @@ ELSE() ENDIF() IF(NOT PROTOBUF_FOUND) - build_protobuf(protobuf FALSE) - LIST(APPEND external_project_dependencies protobuf) + build_protobuf(extern_protobuf FALSE) - SET(PROTOBUF_INCLUDE_DIR ${protobuf_INCLUDE_DIR} + SET(PROTOBUF_INCLUDE_DIR ${extern_protobuf_INCLUDE_DIR} CACHE PATH "protobuf include directory." FORCE) - IF(NOT CMAKE_CROSSCOMPILING) - SET(PROTOBUF_PROTOC_EXECUTABLE ${protobuf_PROTOC_EXECUTABLE} + SET(PROTOBUF_LITE_LIBRARY ${extern_protobuf_LITE_LIBRARY} + CACHE FILEPATH "protobuf lite library." FORCE) + SET(PROTOBUF_LIBRARY ${extern_protobuf_LIBRARY} + CACHE FILEPATH "protobuf library." FORCE) + SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY} + CACHE FILEPATH "protoc library." FORCE) + + IF(CMAKE_CROSSCOMPILING) + PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf) + ELSE() + SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE} CACHE FILEPATH "protobuf executable." FORCE) + PROMPT_PROTOBUF_LIB(extern_protobuf) ENDIF() - SET(PROTOBUF_LITE_LIBRARY ${protobuf_LITE_LIBRARY} CACHE FILEPATH "protobuf lite library." FORCE) - SET(PROTOBUF_LIBRARY ${protobuf_LIBRARY} CACHE FILEPATH "protobuf library." FORCE) - SET(PROTOBUF_PROTOC_LIBRARY ${protobuf_PROTOC_LIBRARY} CACHE FILEPATH "protoc library." FORCE) ENDIF(NOT PROTOBUF_FOUND) - -PROMPT_PROTOBUF_LIB() \ No newline at end of file diff --git a/cmake/system.cmake b/cmake/system.cmake index 904652413e..3b5cbfdd63 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -84,24 +84,6 @@ IF(DEFINED CMAKE_SYSTEM_NAME) ENDIF() ENDIF() -# prefix and suffix on different os -IF(WIN32) - SET(LIBRARY_PREFIX "") - SET(SHARED_LIBRARY_SUFFIX ".dll") - SET(STATIC_LIBRARY_SUFFIX ".lib") - SET(EXECUTABLE_SUFFIX ".exe") -ELSE(WIN32) - SET(LIBRARY_PREFIX "lib") - IF(APPLE) - SET(SHARED_LIBRARY_SUFFIX ".dylib") - ELSE(APPLE) - SET(SHARED_LIBRARY_SUFFIX ".so") - ENDIF(APPLE) - - SET(STATIC_LIBRARY_SUFFIX ".a") - SET(EXECUTABLE_SUFFIX "") -ENDIF(WIN32) - # external dependencies log output SET(EXTERNAL_PROJECT_LOG_ARGS LOG_DOWNLOAD 0 # Wrap download in script to log output diff --git a/doc/api/v2/config/evaluators.rst b/doc/api/v2/config/evaluators.rst index 39db51fa4a..9ac972fb19 100644 --- a/doc/api/v2/config/evaluators.rst +++ b/doc/api/v2/config/evaluators.rst @@ -99,3 +99,12 @@ value_printer .. automodule:: paddle.v2.evaluator :members: value_printer :noindex: + +Detection +===== + +detection_map +------------- +.. automodule:: paddle.v2.evaluator + :members: detection_map + :noindex: diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index 25cd1cafcd..54fa254863 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -1,45 +1,69 @@ package main import ( + "fmt" "net" "net/http" "net/rpc" "strconv" + "strings" "time" "github.com/namsral/flag" + log "github.com/sirupsen/logrus" "github.com/PaddlePaddle/Paddle/go/master" + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" ) func main() { port := flag.Int("port", 8080, "port of the master server.") - - faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") + ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") + endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") flag.Parse() - if *faultTolerance { - panic("fault tolernance not implemented.") + if *endpoints == "" { + log.Warningln("-endpoints not set, fault tolerance not be enabled.") + } + + var store master.Store + if *endpoints != "" { + eps := strings.Split(*endpoints, ",") + ip, err := networkhelper.GetExternalIP() + if err != nil { + log.Fatal(err) + } + addr := fmt.Sprintf("%s:%d", ip, *port) + store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec) + if err != nil { + log.Fatal(err) + } + } else { + store = &master.InMemStore{} + } + + s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) + if err != nil { + log.Fatal(err) } - s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) - err := rpc.Register(s) + err = rpc.Register(s) if err != nil { - panic(err) + log.Fatal(err) } rpc.HandleHTTP() l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) if err != nil { - panic(err) + log.Fatal(err) } err = http.Serve(l, nil) if err != nil { - panic(err) + log.Fatal(err) } } diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index f0be251c24..fe1fe5f6f0 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -5,18 +5,35 @@ import ( "net/http" "net/rpc" "strconv" + "time" "github.com/namsral/flag" "github.com/PaddlePaddle/Paddle/go/pserver" + log "github.com/sirupsen/logrus" ) func main() { port := flag.Int("port", 0, "port of the pserver") + etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", + "comma separated endpoint string for pserver to connect to etcd") + etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") + logLevel := flag.String("log-level", "info", + "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() - s := pserver.NewService() - err := rpc.Register(s) + level, err := log.ParseLevel(*logLevel) + if err != nil { + panic(err) + } + log.SetLevel(level) + + timeout := time.Second * time.Duration((*etcdTimeout)) + s, err := pserver.NewService(*etcdEndpoint, timeout) + if err != nil { + panic(err) + } + err = rpc.Register(s) if err != nil { panic(err) } @@ -27,7 +44,9 @@ func main() { panic(err) } + log.Infof("start pserver at port %d", *port) err = http.Serve(l, nil) + if err != nil { panic(err) } diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go index 00fcca0e2c..251225780a 100644 --- a/go/master/client_internal_test.go +++ b/go/master/client_internal_test.go @@ -47,9 +47,13 @@ func TestGetFinishTask(t *testing.T) { } go func(l net.Listener) { - s := NewService(chunkPerTask, time.Second, 1) + s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) + if err != nil { + panic(err) + } + server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/master/client_test.go b/go/master/client_test.go index 2b3f873ecf..85a86761c2 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -33,9 +33,13 @@ func TestNextRecord(t *testing.T) { } go func(l net.Listener) { - s := master.NewService(10, time.Second, 1) + s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) + if err != nil { + panic(err) + } + server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go new file mode 100644 index 0000000000..b7293a7598 --- /dev/null +++ b/go/master/etcd_client.go @@ -0,0 +1,144 @@ +package master + +import ( + "context" + "time" + + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + log "github.com/sirupsen/logrus" +) + +const ( + // DefaultLockPath is the default etcd master lock path. + DefaultLockPath = "/master/lock" + // DefaultStatePath is the default etcd key for master state. + DefaultStatePath = "/master/state" + // DefaultAddrPath is the default etcd key for master address. + DefaultAddrPath = "/master/addr" +) + +// EtcdClient is the etcd client that master uses for fault tolerance +// and service registry. +type EtcdClient struct { + lockPath string + statePath string + client *clientv3.Client + lock *concurrency.Mutex +} + +// NewEtcdClient creates a new EtcdClient. +func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { + log.Debugf("Connecting to etcd at %v", endpoints) + // TODO(helin): gracefully shutdown etcd store. Becuase etcd + // store holds a etcd lock, even though the lock will expire + // when the lease timeout, we need to implement graceful + // shutdown to release the lock. + cli, err := clientv3.New(clientv3.Config{ + Endpoints: endpoints, + DialTimeout: dialTimeout, + }) + if err != nil { + return nil, err + } + + sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec)) + if err != nil { + return nil, err + } + + lock := concurrency.NewMutex(sess, lockPath) + // It's fine for the lock to get stuck, in this case we have + // multiple master servers running (only configured to have + // one master running, but split-brain problem may cuase + // multiple master servers running), and the cluster management + // software will kill one of them. + log.Debugf("Trying to acquire lock at %s.", lockPath) + err = lock.Lock(context.TODO()) + if err != nil { + return nil, err + } + log.Debugf("Successfully acquired lock at %s.", lockPath) + + put := clientv3.OpPut(addrPath, string(addr)) + resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() + if err != nil { + return nil, err + } + + if !resp.Succeeded { + log.Fatal("No longer owns the master lock. Exiting.") + } + + e := &EtcdClient{ + lockPath: lockPath, + statePath: statePath, + client: cli, + lock: lock, + } + + return e, nil +} + +// Save saves the state into the etcd. +func (e *EtcdClient) Save(state []byte) error { + ctx := context.TODO() + put := clientv3.OpPut(e.statePath, string(state)) + resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit() + if err != nil { + return err + } + + if !resp.Succeeded { + log.Errorln("No longer owns the lock, trying to lock again") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + err := e.lock.Lock(ctx) + cancel() + if err != nil { + // We lost the master lock and can not acquire + // it back, it means some other master is + // already started. We don't want cluster + // managment system to kill the master server + // who is holding the lock and running + // correctly. So the most feasible solution is + // to kill current master server. The current + // state is not saved, but the trainer's RPC + // call will fail, so the trainer will retry. + log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err) + } + log.Infof("Successfully acquired lock at %s.", e.lockPath) + return e.Save(state) + } + + return nil +} + +// Load loads the state from etcd. +func (e *EtcdClient) Load() ([]byte, error) { + ctx := context.TODO() + get := clientv3.OpGet(e.statePath) + + resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit() + if err != nil { + return nil, err + } + + if !resp.Succeeded { + log.Errorln("No longer owns the lock, trying to lock and load again.") + err = e.lock.Lock(context.Background()) + if err != nil { + return nil, err + } + + return e.Load() + } + + kvs := resp.Responses[0].GetResponseRange().Kvs + if len(kvs) == 0 { + // No state exists + return nil, nil + } + + state := kvs[0].Value + return state, nil +} diff --git a/go/master/inmem_store.go b/go/master/inmem_store.go new file mode 100644 index 0000000000..bcd549b20e --- /dev/null +++ b/go/master/inmem_store.go @@ -0,0 +1,28 @@ +package master + +import "sync" + +// InMemStore is an in memory implementation of Store interface. +// +// It does not tolerate the fault that casues the program to crash. +type InMemStore struct { + mu sync.Mutex + buf []byte +} + +// Save saves the state into the in-memory store. +func (m *InMemStore) Save(state []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.buf = state + return nil +} + +// Load loads the state from the in-memory store. +func (m *InMemStore) Load() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.buf, nil +} diff --git a/go/master/service.go b/go/master/service.go index 55e1e2d1a4..58e68e7448 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -1,6 +1,9 @@ package master import ( + "bytes" + "compress/gzip" + "encoding/gob" "errors" "os" "path/filepath" @@ -12,24 +15,54 @@ import ( "github.com/PaddlePaddle/recordio" ) +const ( + dialTimeout = 5 * time.Second +) + +// Store is the interface for save and load the master state. +type Store interface { + Save([]byte) error + Load() ([]byte, error) +} + +// Chunk is a chunk of data consisted of several data instances. +type Chunk struct { + Path string + Index recordio.Index // chunk index +} + +// Task is the basic unit of data instances assigned to trainers. +type Task struct { + ID int + Chunks []Chunk +} + +type taskEntry struct { + Epoch int + NumTimeout int + Task Task +} + +type taskQueues struct { + Todo []taskEntry + Pending map[int]taskEntry // map from task ID to task entry + Done []taskEntry + Failed []Task +} + // Service is the master server service. type Service struct { chunksPerTask int timeoutDur time.Duration timeoutMax int ready chan struct{} + store Store mu sync.Mutex initDone bool taskQueues taskQueues } -// Recover recovers service state from etcd. -func Recover() (*Service, error) { - // TODO(helin): recover from snapshot state from etcd. - return nil, nil -} - func partition(chunks []Chunk, chunksPerTask int) []taskEntry { id := 0 if chunksPerTask <= 0 { @@ -58,7 +91,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { } // NewService creates a new service. -func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { +func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) { s := &Service{} s.chunksPerTask = chunksPerTask s.timeoutDur = timeoutDur @@ -66,38 +99,82 @@ func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Se s.taskQueues = taskQueues{} s.taskQueues.Pending = make(map[int]taskEntry) s.ready = make(chan struct{}) - return s -} + s.store = store + recovered, err := s.recover() + if err != nil { + return nil, err + } -// Chunk is a chunk of data consisted of several data instances. -type Chunk struct { - Path string - Index recordio.Index // chunk index -} + if recovered { + // Recovered. Now the state is already initialized, + // and the master is ready. + s.initDone = true + close(s.ready) + log.Info("Master recovered from saved state.") + } -// Task is the basic unit of data instances assigned to trainers. -type Task struct { - ID int - Chunks []Chunk + return s, nil } -type taskEntry struct { - Epoch int - NumTimeout int - Task Task -} +// recover recovers service state from etcd. +func (s *Service) recover() (bool, error) { + state, err := s.store.Load() + if err != nil { + return false, err + } -type taskQueues struct { - Todo []taskEntry - Pending map[int]taskEntry // map from task ID to task entry - Done []taskEntry - Failed []Task + if state == nil { + log.Infoln("No state exists, not recovered.") + return false, nil + } + + log.Infof("Loaded snapshot of size: %d bytes.", len(state)) + gr, err := gzip.NewReader(bytes.NewReader(state)) + if err != nil { + return false, err + } + + dec := gob.NewDecoder(gr) + var tqs taskQueues + err = dec.Decode(&tqs) + if err != nil { + return false, err + } + + err = gr.Close() + if err != nil { + // Only close failed, recover actually succeed, so + // just log error. + log.Errorln(err) + } + + s.taskQueues = tqs + return true, nil } -// *must* be called with s.mu being held. +// snapshot *must* be called with s.mu being held. func (s *Service) snapshot() error { - // TODO(helin): snapshot state on etcd. - return nil + // TOOD(helin): etcd request has a size limit, so the snapshot + // size is limited by the max request size. We should either + // divide the snapshot into smaller chunks and save under + // different keys, or configure the request size to be big + // enough: + // https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44 + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + enc := gob.NewEncoder(gw) + err := enc.Encode(s.taskQueues) + if err != nil { + return err + } + err = gw.Close() + if err != nil { + return err + } + + state := buf.Bytes() + log.Infof("Saving snapshot of size: %d bytes.", len(state)) + return s.store.Save(state) } func readChunks(globPaths []string) ([]Chunk, error) { @@ -207,12 +284,12 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { t.NumTimeout++ if t.NumTimeout > s.timeoutMax { - log.Warningf("Task %v timed out %d times, discard.\n", t.Task, t.NumTimeout) + log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout) s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) return } - log.Warningf("Task %v timed out %d times, retry.\n", t.Task, t.NumTimeout) + log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout) s.taskQueues.Todo = append(s.taskQueues.Todo, t) } } diff --git a/go/pserver/cclient/cclient.go b/go/pserver/cclient/cclient.go index 92a41b7f54..bbaf43d9f1 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/cclient/cclient.go @@ -133,7 +133,7 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, if err != nil { if err.Error() == pserver.AlreadyInitialized { - log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name) + log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.", name) return C.PSERVER_OK } log.Errorln(err) @@ -200,7 +200,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, for i, p := range ps { pn[i] = p.Name } - log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) return C.PSERVER_ERROR } @@ -210,7 +210,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, for i, p := range ps { pn[i] = p.Name } - log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) return C.PSERVER_ERROR } } diff --git a/go/pserver/client_test.go b/go/pserver/client_test.go index d0371a26a1..6ecf1fa08a 100644 --- a/go/pserver/client_test.go +++ b/go/pserver/client_test.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/PaddlePaddle/Paddle/go/pserver" ) @@ -30,9 +31,12 @@ func init() { port[i] = p go func(l net.Listener) { - s := pserver.NewService() + s, err := pserver.NewService("", time.Second*5) + if err != nil { + panic(err) + } server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/pserver/service.go b/go/pserver/service.go index 78a2bfaf63..7e2b841dd8 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -1,9 +1,18 @@ package pserver import ( + "context" "errors" "fmt" + "strconv" + "strings" "sync" + "time" + + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. @@ -24,6 +33,9 @@ const ( Float64 ) +// PsDesired is etcd path for store desired pserver count +const PsDesired = "/ps_desired" + // Parameter is a piece of data to sync with the parameter server. type Parameter struct { Name string @@ -47,14 +59,128 @@ type Service struct { mu sync.Mutex opt *optimizer paramMap map[string]Parameter + + etcdEndpoints string + etcdClient *clientv3.Client + // etcdTimeout is also used as retry intervals. + etcdTimeout time.Duration + // desired number of pservers in the job. + // assume desired will not change during one training job. + desired int + // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. + externalIP string } -// NewService creates a new service. -func NewService() *Service { +// NewService creates a new service, will bypass etcd registration if no +// endpoints specified. +func NewService(endpoints string, timeout time.Duration) (*Service, error) { s := &Service{opt: newOptimizer(sgd, 0.005)} s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) - return s + s.etcdEndpoints = endpoints + s.etcdTimeout = timeout + + var err error + s.externalIP, err = networkhelper.GetExternalIP() + if err != nil { + return nil, err + } + + if endpoints != "" { + // initialize connection to etcd, try + ep := strings.Split(s.etcdEndpoints, ",") + for { + cli, err := clientv3.New(clientv3.Config{ + Endpoints: ep, + DialTimeout: s.etcdTimeout, + }) + if err != nil { + log.Errorf("connect to etcd error: %v", err) + time.Sleep(s.etcdTimeout) + continue + } + s.etcdClient = cli + log.Debugf("inited client to %s", s.etcdEndpoints) + break + } + // wait and set s.desired init value + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err := s.etcdClient.Get(ctx, PsDesired) + cancel() + if err != nil { + log.Errorf("getting %s error: %v", PsDesired, err) + time.Sleep(s.etcdTimeout) + continue + } + if len(resp.Kvs) != 0 { + s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) + if err != nil { + log.Errorf("value of %s invalid %v\n", PsDesired, err) + time.Sleep(s.etcdTimeout) + // NOTE: wait util ps_desired value change + continue + } + break + } + } + // try register pserver node on etcd + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := s.registerPserverEtcd(ctx) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(s.etcdTimeout) + continue + } + break + } + } // if endpoints != "" + // Bypass etcd registration if no endpoints specified + return s, nil +} + +// registerPserverEtcd registers pserver node on etcd using transaction. +func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) { + return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { + registered := false + for i := 0; i < s.desired; i++ { + psKey := "/ps/" + strconv.Itoa(i) + log.Debugf("checking %s", psKey) + ps := c.Get(psKey) + log.Debugf("got value (%s) for key: %s", ps, psKey) + + if ps == "" { + resp, err := s.etcdClient.Grant(context.TODO(), 5) + if err != nil { + log.Fatal(err) + } + // find the first id and write info + c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID)) + log.Debugf("set pserver node %s with value %s", psKey, s.externalIP) + ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) + if kaerr != nil { + log.Errorf("keepalive etcd node error: %v", kaerr) + return kaerr + } + + // Eat the keep alive message so etcd + // will not expire the lease. + go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { + ka := <-ch + log.Debugf("keepalive: %d\n", ka.TTL) + }(ch) + log.Debug("register finished") + registered = true + break + } + } + if registered == true { + return nil + } + return errors.New("not registerd, may due to already have enough pservers") + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) } // InitParam initializes a parameter. diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index b746d13e1c..f317535592 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -10,12 +10,15 @@ import ( ) func TestFull(t *testing.T) { - s := pserver.NewService() + s, err := pserver.NewService("", time.Second*5) + if err != nil { + t.Error(err) + } var p pserver.Parameter p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) + err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) if err != nil { t.FailNow() } @@ -72,8 +75,11 @@ func TestFull(t *testing.T) { } func TestMultipleInit(t *testing.T) { - s := pserver.NewService() - err := s.FinishInitParams(0, nil) + s, err := pserver.NewService("", time.Second*5) + if err != nil { + t.Error(err) + } + err = s.FinishInitParams(0, nil) if err != nil { t.FailNow() } @@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) { } func TestUninitialized(t *testing.T) { - s := pserver.NewService() - err := s.SendGrad(pserver.Gradient{}, nil) + s, err := pserver.NewService("", time.Second*5) + err = s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.FailNow() } } func TestBlockUntilInitialized(t *testing.T) { - s := pserver.NewService() + s, err := pserver.NewService("", time.Second*5) + if err != nil { + t.Error(err) + } ch := make(chan struct{}, 2) errCh := make(chan error, 2) var wg sync.WaitGroup @@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) + err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) if err != nil { t.FailNow() } diff --git a/go/utils/networkhelper/helper.go b/go/utils/networkhelper/helper.go new file mode 100644 index 0000000000..fbeaea8f5e --- /dev/null +++ b/go/utils/networkhelper/helper.go @@ -0,0 +1,45 @@ +package networkhelper + +import ( + "errors" + "net" +) + +// GetExternalIP returns the ip address of local network interface, not the +// loopback device. +func GetExternalIP() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", err + } + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + if iface.Flags&net.FlagLoopback != 0 { + continue // loopback interface + } + addrs, err := iface.Addrs() + if err != nil { + return "", err + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip == nil || ip.IsLoopback() { + continue + } + ip = ip.To4() + if ip == nil { + continue // not an ipv4 address + } + return ip.String(), nil + } + } + return "", errors.New("are you connected to the network?") +} diff --git a/go/utils/networkhelper/helper_test.go b/go/utils/networkhelper/helper_test.go new file mode 100644 index 0000000000..4208f9e358 --- /dev/null +++ b/go/utils/networkhelper/helper_test.go @@ -0,0 +1,10 @@ +package networkhelper + +import "testing" + +func TestGetIP(t *testing.T) { + _, err := GetExternalIP() + if err != nil { + t.Errorf("GetExternalIP returns error : %v\n", err) + } +} diff --git a/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp b/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp new file mode 100644 index 0000000000..9b825db574 --- /dev/null +++ b/paddle/gserver/evaluators/DetectionMAPEvaluator.cpp @@ -0,0 +1,308 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Evaluator.h" +#include "paddle/gserver/layers/DetectionUtil.h" + +using std::map; +using std::vector; +using std::pair; +using std::make_pair; + +namespace paddle { + +/** + * @brief detection map Evaluator + * + * The config file api is detection_map_evaluator. + */ +class DetectionMAPEvaluator : public Evaluator { +public: + DetectionMAPEvaluator() + : evaluateDifficult_(false), cpuOutput_(nullptr), cpuLabel_(nullptr) {} + + virtual void start() { + Evaluator::start(); + allTruePos_.clear(); + allFalsePos_.clear(); + numPos_.clear(); + } + + virtual real evalImp(std::vector& arguments) { + overlapThreshold_ = config_.overlap_threshold(); + backgroundId_ = config_.background_id(); + evaluateDifficult_ = config_.evaluate_difficult(); + apType_ = config_.ap_type(); + + MatrixPtr detectTmpValue = arguments[0].value; + Matrix::resizeOrCreate(cpuOutput_, + detectTmpValue->getHeight(), + detectTmpValue->getWidth(), + false, + false); + + MatrixPtr labelTmpValue = arguments[1].value; + Matrix::resizeOrCreate(cpuLabel_, + labelTmpValue->getHeight(), + labelTmpValue->getWidth(), + false, + false); + + cpuOutput_->copyFrom(*detectTmpValue); + cpuLabel_->copyFrom(*labelTmpValue); + + Argument label = arguments[1]; + const int* labelIndex = label.sequenceStartPositions->getData(false); + size_t batchSize = label.getNumSequences(); + + vector>> allGTBBoxes; + vector>>> allDetectBBoxes; + + for (size_t n = 0; n < batchSize; ++n) { + map> bboxes; + for (int i = labelIndex[n]; i < labelIndex[n + 1]; ++i) { + vector bbox; + getBBoxFromLabelData(cpuLabel_->getData() + i * 6, 1, bbox); + int c = cpuLabel_->getData()[i * 6]; + bboxes[c].push_back(bbox[0]); + } + allGTBBoxes.push_back(bboxes); + } + + size_t n = 0; + const real* cpuOutputData = cpuOutput_->getData(); + for (size_t imgId = 0; imgId < batchSize; ++imgId) { + map>> bboxes; + size_t curImgId = static_cast((cpuOutputData + n * 7)[0]); + while (curImgId == imgId && n < cpuOutput_->getHeight()) { + vector label; + vector score; + vector bbox; + getBBoxFromDetectData(cpuOutputData + n * 7, 1, label, score, bbox); + bboxes[label[0]].push_back(make_pair(score[0], bbox[0])); + ++n; + curImgId = static_cast((cpuOutputData + n * 7)[0]); + } + allDetectBBoxes.push_back(bboxes); + } + + for (size_t n = 0; n < batchSize; ++n) { + for (map>::iterator it = + allGTBBoxes[n].begin(); + it != allGTBBoxes[n].end(); + ++it) { + size_t count = 0; + if (evaluateDifficult_) { + count = it->second.size(); + } else { + for (size_t i = 0; i < it->second.size(); ++i) + if (!(it->second[i].isDifficult)) ++count; + } + if (numPos_.find(it->first) == numPos_.end() && count != 0) { + numPos_[it->first] = count; + } else { + numPos_[it->first] += count; + } + } + } + + // calcTFPos + calcTFPos(batchSize, allGTBBoxes, allDetectBBoxes); + + return 0; + } + + virtual void printStats(std::ostream& os) const { + real mAP = calcMAP(); + os << "Detection mAP=" << mAP; + } + + virtual void distributeEval(ParameterClient2* client) { + LOG(FATAL) << "Distribute detection evaluation not implemented."; + } + +protected: + void calcTFPos(const size_t batchSize, + const vector>>& allGTBBoxes, + const vector>>>& + allDetectBBoxes) { + for (size_t n = 0; n < allDetectBBoxes.size(); ++n) { + if (allGTBBoxes[n].size() == 0) { + for (map>>::const_iterator + it = allDetectBBoxes[n].begin(); + it != allDetectBBoxes[n].end(); + ++it) { + size_t label = it->first; + for (size_t i = 0; i < it->second.size(); ++i) { + allTruePos_[label].push_back(make_pair(it->second[i].first, 0)); + allFalsePos_[label].push_back(make_pair(it->second[i].first, 1)); + } + } + } else { + for (map>>::const_iterator + it = allDetectBBoxes[n].begin(); + it != allDetectBBoxes[n].end(); + ++it) { + size_t label = it->first; + vector> predBBoxes = it->second; + if (allGTBBoxes[n].find(label) == allGTBBoxes[n].end()) { + for (size_t i = 0; i < predBBoxes.size(); ++i) { + allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0)); + allFalsePos_[label].push_back(make_pair(predBBoxes[i].first, 1)); + } + } else { + vector gtBBoxes = + allGTBBoxes[n].find(label)->second; + vector visited(gtBBoxes.size(), false); + // Sort detections in descend order based on scores + std::sort(predBBoxes.begin(), + predBBoxes.end(), + sortScorePairDescend); + for (size_t i = 0; i < predBBoxes.size(); ++i) { + real maxOverlap = -1.0; + size_t maxIdx = 0; + for (size_t j = 0; j < gtBBoxes.size(); ++j) { + real overlap = + jaccardOverlap(predBBoxes[i].second, gtBBoxes[j]); + if (overlap > maxOverlap) { + maxOverlap = overlap; + maxIdx = j; + } + } + if (maxOverlap > overlapThreshold_) { + if (evaluateDifficult_ || + (!evaluateDifficult_ && !gtBBoxes[maxIdx].isDifficult)) { + if (!visited[maxIdx]) { + allTruePos_[label].push_back( + make_pair(predBBoxes[i].first, 1)); + allFalsePos_[label].push_back( + make_pair(predBBoxes[i].first, 0)); + visited[maxIdx] = true; + } else { + allTruePos_[label].push_back( + make_pair(predBBoxes[i].first, 0)); + allFalsePos_[label].push_back( + make_pair(predBBoxes[i].first, 1)); + } + } + } else { + allTruePos_[label].push_back(make_pair(predBBoxes[i].first, 0)); + allFalsePos_[label].push_back( + make_pair(predBBoxes[i].first, 1)); + } + } + } + } + } + } + } + + real calcMAP() const { + real mAP = 0.0; + size_t count = 0; + for (map::const_iterator it = numPos_.begin(); + it != numPos_.end(); + ++it) { + size_t label = it->first; + size_t labelNumPos = it->second; + if (labelNumPos == 0 || allTruePos_.find(label) == allTruePos_.end()) + continue; + vector> labelTruePos = allTruePos_.find(label)->second; + vector> labelFalsePos = + allFalsePos_.find(label)->second; + // Compute average precision. + vector tpCumSum; + getAccumulation(labelTruePos, &tpCumSum); + vector fpCumSum; + getAccumulation(labelFalsePos, &fpCumSum); + std::vector precision, recall; + size_t num = tpCumSum.size(); + // Compute Precision. + for (size_t i = 0; i < num; ++i) { + CHECK_LE(tpCumSum[i], labelNumPos); + precision.push_back(static_cast(tpCumSum[i]) / + static_cast(tpCumSum[i] + fpCumSum[i])); + recall.push_back(static_cast(tpCumSum[i]) / labelNumPos); + } + // VOC2007 style + if (apType_ == "11point") { + vector maxPrecisions(11, 0.0); + int startIdx = num - 1; + for (int j = 10; j >= 0; --j) + for (int i = startIdx; i >= 0; --i) { + if (recall[i] < j / 10.) { + startIdx = i; + if (j > 0) maxPrecisions[j - 1] = maxPrecisions[j]; + break; + } else { + if (maxPrecisions[j] < precision[i]) + maxPrecisions[j] = precision[i]; + } + } + for (int j = 10; j >= 0; --j) mAP += maxPrecisions[j] / 11; + ++count; + } else if (apType_ == "Integral") { + // Nature integral + real averagePrecisions = 0.; + real prevRecall = 0.; + for (size_t i = 0; i < num; ++i) { + if (fabs(recall[i] - prevRecall) > 1e-6) + averagePrecisions += precision[i] * fabs(recall[i] - prevRecall); + prevRecall = recall[i]; + } + mAP += averagePrecisions; + ++count; + } else { + LOG(FATAL) << "Unkown ap version: " << apType_; + } + } + if (count != 0) mAP /= count; + return mAP * 100; + } + + void getAccumulation(vector> inPairs, + vector* accuVec) const { + std::stable_sort( + inPairs.begin(), inPairs.end(), sortScorePairDescend); + accuVec->clear(); + size_t sum = 0; + for (size_t i = 0; i < inPairs.size(); ++i) { + sum += inPairs[i].second; + accuVec->push_back(sum); + } + } + + std::string getTypeImpl() const { return "detection_map"; } + + real getValueImpl() const { return calcMAP(); } + +private: + real overlapThreshold_; // overlap threshold when determining whether matched + bool evaluateDifficult_; // whether evaluate difficult ground truth + size_t backgroundId_; // class index of background + std::string apType_; // how to calculate mAP (Integral or 11point) + + MatrixPtr cpuOutput_; + MatrixPtr cpuLabel_; + + map numPos_; // counts of true objects each classification + map>> + allTruePos_; // true positive prediction + map>> + allFalsePos_; // false positive prediction +}; + +REGISTER_EVALUATOR(detection_map, DetectionMAPEvaluator); + +} // namespace paddle diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index 514c0759e1..2e839f6405 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -309,35 +309,35 @@ public: void addEvaluator(std::unique_ptr&& evaluator) { evaluators_.emplace_back(std::move(evaluator)); } - virtual void start() { + void start() override { for (auto& evaluator : evaluators_) { evaluator->start(); } } - virtual void finish() { + void finish() override { for (auto& evaluator : evaluators_) { evaluator->finish(); } } - virtual void eval(const NeuralNetwork& nn) override { + void eval(const NeuralNetwork& nn) override { for (auto& evaluator : evaluators_) { evaluator->eval(nn); } } - virtual real evalImp(std::vector& arguments) { + real evalImp(std::vector& arguments) override { (void)arguments; return -1; } - virtual void printStats(std::ostream& os) const { + void printStats(std::ostream& os) const override { for (auto& evaluator : evaluators_) { evaluator->printStats(os); os << ' '; } } - virtual void distributeEval(ParameterClient2* client) { + void distributeEval(ParameterClient2* client) override { for (auto& evaluator : evaluators_) { evaluator->distributeEval(client); } @@ -352,7 +352,7 @@ public: * @brief getNames will return all inside evaluators' names. * @param names [out]: return names. */ - void getNames(std::vector* names) { + void getNames(std::vector* names) override { for (auto& eval : evaluators_) { eval->getNames(names); } @@ -361,7 +361,7 @@ public: /** * @brief getValue could get all inside evaluators' value. */ - real getValue(const std::string& name, Error* err) const { + real getValue(const std::string& name, Error* err) const override { return this->getMethodHelper( name, err, [&name, err](const std::unique_ptr& eval) { return eval->getValue(name, err); @@ -371,7 +371,7 @@ public: /** * @brief getType could get all inside evaluators' type. */ - std::string getType(const std::string& name, Error* err) const { + std::string getType(const std::string& name, Error* err) const override { return this->getMethodHelper( name, err, [&name, err](const std::unique_ptr& eval) { return eval->getType(name, err); diff --git a/paddle/gserver/tests/test_Evaluator.cpp b/paddle/gserver/tests/test_Evaluator.cpp index 4f5fdbb37c..93996392d2 100644 --- a/paddle/gserver/tests/test_Evaluator.cpp +++ b/paddle/gserver/tests/test_Evaluator.cpp @@ -138,6 +138,23 @@ void testEvaluatorAll(TestConfig testConf, testEvaluator(testConf, testEvaluatorName, batchSize, false); } +TEST(Evaluator, detection_map) { + TestConfig config; + config.evaluatorConfig.set_type("detection_map"); + config.evaluatorConfig.set_overlap_threshold(0.5); + config.evaluatorConfig.set_background_id(0); + config.evaluatorConfig.set_ap_type("Integral"); + config.evaluatorConfig.set_evaluate_difficult(0); + + config.inputDefs.push_back({INPUT_DATA, "output", 7}); + config.inputDefs.push_back({INPUT_SEQUENCE_DATA, "label", 6}); + config.evaluatorConfig.set_evaluate_difficult(false); + testEvaluatorAll(config, "detection_map", 100); + + config.evaluatorConfig.set_evaluate_difficult(true); + testEvaluatorAll(config, "detection_map", 100); +} + TEST(Evaluator, classification_error) { TestConfig config; config.evaluatorConfig.set_type("classification_error"); diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index f826e8448c..c8b47687f5 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -14,11 +14,13 @@ limitations under the License. */ #include "ParameterUpdaterHook.h" +#include #include #include #include #include #include +#include #include "paddle/math/Vector.h" #include "paddle/parameter/Parameter.h" @@ -29,106 +31,76 @@ namespace paddle { /** * The static pruning hook - * - * Static means user load a mask map before training started. This map will - * define which link/weight between neural is disabled. + * Static means user specify a sparsity_ratio before training started, and the + * network will prune the parameters based on the sparsity_ratio. More details + * can be found https://arxiv.org/pdf/1506.02626.pdf. */ + class StaticPruningHook : public IParameterUpdaterHook { public: - /** - * The Mask Map Header. - * The map file started with this header. - * - * In Version 0, reset file will be: - * contains header.size bit, each bit means such weight is enabled or not. - * if bit is 1, then such weight is enabled. - * at end, the file will round to byte, and the low bits of end byte will be - * filled by zero. - * - */ - struct StaticMaskHeader { - uint32_t version; - size_t size; - } __attribute__((__packed__)); - - explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) { - bool ok = this->loadMaskFile(mask_filename); - if (!ok) { - LOG(WARNING) << "Fail to load mask file " << mask_filename - << " in current directory, searching in init_model_path"; - std::string combineMaskFilename = - path::join(FLAGS_init_model_path, mask_filename); - CHECK(this->loadMaskFile(combineMaskFilename)) - << "Cannot load " << mask_filename << " in ./" << mask_filename - << " and " << combineMaskFilename; - } - VLOG(3) << mask_filename << " mask size = " << this->mask_.size(); + explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig) + : initCount_(0) { + sparsityRatio_ = hookConfig.sparsity_ratio(); } - void update(Parameter* para) { + static bool sortPairAscend(const std::pair &pair1, + const std::pair &pair2) { + return pair1.first > pair2.first; + } + + void update(Parameter *para) { updateThreadChecker_.check(); - auto& vec = para->getBuf(PARAMETER_GRADIENT); + auto &vec = para->getBuf(PARAMETER_GRADIENT); if (vec) { vec->dotMul(*maskVec_); } } - void init(Parameter* para) { - size_t initCount = this->initCount_.fetch_add(1); - CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " - "in same ParamterUpdater"; - VLOG(3) << "Initialize Parameter " << para; - SetDevice device(para->getDeviceId()); + void generateMask(Parameter *para) { + VectorPtr maskTemp = Vector::create(para->getSize(), false); + maskTemp->zeroMem(); + real *maskTempData = maskTemp->getData(); + size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); - auto maskVec = Vector::create(this->mask_.size(), false); - { // Initialize maskVec with float mask vector - real* dataPtr = maskVec->getData(); - size_t i = 0; - for (bool m : mask_) { - dataPtr[i++] = m ? 1.0 : 0.0; - } - } + VectorPtr paraVec = para->getBuf(PARAMETER_VALUE); + VectorPtr paraCpuCopy = Vector::create(para->getSize(), false); + + paraCpuCopy->copyFrom(*paraVec); + std::vector> param; + + for (size_t i = 0; i < para->getSize(); i++) + param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i)); + + std::partial_sort( + param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); + for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0; // Currently just use a mask vector for hack. - // @TODO(yuyang18): Implemented the mask operation in vector. if (para->useGpu()) { - maskVec_ = Vector::create(this->mask_.size(), para->useGpu()); - maskVec_->copyFrom(*maskVec); + maskVec_ = Vector::create(para->getSize(), para->useGpu()); + maskVec_->copyFrom(*maskTemp); } else { - maskVec_ = maskVec; + maskVec_ = maskTemp; } - - auto& vec = para->getBuf(PARAMETER_VALUE); - vec->dotMul(*maskVec_); } -private: - bool loadMaskFile(const std::string& mask_filename) { - std::ifstream fin; - fin.open(mask_filename); - if (fin.is_open()) { - StaticMaskHeader header; - fin.read(reinterpret_cast(&header), sizeof(StaticMaskHeader)); - CHECK_EQ(header.version, 0UL); - mask_.resize(header.size); - uint8_t buf; - for (size_t i = 0; i < header.size; ++i, buf <<= 1) { - if (i % 8 == 0) { - fin.read(reinterpret_cast(&buf), sizeof(uint8_t)); - } - mask_[i] = buf & 0x80; - } - fin.close(); - return true; - } else { - return false; - } + void init(Parameter *para) { + generateMask(para); + size_t initCount = this->initCount_.fetch_add(1); + CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " + "in same ParamterUpdater"; + VLOG(3) << "Initialize Parameter " << para; + SetDevice device(para->getDeviceId()); + + auto ¶Vec = para->getBuf(PARAMETER_VALUE); + paraVec->dotMul(*maskVec_); } +private: SameThreadChecker updateThreadChecker_; std::atomic initCount_; VectorPtr maskVec_; - std::vector mask_; + real sparsityRatio_; }; IParameterUpdaterHook::IParameterUpdaterHook() {} @@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {} */ class StringIntPairHasher { public: - size_t operator()(const std::pair& k) const { + size_t operator()(const std::pair &k) const { return intHasher_(strHasher_(k.first) + k.second); } @@ -162,19 +134,19 @@ static WeakKVCache, /** * ParameterUpdaterHook actually factory method. */ -static IParameterUpdaterHook* createImpl( - const ParameterUpdaterHookConfig& config) { - auto& type = config.type(); +static IParameterUpdaterHook *createImpl( + const ParameterUpdaterHookConfig &config) { + auto &type = config.type(); if (type == "pruning") { - if (config.has_purning_mask_filename()) { - return new StaticPruningHook(config.purning_mask_filename()); - } + return new StaticPruningHook(config); } + + LOG(FATAL) << "Unknown Hook type: " << type; return nullptr; } std::shared_ptr IParameterUpdaterHook::create( - const ParameterConfig& paramConfig, int idx) { + const ParameterConfig ¶mConfig, int idx) { std::pair key = {paramConfig.name(), idx}; return g_hookCache_.get( key, [&] { return createImpl(paramConfig.update_hooks(idx)); }); diff --git a/paddle/scripts/travis/build_and_test.sh b/paddle/scripts/travis/build_and_test.sh deleted file mode 100755 index f2cbc56165..0000000000 --- a/paddle/scripts/travis/build_and_test.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash -source ./common.sh - -NPROC=1 -export PYTHONPATH=/opt/python/2.7.12/lib/python2.7/site-packages -export PYTHONHOME=/opt/python/2.7.12 -export PATH=/opt/python/2.7.12/bin:${PATH} -cmake .. -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DON_TRAVIS=ON -DWITH_COVERAGE=ON -DCOVERALLS_UPLOAD=ON ${EXTRA_CMAKE_OPTS} -NRPOC=`nproc` -make -j $NPROC -make coveralls -sudo make install diff --git a/paddle/scripts/travis/docs.sh b/paddle/scripts/travis/build_doc.sh similarity index 84% rename from paddle/scripts/travis/docs.sh rename to paddle/scripts/travis/build_doc.sh index c784293695..88264d8c26 100755 --- a/paddle/scripts/travis/docs.sh +++ b/paddle/scripts/travis/build_doc.sh @@ -1,15 +1,18 @@ #!/bin/bash +set -e + +# Create the build directory for CMake. +mkdir -p $TRAVIS_BUILD_DIR/build +cd $TRAVIS_BUILD_DIR/build -# Add set -e, cd to directory. -source ./common.sh # Compile Documentation only. -cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_STYLE_CHECK=OFF ${EXTRA_CMAKE_OPTS} +cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_STYLE_CHECK=OFF mkdir output make -j `nproc` find .. -name '*whl' | xargs pip install # install all wheels. rm -rf * -cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_Fortran_COMPILER=/usr/bin/gfortran-4.8 -DWITH_GPU=OFF -DWITH_DOC=ON ${EXTRA_CMAKE_OPTS} -make paddle_docs paddle_docs_cn +cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=ON +make -j `nproc` paddle_docs paddle_docs_cn # check websites for broken links linkchecker doc/en/html/index.html diff --git a/paddle/scripts/travis/precommit.sh b/paddle/scripts/travis/check_style.sh similarity index 54% rename from paddle/scripts/travis/precommit.sh rename to paddle/scripts/travis/check_style.sh index 7a59b1131d..4754bdd4c8 100755 --- a/paddle/scripts/travis/precommit.sh +++ b/paddle/scripts/travis/check_style.sh @@ -1,14 +1,14 @@ #!/bin/bash function abort(){ - echo "Your commit not fit PaddlePaddle code style" 1>&2 - echo "Please use pre-commit scripts to auto-format your code" 1>&2 + echo "Your change doesn't follow PaddlePaddle's code style." 1>&2 + echo "Please use pre-commit to reformat your code and git push again." 1>&2 exit 1 } trap 'abort' 0 set -e -source common.sh -cd .. + +cd $TRAVIS_BUILD_DIR export PATH=/usr/bin:$PATH pre-commit install clang-format --version diff --git a/paddle/scripts/travis/common.sh b/paddle/scripts/travis/common.sh deleted file mode 100755 index f05c7530a3..0000000000 --- a/paddle/scripts/travis/common.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -set -e -mkdir -p ../../../build -cd ../../../build -mkdir -p $HOME/third_party -EXTRA_CMAKE_OPTS="-DTHIRD_PARTY_PATH=${HOME}/third_party" diff --git a/paddle/scripts/travis/main.sh b/paddle/scripts/travis/main.sh deleted file mode 100755 index 13f2552d29..0000000000 --- a/paddle/scripts/travis/main.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -cd `dirname $0` - -if [ ${JOB} == "BUILD_AND_TEST" ]; then - ./build_and_test.sh -elif [ ${JOB} == "DOCS" ]; then - ./docs.sh -elif [ ${JOB} == "PRE_COMMIT" ]; then - ./precommit.sh -else - echo Unknown job ${JOB} - exit 1 -fi diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 29270829bb..ebe4f5cbb5 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -489,6 +489,15 @@ message EvaluatorConfig { // Used by ClassificationErrorEvaluator // top # classification error optional int32 top_k = 13 [default = 1]; + + // Used by DetectionMAPEvaluator + optional double overlap_threshold = 14 [default = 0.5]; + + optional int32 background_id = 15 [default = 0]; + + optional bool evaluate_difficult = 16 [default = false]; + + optional string ap_type = 17 [default = "11point"]; } message LinkConfig { diff --git a/proto/ParameterConfig.proto b/proto/ParameterConfig.proto index cbcd0af598..580d663246 100644 --- a/proto/ParameterConfig.proto +++ b/proto/ParameterConfig.proto @@ -25,8 +25,10 @@ enum ParameterInitStrategy { } message ParameterUpdaterHookConfig { + // hook type such as 'pruning' required string type = 1; - optional string purning_mask_filename = 2; + // this represents the ratio of zero element to be set by the Parameter + optional double sparsity_ratio = 2 [default = 0.6]; } message ParameterConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index c11dc09a8b..58e4902f57 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1280,20 +1280,23 @@ def parse_maxout(maxout, input_layer_name, maxout_conf): # Define an evaluator @config_func -def Evaluator( - name, - type, - inputs, - chunk_scheme=None, - num_chunk_types=None, - classification_threshold=None, - positive_label=None, - dict_file=None, - result_file=None, - num_results=None, - top_k=None, - delimited=None, - excluded_chunk_types=None, ): +def Evaluator(name, + type, + inputs, + chunk_scheme=None, + num_chunk_types=None, + classification_threshold=None, + positive_label=None, + dict_file=None, + result_file=None, + num_results=None, + top_k=None, + delimited=None, + excluded_chunk_types=None, + overlap_threshold=None, + background_id=None, + evaluate_difficult=None, + ap_type=None): evaluator = g_config.model_config.evaluators.add() evaluator.type = type evaluator.name = MakeLayerNameInSubmodel(name) @@ -1327,6 +1330,18 @@ def Evaluator( if excluded_chunk_types: evaluator.excluded_chunk_types.extend(excluded_chunk_types) + if overlap_threshold is not None: + evaluator.overlap_threshold = overlap_threshold + + if background_id is not None: + evaluator.background_id = background_id + + if evaluate_difficult is not None: + evaluator.evaluate_difficult = evaluate_difficult + + if ap_type is not None: + evaluator.ap_type = ap_type + class LayerBase(object): def __init__( @@ -3124,11 +3139,11 @@ def Layer(name, type, **xargs): @config_func def ParameterHook(type, **kwargs): if type == 'pruning': - mask_filename = kwargs.get('mask_filename', None) - assert mask_filename is not None hook = ParameterUpdaterHookConfig() hook.type = type - hook.purning_mask_filename = mask_filename + sparsity_ratio = kwargs.get('sparsity_ratio', None) + if sparsity_ratio is not None: + hook.sparsity_ratio = sparsity_ratio return hook else: return None @@ -3236,13 +3251,13 @@ def Parameter(name, if update_hooks is not None: if hasattr(update_hooks, '__call__'): - update_hooks = update_hooks(para.name) + update_hooks = update_hooks() if isinstance(update_hooks, list): for hook in update_hooks: para.update_hooks.extend([hook]) else: - para.update_hooks.extend(update_hooks) + para.update_hooks.extend([update_hooks]) g_parameter_map[name] = para if initializer is not None: diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index 4100697c9c..9b9f979bb6 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -14,7 +14,8 @@ from paddle.trainer.config_parser import * __all__ = [ - 'ParamAttr', 'ExtraAttr', 'ParameterAttribute', 'ExtraLayerAttribute' + 'HookAttr', 'ParamAttr', 'ExtraAttr', 'ParameterAttribute', + 'ExtraLayerAttribute' ] @@ -55,6 +56,40 @@ def is_compatible_with(x, Type): return False +class HookAttribute(object): + """ + Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs + during training process of a layer with parameters, such as img_conv layer, fc layer. + + :param type: Hook type, currently supported types: + 'pruning' : user specify a sparsity_ratio before training started, and the + network will prune the parameters based on the sparsity_ratio. + eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6) + The specific usage can be paddle.layer.img_conv(input=img, filter_size=3, + num_channels=3, num_filters=64, + param_attr=ParameterAttribute(update_hooks=hk) ) + The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf + :type type: string + + :param sparsity_ratio: Must be specified if hook type is 'pruning', + it represents the ratio of the zero elements to be set by the Parameter. + :type sparsity_ratio: float or None + + """ + + def __init__(self, type, sparsity_ratio=None): + self.type = type + self.sparsity_ratio = sparsity_ratio + if self.sparsity_ratio is not None: + assert is_compatible_with( + self.sparsity_ratio, + float), 'sparisity_ratio must be float type' + assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparsity_ratio must be a float between [0, 1] ' + + def __call__(self): + return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) + + class ParameterAttribute(object): """ Parameter Attributes object. To fine-tuning network training process, user @@ -114,6 +149,7 @@ class ParameterAttribute(object): momentum=None, gradient_clipping_threshold=None, sparse_update=False, + update_hooks=None, initializer=None): self.attr = {} @@ -169,6 +205,9 @@ class ParameterAttribute(object): if initializer is not None: self.attr['initializer'] = initializer + if update_hooks: + self.attr['update_hooks'] = update_hooks + def set_default_parameter_name(self, name): """ Set default parameter name. If parameter not set, then will use default @@ -244,5 +283,6 @@ class ExtraLayerAttribute(object): return attr.attr +HookAttr = HookAttribute ParamAttr = ParameterAttribute ExtraAttr = ExtraLayerAttribute diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index a5234f3e47..44d52edfa7 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -21,7 +21,8 @@ __all__ = [ "chunk_evaluator", "sum_evaluator", "column_sum_evaluator", "value_printer_evaluator", "gradient_printer_evaluator", "maxid_printer_evaluator", "maxframe_printer_evaluator", - "seqtext_printer_evaluator", "classification_error_printer_evaluator" + "seqtext_printer_evaluator", "classification_error_printer_evaluator", + "detection_map_evaluator" ] @@ -31,10 +32,11 @@ class EvaluatorAttribute(object): FOR_RANK = 1 << 2 FOR_PRINT = 1 << 3 FOR_UTILS = 1 << 4 + FOR_DETECTION = 1 << 5 KEYS = [ "for_classification", "for_regression", "for_rank", "for_print", - "for_utils" + "for_utils", "for_detection" ] @staticmethod @@ -57,22 +59,25 @@ def evaluator(*attrs): return impl -def evaluator_base( - input, - type, - label=None, - weight=None, - name=None, - chunk_scheme=None, - num_chunk_types=None, - classification_threshold=None, - positive_label=None, - dict_file=None, - result_file=None, - num_results=None, - delimited=None, - top_k=None, - excluded_chunk_types=None, ): +def evaluator_base(input, + type, + label=None, + weight=None, + name=None, + chunk_scheme=None, + num_chunk_types=None, + classification_threshold=None, + positive_label=None, + dict_file=None, + result_file=None, + num_results=None, + delimited=None, + top_k=None, + excluded_chunk_types=None, + overlap_threshold=None, + background_id=None, + evaluate_difficult=None, + ap_type=None): """ Evaluator will evaluate the network status while training/testing. @@ -107,6 +112,14 @@ def evaluator_base( :type weight: LayerOutput. :param top_k: number k in top-k error rate :type top_k: int + :param overlap_threshold: In detection tasks to filter detection results + :type overlap_threshold: float + :param background_id: Identifier of background class + :type background_id: int + :param evaluate_difficult: Whether to evaluate difficult objects + :type evaluate_difficult: bool + :param ap_type: How to calculate average persicion + :type ap_type: str """ # inputs type assertions. assert classification_threshold is None or isinstance( @@ -136,7 +149,61 @@ def evaluator_base( delimited=delimited, num_results=num_results, top_k=top_k, - excluded_chunk_types=excluded_chunk_types, ) + excluded_chunk_types=excluded_chunk_types, + overlap_threshold=overlap_threshold, + background_id=background_id, + evaluate_difficult=evaluate_difficult, + ap_type=ap_type) + + +@evaluator(EvaluatorAttribute.FOR_DETECTION) +@wrap_name_default() +def detection_map_evaluator(input, + label, + overlap_threshold=0.5, + background_id=0, + evaluate_difficult=False, + ap_type="11point", + name=None): + """ + Detection mAP Evaluator. It will print mean Average Precision (mAP) for detection. + + The detection mAP Evaluator based on the output of detection_output layer counts + the true positive and the false positive bbox and integral them to get the + mAP. + + The simple usage is: + + .. code-block:: python + + eval = detection_map_evaluator(input=det_output,label=lbl) + + :param input: Input layer. + :type input: LayerOutput + :param label: Label layer. + :type label: LayerOutput + :param overlap_threshold: The bbox overlap threshold of a true positive. + :type overlap_threshold: float + :param background_id: The background class index. + :type background_id: int + :param evaluate_difficult: Whether evaluate a difficult ground truth. + :type evaluate_difficult: bool + """ + if not isinstance(input, list): + input = [input] + + if label: + input.append(label) + + evaluator_base( + name=name, + type="detection_map", + input=input, + label=label, + overlap_threshold=overlap_threshold, + background_id=background_id, + evaluate_difficult=evaluate_difficult, + ap_type=ap_type) @evaluator(EvaluatorAttribute.FOR_CLASSIFICATION) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index b8ce0373c0..84ed160773 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -3839,7 +3839,8 @@ def classification_cost(input, weight=None, name=None, evaluator=classification_error_evaluator, - layer_attr=None): + layer_attr=None, + coeff=1.): """ classification cost Layer. @@ -3855,6 +3856,8 @@ def classification_cost(input, :param evaluator: Evaluator method. :param layer_attr: layer's extra attribute. :type layer_attr: ExtraLayerAttribute + :param coeff: The coefficient affects the gradient in the backward. + :type coeff: float :return: LayerOutput object. :rtype: LayerOutput """ @@ -3868,6 +3871,7 @@ def classification_cost(input, name=name, type="multi-class-cross-entropy", inputs=ipts, + coeff=coeff, **ExtraLayerAttribute.to_kwargs(layer_attr)) def __add_evaluator__(e): diff --git a/python/paddle/v2/attr.py b/python/paddle/v2/attr.py index 32f78614e7..5d23894d73 100644 --- a/python/paddle/v2/attr.py +++ b/python/paddle/v2/attr.py @@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs __all__ = [ "Param", "Extra", + "Hook", ] Param = paddle.trainer_config_helpers.attrs.ParameterAttribute Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute +Hook = paddle.trainer_config_helpers.attrs.HookAttribute for each in paddle.trainer_config_helpers.attrs.__all__: globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each)