commit
91c6a7926e
@ -1,5 +1,5 @@
|
|||||||
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
|
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
|
||||||
go_library(paddle_pserver_cclient STATIC)
|
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
|
||||||
if(WITH_TESTING)
|
if(WITH_TESTING)
|
||||||
add_subdirectory(test)
|
add_subdirectory(test)
|
||||||
endif()
|
endif()
|
@ -0,0 +1,125 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/Paddle/go/pserver"
|
||||||
|
"github.com/coreos/etcd/clientv3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultEtcdTimeout time.Duration = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// EtcdClient is used by pserver client that is a part of trainer process.
|
||||||
|
// TODO:
|
||||||
|
// 1. add watcher to watch the change state of pservers)
|
||||||
|
// 1. add etcd lock)
|
||||||
|
type EtcdClient struct {
|
||||||
|
client *clientv3.Client
|
||||||
|
timeout time.Duration
|
||||||
|
endpoints []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Desired read ps desired number from etcd.
|
||||||
|
func (p *EtcdClient) Desired() int {
|
||||||
|
var psDesired int
|
||||||
|
for {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||||
|
resp, err := p.client.Get(ctx, pserver.PsDesired)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
|
||||||
|
time.Sleep(p.timeout)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
kvs := resp.Kvs
|
||||||
|
if len(kvs) == 0 {
|
||||||
|
log.Infoln("Waiting for ps desired registered ...")
|
||||||
|
time.Sleep(p.timeout)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("psDesired %s invalid %v", psDesired, err)
|
||||||
|
time.Sleep(p.timeout)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Get psDesired number: %d", psDesired)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return psDesired
|
||||||
|
}
|
||||||
|
|
||||||
|
// List return the pserver list read from etcd.
|
||||||
|
func (p *EtcdClient) List() []Server {
|
||||||
|
psDesired := p.Desired()
|
||||||
|
|
||||||
|
servers := make([]Server, psDesired)
|
||||||
|
for {
|
||||||
|
for i := 0; i < psDesired; i++ {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
|
||||||
|
cancel()
|
||||||
|
psKey := pserver.PsPath + strconv.Itoa(i)
|
||||||
|
log.Debugf("checking %s", psKey)
|
||||||
|
resp, err := p.client.Get(ctx, psKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("Get psKey= %s error, %v", psKey, err)
|
||||||
|
time.Sleep(p.timeout)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kvs := resp.Kvs
|
||||||
|
if len(kvs) == 0 {
|
||||||
|
log.Infof("Waiting for ps addr registered ...")
|
||||||
|
time.Sleep(p.timeout)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
psAddr := string(resp.Kvs[0].Value)
|
||||||
|
// TODO(Longfei) check the ps address
|
||||||
|
if psAddr == "" {
|
||||||
|
log.Infof("Get psKey = %s, psAddr is empty", psKey)
|
||||||
|
time.Sleep(p.timeout)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Infof("got value (%s) for key: %s", psAddr, psKey)
|
||||||
|
servers[i].Index = i
|
||||||
|
servers[i].Addr = psAddr
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return servers
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEtcd create a etcd client to return the state of pserver on etcd.
|
||||||
|
func NewEtcd(endpoints string) *EtcdClient {
|
||||||
|
ep := strings.Split(endpoints, ",")
|
||||||
|
var cli *clientv3.Client
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
cli, err = clientv3.New(clientv3.Config{
|
||||||
|
Endpoints: ep,
|
||||||
|
DialTimeout: DefaultEtcdTimeout,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Init etcd connection failed: %v", err)
|
||||||
|
time.Sleep(DefaultEtcdTimeout)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Infof("Connected to etcd: %s\n", endpoints)
|
||||||
|
client := &EtcdClient{
|
||||||
|
client: cli,
|
||||||
|
timeout: DefaultEtcdTimeout,
|
||||||
|
endpoints: ep,
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
@ -0,0 +1 @@
|
|||||||
|
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
|
@ -0,0 +1,104 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
#include "paddle/platform/dynamic_loader.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace platform {
|
||||||
|
namespace dynload {
|
||||||
|
|
||||||
|
std::once_flag cublas_dso_flag;
|
||||||
|
void *cublas_dso_handle = nullptr;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The following macro definition can generate structs
|
||||||
|
* (for each function) to dynamic load cublas routine
|
||||||
|
* via operator overloading.
|
||||||
|
*
|
||||||
|
* note: default dynamic linked libs
|
||||||
|
*/
|
||||||
|
#ifdef PADDLE_USE_DSO
|
||||||
|
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
||||||
|
struct DynLoad__##__name { \
|
||||||
|
template <typename... Args> \
|
||||||
|
cublasStatus_t operator()(Args... args) { \
|
||||||
|
typedef cublasStatus_t (*cublasFunc)(Args...); \
|
||||||
|
std::call_once(cublas_dso_flag, \
|
||||||
|
paddle::platform::dynload::GetCublasDsoHandle, \
|
||||||
|
&cublas_dso_handle); \
|
||||||
|
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
|
||||||
|
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
|
||||||
|
} \
|
||||||
|
} __name; // struct DynLoad__##__name
|
||||||
|
#else
|
||||||
|
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
||||||
|
struct DynLoad__##__name { \
|
||||||
|
template <typename... Args> \
|
||||||
|
cublasStatus_t operator()(Args... args) { \
|
||||||
|
return __name(args...); \
|
||||||
|
} \
|
||||||
|
} __name; // struct DynLoad__##__name
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name)
|
||||||
|
|
||||||
|
// include all needed cublas functions in HPPL
|
||||||
|
// clang-format off
|
||||||
|
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
|
||||||
|
__macro(cublasSgemv) \
|
||||||
|
__macro(cublasDgemv) \
|
||||||
|
__macro(cublasSgemm) \
|
||||||
|
__macro(cublasDgemm) \
|
||||||
|
__macro(cublasSgeam) \
|
||||||
|
__macro(cublasDgeam) \
|
||||||
|
|
||||||
|
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched)
|
||||||
|
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched)
|
||||||
|
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
|
||||||
|
|
||||||
|
#undef DYNAMIC_LOAD_CUBLAS_WRAP
|
||||||
|
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
|
||||||
|
#undef CUBLAS_BLAS_ROUTINE_EACH
|
||||||
|
|
||||||
|
// clang-format on
|
||||||
|
#ifndef PADDLE_TYPE_DOUBLE
|
||||||
|
#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam
|
||||||
|
#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv
|
||||||
|
#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm
|
||||||
|
#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched
|
||||||
|
#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched
|
||||||
|
#else
|
||||||
|
#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam
|
||||||
|
#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv
|
||||||
|
#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm
|
||||||
|
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
|
||||||
|
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
|
||||||
|
#endif
|
||||||
|
} // namespace dynload
|
||||||
|
} // namespace platform
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue