commit
91c6a7926e
@ -1,5 +1,5 @@
|
||||
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
|
||||
go_library(paddle_pserver_cclient STATIC)
|
||||
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
|
||||
if(WITH_TESTING)
|
||||
add_subdirectory(test)
|
||||
endif()
|
@ -0,0 +1,125 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/pserver"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultEtcdTimeout time.Duration = 5 * time.Second
|
||||
)
|
||||
|
||||
// EtcdClient is used by pserver client that is a part of trainer process.
|
||||
// TODO:
|
||||
// 1. add watcher to watch the change state of pservers)
|
||||
// 1. add etcd lock)
|
||||
type EtcdClient struct {
|
||||
client *clientv3.Client
|
||||
timeout time.Duration
|
||||
endpoints []string
|
||||
}
|
||||
|
||||
// Desired read ps desired number from etcd.
|
||||
func (p *EtcdClient) Desired() int {
|
||||
var psDesired int
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||
resp, err := p.client.Get(ctx, pserver.PsDesired)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
|
||||
time.Sleep(p.timeout)
|
||||
continue
|
||||
}
|
||||
|
||||
kvs := resp.Kvs
|
||||
if len(kvs) == 0 {
|
||||
log.Infoln("Waiting for ps desired registered ...")
|
||||
time.Sleep(p.timeout)
|
||||
continue
|
||||
}
|
||||
|
||||
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
|
||||
if err != nil {
|
||||
log.Errorf("psDesired %s invalid %v", psDesired, err)
|
||||
time.Sleep(p.timeout)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("Get psDesired number: %d", psDesired)
|
||||
break
|
||||
}
|
||||
return psDesired
|
||||
}
|
||||
|
||||
// List return the pserver list read from etcd.
|
||||
func (p *EtcdClient) List() []Server {
|
||||
psDesired := p.Desired()
|
||||
|
||||
servers := make([]Server, psDesired)
|
||||
for {
|
||||
for i := 0; i < psDesired; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||
cancel()
|
||||
psKey := pserver.PsPath + strconv.Itoa(i)
|
||||
log.Debugf("checking %s", psKey)
|
||||
resp, err := p.client.Get(ctx, psKey)
|
||||
if err != nil {
|
||||
log.Infof("Get psKey= %s error, %v", psKey, err)
|
||||
time.Sleep(p.timeout)
|
||||
continue
|
||||
}
|
||||
kvs := resp.Kvs
|
||||
if len(kvs) == 0 {
|
||||
log.Infof("Waiting for ps addr registered ...")
|
||||
time.Sleep(p.timeout)
|
||||
continue
|
||||
}
|
||||
|
||||
psAddr := string(resp.Kvs[0].Value)
|
||||
// TODO(Longfei) check the ps address
|
||||
if psAddr == "" {
|
||||
log.Infof("Get psKey = %s, psAddr is empty", psKey)
|
||||
time.Sleep(p.timeout)
|
||||
continue
|
||||
}
|
||||
log.Infof("got value (%s) for key: %s", psAddr, psKey)
|
||||
servers[i].Index = i
|
||||
servers[i].Addr = psAddr
|
||||
}
|
||||
break
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
// NewEtcd create a etcd client to return the state of pserver on etcd.
|
||||
func NewEtcd(endpoints string) *EtcdClient {
|
||||
ep := strings.Split(endpoints, ",")
|
||||
var cli *clientv3.Client
|
||||
var err error
|
||||
for {
|
||||
cli, err = clientv3.New(clientv3.Config{
|
||||
Endpoints: ep,
|
||||
DialTimeout: DefaultEtcdTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("Init etcd connection failed: %v", err)
|
||||
time.Sleep(DefaultEtcdTimeout)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
log.Infof("Connected to etcd: %s\n", endpoints)
|
||||
client := &EtcdClient{
|
||||
client: cli,
|
||||
timeout: DefaultEtcdTimeout,
|
||||
endpoints: ep,
|
||||
}
|
||||
return client
|
||||
}
|
@ -0,0 +1 @@
|
||||
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
|
@ -0,0 +1,104 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include "paddle/platform/dynamic_loader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
namespace dynload {
|
||||
|
||||
std::once_flag cublas_dso_flag;
|
||||
void *cublas_dso_handle = nullptr;
|
||||
|
||||
/**
|
||||
* The following macro definition can generate structs
|
||||
* (for each function) to dynamic load cublas routine
|
||||
* via operator overloading.
|
||||
*
|
||||
* note: default dynamic linked libs
|
||||
*/
|
||||
#ifdef PADDLE_USE_DSO
|
||||
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
||||
struct DynLoad__##__name { \
|
||||
template <typename... Args> \
|
||||
cublasStatus_t operator()(Args... args) { \
|
||||
typedef cublasStatus_t (*cublasFunc)(Args...); \
|
||||
std::call_once(cublas_dso_flag, \
|
||||
paddle::platform::dynload::GetCublasDsoHandle, \
|
||||
&cublas_dso_handle); \
|
||||
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
|
||||
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
|
||||
} \
|
||||
} __name; // struct DynLoad__##__name
|
||||
#else
|
||||
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
||||
struct DynLoad__##__name { \
|
||||
template <typename... Args> \
|
||||
cublasStatus_t operator()(Args... args) { \
|
||||
return __name(args...); \
|
||||
} \
|
||||
} __name; // struct DynLoad__##__name
|
||||
#endif
|
||||
|
||||
#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name)
|
||||
|
||||
// include all needed cublas functions in HPPL
|
||||
// clang-format off
|
||||
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
|
||||
__macro(cublasSgemv) \
|
||||
__macro(cublasDgemv) \
|
||||
__macro(cublasSgemm) \
|
||||
__macro(cublasDgemm) \
|
||||
__macro(cublasSgeam) \
|
||||
__macro(cublasDgeam) \
|
||||
|
||||
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate)
|
||||
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy)
|
||||
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream)
|
||||
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode)
|
||||
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode)
|
||||
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched)
|
||||
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched)
|
||||
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched)
|
||||
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched)
|
||||
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched)
|
||||
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched)
|
||||
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched)
|
||||
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched)
|
||||
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
|
||||
|
||||
#undef DYNAMIC_LOAD_CUBLAS_WRAP
|
||||
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
|
||||
#undef CUBLAS_BLAS_ROUTINE_EACH
|
||||
|
||||
// clang-format on
|
||||
#ifndef PADDLE_TYPE_DOUBLE
|
||||
#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam
|
||||
#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv
|
||||
#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm
|
||||
#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched
|
||||
#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched
|
||||
#else
|
||||
#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam
|
||||
#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv
|
||||
#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm
|
||||
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
|
||||
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
|
||||
#endif
|
||||
} // namespace dynload
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue