Merge pull request #2172 from helinwang/cgo
parameter client library: stub and cgo part with functional test.refactor_docs
commit
501b59af69
@ -0,0 +1,34 @@
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
|
||||
if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES)
|
||||
message("-- Found gtest (include: ${GTEST_INCLUDE_DIR}, library: ${GTEST_LIBRARIES})")
|
||||
else()
|
||||
# find #include <majel/xx.h>
|
||||
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
|
||||
include_directories(${PARENT_DIR})
|
||||
|
||||
# find cmake directory modules
|
||||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
|
||||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
|
||||
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
|
||||
|
||||
# enable c++11
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
|
||||
# enable gtest
|
||||
set(THIRD_PARTY_PATH ./third_party)
|
||||
set(WITH_TESTING ON)
|
||||
include(external/gtest)
|
||||
endif()
|
||||
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
|
||||
project(cxx_go CXX C Go)
|
||||
|
||||
include(cmake/golang.cmake)
|
||||
include(cmake/flags.cmake)
|
||||
|
||||
ExternalGoProject_Add(pserver github.com/PaddlePaddle/Paddle/paddle/go/pserver)
|
||||
add_go_library(client STATIC pserver)
|
||||
add_subdirectory(test)
|
@ -0,0 +1,239 @@
|
||||
package main
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
typedef enum {
|
||||
PADDLE_ELEMENT_TYPE_INT32 = 0,
|
||||
PADDLE_ELEMENT_TYPE_UINT32 = 1,
|
||||
PADDLE_ELEMENT_TYPE_INT64 = 2,
|
||||
PADDLE_ELEMENT_TYPE_UINT64 = 3,
|
||||
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
|
||||
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
|
||||
} paddle_element_type;
|
||||
|
||||
typedef struct {
|
||||
char* name;
|
||||
paddle_element_type element_type;
|
||||
unsigned char* content;
|
||||
int content_len;
|
||||
} paddle_parameter, paddle_gradient;
|
||||
|
||||
static inline void paddle_release_param(paddle_parameter* param) {
|
||||
if (param != NULL) {
|
||||
if (param->name != NULL) {
|
||||
free(param->name);
|
||||
}
|
||||
|
||||
if (param->content != NULL) {
|
||||
free(param->content);
|
||||
}
|
||||
|
||||
free(param);
|
||||
}
|
||||
}
|
||||
|
||||
typedef int client;
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
|
||||
)
|
||||
|
||||
var nullPtr = unsafe.Pointer(uintptr(0))
|
||||
var mu sync.Mutex
|
||||
var handleMap = make(map[C.client]*pserver.Client)
|
||||
var curHandle C.client
|
||||
|
||||
func add(c *pserver.Client) C.client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
client := curHandle
|
||||
curHandle++
|
||||
handleMap[client] = c
|
||||
return client
|
||||
}
|
||||
|
||||
func get(client C.client) *pserver.Client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return handleMap[client]
|
||||
}
|
||||
|
||||
func remove(client C.client) *pserver.Client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
h := handleMap[client]
|
||||
delete(handleMap, client)
|
||||
return h
|
||||
}
|
||||
|
||||
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
|
||||
if p == nullPtr {
|
||||
return nil
|
||||
}
|
||||
|
||||
// create a Go clice backed by a C array,
|
||||
// reference: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
return (*[1 << 30]byte)(p)[:len:len]
|
||||
}
|
||||
|
||||
//export paddle_new_pserver_client
|
||||
func paddle_new_pserver_client(addr *C.char) C.client {
|
||||
c := pserver.NewClient(C.GoString(addr))
|
||||
return add(c)
|
||||
}
|
||||
|
||||
//export paddle_pserver_client_release
|
||||
func paddle_pserver_client_release(client C.client) {
|
||||
c := remove(client)
|
||||
c.Cleanup()
|
||||
}
|
||||
|
||||
//export paddle_begin_init_params
|
||||
func paddle_begin_init_params(client C.client, pserver_config unsafe.Pointer, config_len C.int) C.int {
|
||||
c := get(client)
|
||||
b := cArrayToSlice(pserver_config, int(config_len))
|
||||
selected, err := c.BeginInitParams(b)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return -1
|
||||
}
|
||||
|
||||
if selected {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
//export paddle_init_param
|
||||
func paddle_init_param(client C.client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
|
||||
et := pserver.ElementType(param.element_type)
|
||||
name := C.GoString(param.name)
|
||||
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
|
||||
pc := pserver.ParameterWithConfig{
|
||||
Param: pserver.Parameter{Name: name, ElementType: et, Content: content},
|
||||
Config: cArrayToSlice(param_config, int(config_len)),
|
||||
}
|
||||
c := get(client)
|
||||
err := c.InitParam(pc)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return -1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
//export paddle_finish_init_params
|
||||
func paddle_finish_init_params(client C.client) C.int {
|
||||
c := get(client)
|
||||
err := c.FinishInitParams()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return -1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
//export paddle_send_grads
|
||||
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int {
|
||||
var gs []pserver.Gradient
|
||||
for i := 0; i < int(total); i++ {
|
||||
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
|
||||
et := pserver.ElementType(grad.element_type)
|
||||
name := C.GoString(grad.name)
|
||||
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
|
||||
gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content})
|
||||
}
|
||||
|
||||
c := get(client)
|
||||
err := c.SendGrads(gs)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return -1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
//export paddle_get_params
|
||||
func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter, total C.int) C.int {
|
||||
var ns []string
|
||||
for i := 0; i < int(total); i++ {
|
||||
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names))))
|
||||
ns = append(ns, C.GoString(name))
|
||||
}
|
||||
c := get(client)
|
||||
ps, err := c.GetParams(ns)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return -1
|
||||
}
|
||||
|
||||
for i := 0; i < int(total); i++ {
|
||||
if i >= len(ps) {
|
||||
break
|
||||
}
|
||||
|
||||
p := ps[i]
|
||||
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
|
||||
nameReady := false
|
||||
contentAllocated := false
|
||||
|
||||
if unsafe.Pointer(param) == nullPtr {
|
||||
param = (*C.paddle_parameter)(C.calloc(1, C.size_t(unsafe.Sizeof(*param))))
|
||||
} else {
|
||||
if unsafe.Pointer(param.name) != nullPtr {
|
||||
if n := C.GoString(param.name); n != p.Name {
|
||||
log.Println("Warning: the pre-allocated parameter name does not match the parameter name, it will be freed.", n, p.Name)
|
||||
C.free(unsafe.Pointer(param.name))
|
||||
} else {
|
||||
nameReady = true
|
||||
}
|
||||
}
|
||||
|
||||
if unsafe.Pointer(param.content) != nullPtr {
|
||||
if int(param.content_len) == len(p.Content) {
|
||||
contentAllocated = true
|
||||
} else {
|
||||
log.Println("Warning: the pre-allocated content len does not match parameter content len, the pre-allocated content will be freed.", param.content_len, len(p.Content))
|
||||
C.free(unsafe.Pointer(param.content))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !nameReady {
|
||||
param.name = C.CString(p.Name)
|
||||
}
|
||||
if !contentAllocated {
|
||||
param.content = (*C.uchar)(C.malloc(C.size_t(len(p.Content))))
|
||||
}
|
||||
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
|
||||
param.content_len = C.int(len(p.Content))
|
||||
param.element_type = C.paddle_element_type(p.ElementType)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
//export paddle_save_model
|
||||
func paddle_save_model(client C.client, path *C.char) C.int {
|
||||
p := C.GoString(path)
|
||||
c := get(client)
|
||||
err := c.SaveModel(p)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return -1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func main() {} // Required but ignored
|
@ -0,0 +1,44 @@
|
||||
if(NOT CMAKE_Go_COMPILER)
|
||||
if(NOT $ENV{GO_COMPILER} STREQUAL "")
|
||||
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT)
|
||||
|
||||
if(CMAKE_Go_FLAGS_ENV_INIT)
|
||||
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler")
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT})
|
||||
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
set(Go_BIN_PATH
|
||||
$ENV{GOPATH}
|
||||
$ENV{GOROOT}
|
||||
$ENV{GOROOT}/../bin
|
||||
$ENV{GO_COMPILER}
|
||||
/usr/bin
|
||||
/usr/local/bin
|
||||
)
|
||||
|
||||
if(CMAKE_Go_COMPILER_INIT)
|
||||
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler")
|
||||
else()
|
||||
find_program(CMAKE_Go_COMPILER
|
||||
NAMES go
|
||||
PATHS ${Go_BIN_PATH}
|
||||
)
|
||||
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION)
|
||||
STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}")
|
||||
message("-- The Golang compiler identification is ${VERSION}")
|
||||
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CMAKE_Go_COMPILER)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in
|
||||
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)
|
||||
|
||||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
|
@ -0,0 +1,8 @@
|
||||
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@")
|
||||
set(CMAKE_Go_COMPILER_LOADED 1)
|
||||
|
||||
set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go)
|
||||
set(CMAKE_Go_LINKER_PREFERENCE 40)
|
||||
set(CMAKE_Go_OUTPUT_EXTENSION .o)
|
||||
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1)
|
||||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
|
@ -0,0 +1,7 @@
|
||||
if(NOT CMAKE_Go_COMPILE_OBJECT)
|
||||
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ")
|
||||
endif()
|
||||
|
||||
if(NOT CMAKE_Go_LINK_EXECUTABLE)
|
||||
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ")
|
||||
endif()
|
@ -0,0 +1 @@
|
||||
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "")
|
@ -0,0 +1,45 @@
|
||||
# Setting Paddle Compile Flags
|
||||
include(CheckCXXCompilerFlag)
|
||||
include(CheckCCompilerFlag)
|
||||
include(CheckCXXSymbolExists)
|
||||
include(CheckTypeSize)
|
||||
|
||||
function(CheckCompilerCXX11Flag)
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
|
||||
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
|
||||
endif()
|
||||
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
|
||||
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
|
||||
# https://gist.github.com/yamaya/2924292
|
||||
if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X
|
||||
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1)
|
||||
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
|
||||
endif()
|
||||
else()
|
||||
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
|
||||
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
CheckCompilerCXX11Flag()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
|
||||
# Common gpu architectures: Kepler, Maxwell
|
||||
foreach(capability 30 35 50)
|
||||
list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}")
|
||||
endforeach()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0")
|
||||
list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52")
|
||||
endif()
|
||||
|
||||
# Modern gpu architectures: Pascal
|
||||
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
|
||||
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
|
||||
endif()
|
||||
|
||||
set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})
|
@ -0,0 +1,46 @@
|
||||
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
|
||||
file(MAKE_DIRECTORY ${GOPATH})
|
||||
|
||||
function(ExternalGoProject_Add TARG)
|
||||
add_custom_target(${TARG} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN})
|
||||
endfunction(ExternalGoProject_Add)
|
||||
|
||||
function(add_go_executable NAME)
|
||||
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
|
||||
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
|
||||
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
|
||||
-o "${CMAKE_CURRENT_BINARY_DIR}/${NAME}"
|
||||
${CMAKE_GO_FLAGS} ${GO_SOURCE}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
|
||||
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${NAME} DESTINATION bin)
|
||||
endfunction(add_go_executable)
|
||||
|
||||
|
||||
function(ADD_GO_LIBRARY NAME BUILD_TYPE)
|
||||
if(BUILD_TYPE STREQUAL "STATIC")
|
||||
set(BUILD_MODE -buildmode=c-archive)
|
||||
set(LIB_NAME "lib${NAME}.a")
|
||||
else()
|
||||
set(BUILD_MODE -buildmode=c-shared)
|
||||
if(APPLE)
|
||||
set(LIB_NAME "lib${NAME}.dylib")
|
||||
else()
|
||||
set(LIB_NAME "lib${NAME}.so")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
|
||||
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
|
||||
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
|
||||
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
|
||||
${CMAKE_GO_FLAGS} ${GO_SOURCE}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
|
||||
|
||||
if(NOT BUILD_TYPE STREQUAL "STATIC")
|
||||
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin)
|
||||
endif()
|
||||
endfunction(ADD_GO_LIBRARY)
|
@ -0,0 +1,8 @@
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
|
||||
include_directories(/env/gopath/src/github.com/PaddlePaddle/Paddle/paddle/go/cclient/build/)
|
||||
|
||||
add_executable(main main.c)
|
||||
add_dependencies(main client)
|
||||
set (CMAKE_EXE_LINKER_FLAGS "-pthread")
|
||||
target_link_libraries(main /env/gopath/src/github.com/PaddlePaddle/Paddle/paddle/go/cclient/build/libclient.a) # ${GTEST_LIBRARIES})
|
@ -0,0 +1,69 @@
|
||||
#include "libclient.h"
|
||||
|
||||
//#include "gtest/gtest.h"
|
||||
|
||||
void panic() {
|
||||
// TODO(helin): fix: gtest using cmake is not working, using this
|
||||
// hacky way for now.
|
||||
*(void*)0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
char addr[] = "localhost:3000";
|
||||
client c = paddle_new_pserver_client(addr);
|
||||
retry:
|
||||
if (paddle_begin_init_params(c, NULL, 0)) {
|
||||
paddle_parameter param;
|
||||
char name_a[] = "param_a";
|
||||
char name_b[] = "param_b";
|
||||
char content[] = {0x00, 0x11, 0x22};
|
||||
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
|
||||
param.name = name_a;
|
||||
param.content = content;
|
||||
param.content_len = 3;
|
||||
if (paddle_init_param(c, param, NULL, 0) != 0) {
|
||||
goto retry;
|
||||
}
|
||||
param.element_type = PADDLE_ELEMENT_TYPE_INT32;
|
||||
param.name = name_b;
|
||||
param.content = content;
|
||||
param.content_len = 3;
|
||||
if (paddle_init_param(c, param, NULL, 0) != 0) {
|
||||
goto retry;
|
||||
}
|
||||
if (paddle_finish_init_params(c) != 0) {
|
||||
goto retry;
|
||||
}
|
||||
} else {
|
||||
panic();
|
||||
}
|
||||
|
||||
char content[] = {0x00, 0x11, 0x22};
|
||||
paddle_gradient grads[2] = {
|
||||
{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3},
|
||||
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
|
||||
|
||||
if (!paddle_send_grads(c, grads, 2)) {
|
||||
panic();
|
||||
}
|
||||
|
||||
paddle_parameter* params[2] = {NULL, NULL};
|
||||
char* names[] = {"param_a", "param_b"};
|
||||
if (!paddle_get_params(c, names, params, 2)) {
|
||||
panic();
|
||||
}
|
||||
|
||||
// get parameters again by reusing the allocated parameter buffers.
|
||||
if (!paddle_get_params(c, names, params, 2)) {
|
||||
panic();
|
||||
}
|
||||
|
||||
paddle_release_param(params[0]);
|
||||
paddle_release_param(params[1]);
|
||||
|
||||
if (!paddle_save_model(c, "/tmp/")) {
|
||||
panic();
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,83 @@
|
||||
package pserver
|
||||
|
||||
// ElementType is the type of elements of a Parameter.
|
||||
type ElementType int
|
||||
|
||||
// Supported element types
|
||||
const (
|
||||
Int32 ElementType = iota
|
||||
UInt32
|
||||
Int64
|
||||
UInt64
|
||||
Float32
|
||||
Float64
|
||||
)
|
||||
|
||||
// Parameter is a piece of data to sync with the parameter server.
|
||||
type Parameter struct {
|
||||
Name string
|
||||
ElementType ElementType
|
||||
Content []byte
|
||||
}
|
||||
|
||||
// ParameterWithConfig contains the parameter and the configuration.
|
||||
type ParameterWithConfig struct {
|
||||
Param Parameter
|
||||
Config []byte // parameter configuration in Proto Buffer format
|
||||
}
|
||||
|
||||
// Gradient is the gradient of the parameter.
|
||||
type Gradient Parameter
|
||||
|
||||
// Client is the client to parameter servers.
|
||||
type Client struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new client.
|
||||
func NewClient(addr string) *Client {
|
||||
return &Client{}
|
||||
}
|
||||
|
||||
// BeginInitParams begins to initialize parameters on parameter
|
||||
// servers.
|
||||
//
|
||||
// BeginInitParams will be called from multiple trainers, only one
|
||||
// trainer will be selected to initialize the parameters on parameter
|
||||
// servers. Other trainers will be blocked until the initialization is
|
||||
// done, and they need to get the initialized parameters from
|
||||
// parameter servers using GetParams.
|
||||
func (c *Client) BeginInitParams(pserverConfigProto []byte) (selected bool, err error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// InitParam initializes the parameter on parameter servers.
|
||||
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinishInitParams tells parameter servers client has sent all
|
||||
// parameters to parameter servers as initialization.
|
||||
func (c *Client) FinishInitParams() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendGrads sends gradients to parameter servers for updating
|
||||
// parameters.
|
||||
func (c *Client) SendGrads(grads []Gradient) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetParams gets parameters from parameter servers.
|
||||
func (c *Client) GetParams(names []string) ([]Parameter, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// SaveModel indicates parameters to save the parameter to the given
|
||||
// path.
|
||||
func (c *Client) SaveModel(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup cleans up the client states.
|
||||
func (c *Client) Cleanup() {
|
||||
}
|
Loading…
Reference in new issue