You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							274 lines
						
					
					
						
							7.3 KiB
						
					
					
				
			
		
		
	
	
							274 lines
						
					
					
						
							7.3 KiB
						
					
					
				// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
						|
 | 
						|
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
// you may not use this file except in compliance with the License.
 | 
						|
// You may obtain a copy of the License at
 | 
						|
 | 
						|
// http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
// Unless required by applicable law or agreed to in writing, software
 | 
						|
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
// See the License for the specific language governing permissions and
 | 
						|
// limitations under the License.
 | 
						|
 | 
						|
package main
 | 
						|
 | 
						|
/*
 | 
						|
#include <string.h>
 | 
						|
typedef enum {
 | 
						|
  PADDLE_ELEMENT_TYPE_INT32   = 0,
 | 
						|
  PADDLE_ELEMENT_TYPE_UINT32  = 1,
 | 
						|
  PADDLE_ELEMENT_TYPE_INT64   = 2,
 | 
						|
  PADDLE_ELEMENT_TYPE_UINT64  = 3,
 | 
						|
  PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
 | 
						|
  PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
 | 
						|
} paddle_element_type;
 | 
						|
 | 
						|
typedef struct {
 | 
						|
  char*               name;
 | 
						|
  paddle_element_type element_type;
 | 
						|
  unsigned char*      content;
 | 
						|
  int                 content_len;
 | 
						|
} paddle_parameter, paddle_gradient;
 | 
						|
 | 
						|
typedef int paddle_pserver_client;
 | 
						|
#define PSERVER_ERROR -1
 | 
						|
#define PSERVER_OK 0
 | 
						|
*/
 | 
						|
import "C"
 | 
						|
 | 
						|
import (
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"unsafe"
 | 
						|
 | 
						|
	"github.com/PaddlePaddle/Paddle/go/pserver"
 | 
						|
	"github.com/PaddlePaddle/Paddle/go/pserver/client"
 | 
						|
	log "github.com/sirupsen/logrus"
 | 
						|
)
 | 
						|
 | 
						|
var mu sync.Mutex
 | 
						|
var handleMap = make(map[C.paddle_pserver_client]*client.Client)
 | 
						|
var curHandle C.paddle_pserver_client
 | 
						|
 | 
						|
func add(c *client.Client) C.paddle_pserver_client {
 | 
						|
	mu.Lock()
 | 
						|
	defer mu.Unlock()
 | 
						|
	cli := curHandle
 | 
						|
	curHandle++
 | 
						|
	handleMap[cli] = c
 | 
						|
	return cli
 | 
						|
}
 | 
						|
 | 
						|
func get(client C.paddle_pserver_client) *client.Client {
 | 
						|
	mu.Lock()
 | 
						|
	defer mu.Unlock()
 | 
						|
	return handleMap[client]
 | 
						|
}
 | 
						|
 | 
						|
func remove(client C.paddle_pserver_client) *client.Client {
 | 
						|
	mu.Lock()
 | 
						|
	defer mu.Unlock()
 | 
						|
	h := handleMap[client]
 | 
						|
	delete(handleMap, client)
 | 
						|
	return h
 | 
						|
}
 | 
						|
 | 
						|
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
 | 
						|
	if p == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	// create a Go clice backed by a C array, reference:
 | 
						|
	// https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
 | 
						|
	//
 | 
						|
	// Go garbage collector will not interact with this data, need
 | 
						|
	// to be freed properly.
 | 
						|
	return (*[1 << 30]byte)(p)[:len:len]
 | 
						|
}
 | 
						|
 | 
						|
type selector bool
 | 
						|
 | 
						|
func (s selector) Select() (bool, error) {
 | 
						|
	return bool(s), nil
 | 
						|
}
 | 
						|
 | 
						|
func (s selector) Done() error {
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
type lister []client.Server
 | 
						|
 | 
						|
func (l lister) List() []client.Server {
 | 
						|
	return l
 | 
						|
}
 | 
						|
 | 
						|
//export paddle_new_pserver_client
 | 
						|
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
 | 
						|
	a := C.GoString(addrs)
 | 
						|
	as := strings.Split(a, ",")
 | 
						|
	servers := make([]client.Server, len(as))
 | 
						|
	for i := range as {
 | 
						|
		servers[i].Index = i
 | 
						|
		servers[i].Addr = as[i]
 | 
						|
	}
 | 
						|
	c := client.NewClient(lister(servers), len(as), selector(selected != 0))
 | 
						|
	return add(c)
 | 
						|
}
 | 
						|
 | 
						|
//export paddle_new_etcd_pserver_client
 | 
						|
func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client {
 | 
						|
	addr := C.GoString(etcdEndpoints)
 | 
						|
	etcdClient := client.NewEtcd(addr)
 | 
						|
	c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient)
 | 
						|
	return add(c)
 | 
						|
}
 | 
						|
 | 
						|
//export paddle_pserver_client_release
 | 
						|
func paddle_pserver_client_release(client C.paddle_pserver_client) {
 | 
						|
	remove(client)
 | 
						|
}
 | 
						|
 | 
						|
// paddle_begin_init_params tells trainer if it needs to init the
 | 
						|
// parameters.
 | 
						|
//
 | 
						|
// returns 1 if the trainer needs to init the parameters. 0 if the
 | 
						|
// trainer does not need to init the parameters.
 | 
						|
//
 | 
						|
//export paddle_begin_init_params
 | 
						|
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
 | 
						|
	c := get(client)
 | 
						|
	selected, err := c.BeginInitParams()
 | 
						|
	if err != nil {
 | 
						|
		panic(err)
 | 
						|
	}
 | 
						|
 | 
						|
	if selected {
 | 
						|
		return 1
 | 
						|
	}
 | 
						|
	return 0
 | 
						|
}
 | 
						|
 | 
						|
//export paddle_init_param
 | 
						|
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, paramConfig unsafe.Pointer, configLen C.int) C.int {
 | 
						|
	et := pserver.ElementType(param.element_type)
 | 
						|
	name := C.GoString(param.name)
 | 
						|
	content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
 | 
						|
	pc := pserver.ParameterWithConfig{
 | 
						|
		Param:  pserver.Parameter{Name: name, ElementType: et, Content: content},
 | 
						|
		Config: cArrayToSlice(paramConfig, int(configLen)),
 | 
						|
	}
 | 
						|
	c := get(client)
 | 
						|
	err := c.InitParam(pc)
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		if err.Error() == pserver.AlreadyInitialized {
 | 
						|
			log.Warningf("parameter %s already initialized, treat paddle_init_param as successful.", name)
 | 
						|
			return C.PSERVER_OK
 | 
						|
		}
 | 
						|
		log.Errorln(err)
 | 
						|
		return C.PSERVER_ERROR
 | 
						|
	}
 | 
						|
 | 
						|
	return C.PSERVER_OK
 | 
						|
}
 | 
						|
 | 
						|
//export paddle_finish_init_params
 | 
						|
func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
 | 
						|
	c := get(client)
 | 
						|
	err := c.FinishInitParams()
 | 
						|
	if err != nil {
 | 
						|
		if err.Error() == pserver.AlreadyInitialized {
 | 
						|
			log.Warningln("parameters already initialized, treat paddle_finish_init_params as successful.")
 | 
						|
			return C.PSERVER_OK
 | 
						|
		}
 | 
						|
 | 
						|
		log.Errorln(err)
 | 
						|
		return C.PSERVER_ERROR
 | 
						|
	}
 | 
						|
 | 
						|
	return C.PSERVER_OK
 | 
						|
}
 | 
						|
 | 
						|
//export paddle_send_grads
 | 
						|
func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient, total C.int) C.int {
 | 
						|
	var gs []pserver.Gradient
 | 
						|
	for i := 0; i < int(total); i++ {
 | 
						|
		grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
 | 
						|
		et := pserver.ElementType(grad.element_type)
 | 
						|
		name := C.GoString(grad.name)
 | 
						|
		content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
 | 
						|
		gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content})
 | 
						|
	}
 | 
						|
 | 
						|
	c := get(client)
 | 
						|
	err := c.SendGrads(gs)
 | 
						|
	if err != nil {
 | 
						|
		log.Errorln(err)
 | 
						|
		return C.PSERVER_ERROR
 | 
						|
	}
 | 
						|
 | 
						|
	return C.PSERVER_OK
 | 
						|
}
 | 
						|
 | 
						|
//export paddle_get_params
 | 
						|
func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, total C.int) C.int {
 | 
						|
	var ns []string
 | 
						|
	for i := 0; i < int(total); i++ {
 | 
						|
		param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
 | 
						|
		ns = append(ns, C.GoString(param.name))
 | 
						|
	}
 | 
						|
	c := get(client)
 | 
						|
	ps, err := c.GetParams(ns)
 | 
						|
	if err != nil {
 | 
						|
		log.Errorln(err)
 | 
						|
		return C.PSERVER_ERROR
 | 
						|
	}
 | 
						|
 | 
						|
	if len(ps) != len(ns) {
 | 
						|
		pn := make([]string, len(ps))
 | 
						|
		for i, p := range ps {
 | 
						|
			pn[i] = p.Name
 | 
						|
		}
 | 
						|
		log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", "))
 | 
						|
		return C.PSERVER_ERROR
 | 
						|
	}
 | 
						|
 | 
						|
	for i := range ps {
 | 
						|
		if ns[i] != ps[i].Name {
 | 
						|
			pn := make([]string, len(ps))
 | 
						|
			for i, p := range ps {
 | 
						|
				pn[i] = p.Name
 | 
						|
			}
 | 
						|
			log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", "))
 | 
						|
			return C.PSERVER_ERROR
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	for i := 0; i < int(total); i++ {
 | 
						|
		p := ps[i]
 | 
						|
		param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
 | 
						|
 | 
						|
		if unsafe.Pointer(param) == nil {
 | 
						|
			log.Errorln("must pre-allocate parameter.")
 | 
						|
			return C.PSERVER_ERROR
 | 
						|
		}
 | 
						|
 | 
						|
		if unsafe.Pointer(param.content) != nil {
 | 
						|
			if int(param.content_len) != len(p.Content) {
 | 
						|
				log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
 | 
						|
				return C.PSERVER_ERROR
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
 | 
						|
		param.content_len = C.int(len(p.Content))
 | 
						|
		param.element_type = C.paddle_element_type(p.ElementType)
 | 
						|
	}
 | 
						|
 | 
						|
	return C.PSERVER_OK
 | 
						|
}
 | 
						|
 | 
						|
func main() {} // Required but ignored
 |