You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
256 lines
6.3 KiB
256 lines
6.3 KiB
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package paddle
|
|
|
|
// #cgo CFLAGS: -Ipaddle_c/paddle/include
|
|
// #cgo LDFLAGS: -Lpaddle_c/paddle/lib -lpaddle_fluid_c
|
|
// #include <stdbool.h>
|
|
// #include <stdlib.h>
|
|
// #include <string.h>
|
|
// #include <paddle_c_api.h>
|
|
import "C"
|
|
|
|
import "runtime"
|
|
import "reflect"
|
|
import "unsafe"
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
)
|
|
|
|
type PaddleDType C.PD_DataType
|
|
|
|
const (
|
|
FLOAT32 PaddleDType = C.PD_FLOAT32
|
|
INT32 PaddleDType = C.PD_INT32
|
|
INT64 PaddleDType = C.PD_INT64
|
|
UINT8 PaddleDType = C.PD_UINT8
|
|
UNKDTYPE PaddleDType = C.PD_UNKDTYPE
|
|
)
|
|
|
|
var types = []struct {
|
|
gotype reflect.Type
|
|
dtype PaddleDType
|
|
}{
|
|
{reflect.TypeOf(float32(0)), FLOAT32},
|
|
{reflect.TypeOf(int32(0)), INT32},
|
|
{reflect.TypeOf(int64(0)), INT64},
|
|
{reflect.TypeOf(uint8(0)), UINT8},
|
|
}
|
|
|
|
func TypeOfShape(dtype PaddleDType, shape []int32) reflect.Type {
|
|
var ret reflect.Type
|
|
for _, t := range types {
|
|
if dtype == PaddleDType(t.dtype) {
|
|
ret = t.gotype
|
|
break
|
|
}
|
|
}
|
|
|
|
if ret == nil {
|
|
panic(bug("Data %v type is not support", dtype))
|
|
}
|
|
|
|
for range shape {
|
|
ret = reflect.SliceOf(ret)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
type ZeroCopyTensor struct {
|
|
c *C.PD_ZeroCopyTensor
|
|
name string
|
|
shape []int32
|
|
}
|
|
|
|
func NewZeroCopyTensor() *ZeroCopyTensor {
|
|
c_tensor := C.PD_NewZeroCopyTensor()
|
|
|
|
tensor := &ZeroCopyTensor{c: c_tensor}
|
|
runtime.SetFinalizer(tensor, (*ZeroCopyTensor).finalize)
|
|
return tensor
|
|
}
|
|
|
|
func (tensor *ZeroCopyTensor) finalize() {
|
|
C.PD_DeleteZeroCopyTensor(tensor.c)
|
|
}
|
|
|
|
func (tensor *ZeroCopyTensor) Shape() []int32 {
|
|
return tensor.shape
|
|
}
|
|
|
|
func (tensor *ZeroCopyTensor) Name() string {
|
|
return C.GoString(tensor.c.name)
|
|
}
|
|
|
|
func (tensor *ZeroCopyTensor) Rename(name string) {
|
|
tensor.name = name
|
|
tensor.c.name = (*C.char)(unsafe.Pointer(tensor.c.name))
|
|
//tensor.c.name = C.CString(tensor.name)
|
|
//defer C.free(unsafe.Pointer(tensor.c.name))
|
|
}
|
|
|
|
func (tensor *ZeroCopyTensor) Reshape(shape []int32) {
|
|
tensor.shape = make([]int32, len(shape))
|
|
copy(tensor.shape, shape)
|
|
length := C.sizeof_int * C.size_t(len(shape))
|
|
if tensor.c.shape.capacity < C.size_t(length) {
|
|
if tensor.c.shape.capacity != C.size_t(0) {
|
|
C.free(tensor.c.shape.data)
|
|
}
|
|
tensor.c.shape.data = C.malloc(length)
|
|
tensor.c.shape.capacity = length
|
|
}
|
|
tensor.c.shape.length = length
|
|
C.memcpy(tensor.c.shape.data, unsafe.Pointer(&shape[0]), length)
|
|
}
|
|
|
|
func (tensor *ZeroCopyTensor) DataType() PaddleDType {
|
|
return PaddleDType(tensor.c.dtype)
|
|
}
|
|
|
|
func (tensor *ZeroCopyTensor) SetValue(value interface{}) {
|
|
val := reflect.ValueOf(value)
|
|
shape, dtype := ShapeAndTypeOf(val)
|
|
tensor.Reshape(shape)
|
|
num := numel(shape)
|
|
length := C.size_t(SizeofDataType(dtype) * num)
|
|
if tensor.c.data.capacity < length {
|
|
if tensor.c.data.capacity != C.size_t(0) {
|
|
C.free(tensor.c.data.data)
|
|
}
|
|
tensor.c.data.data = C.malloc(length)
|
|
tensor.c.data.capacity = length
|
|
}
|
|
tensor.c.data.length = length
|
|
|
|
switch dtype {
|
|
case PaddleDType(UINT8):
|
|
data := val.Interface().([]uint8)
|
|
C.memcpy(tensor.c.data.data, unsafe.Pointer(&data[0]), length)
|
|
case PaddleDType(INT32):
|
|
data := val.Interface().([]int32)
|
|
C.memcpy(tensor.c.data.data, unsafe.Pointer(&data[0]), length)
|
|
case PaddleDType(INT64):
|
|
data := val.Interface().([]int64)
|
|
C.memcpy(tensor.c.data.data, unsafe.Pointer(&data[0]), length)
|
|
case PaddleDType(FLOAT32):
|
|
data := val.Interface().([]float32)
|
|
C.memcpy(tensor.c.data.data, unsafe.Pointer(&data[0]), length)
|
|
}
|
|
tensor.c.dtype = C.PD_DataType(dtype)
|
|
}
|
|
|
|
func TypeOf(dtype PaddleDType, shape []int32) reflect.Type {
|
|
var ret reflect.Type
|
|
for _, t := range types {
|
|
if t.dtype == dtype {
|
|
ret = t.gotype
|
|
break
|
|
}
|
|
}
|
|
|
|
for range shape {
|
|
ret = reflect.SliceOf(ret)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func (tensor *ZeroCopyTensor) Value() interface{} {
|
|
t := TypeOf(PaddleDType(tensor.c.dtype), tensor.shape)
|
|
value := reflect.New(t)
|
|
c_bytes := tensor.c.data.data
|
|
length := tensor.c.data.length
|
|
var slice []byte
|
|
if unsafe.Sizeof(unsafe.Pointer(nil)) == 8 {
|
|
slice = (*[1<<50 - 1]byte)(unsafe.Pointer(c_bytes))[:length:length]
|
|
} else {
|
|
slice = (*[1 << 30]byte)(unsafe.Pointer(c_bytes))[:length:length]
|
|
}
|
|
r := bytes.NewReader(slice)
|
|
DecodeTensor(r, tensor.Shape(), t, value)
|
|
return reflect.Indirect(value).Interface()
|
|
}
|
|
|
|
func Endian() binary.ByteOrder {
|
|
buf := [2]byte{}
|
|
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
|
|
|
|
var endian binary.ByteOrder
|
|
|
|
switch buf {
|
|
case [2]byte{0xCD, 0xAB}:
|
|
endian = binary.LittleEndian
|
|
case [2]byte{0xAB, 0xCD}:
|
|
endian = binary.BigEndian
|
|
default:
|
|
panic("Could not determine native endianness.")
|
|
}
|
|
return endian
|
|
}
|
|
|
|
func DecodeTensor(r *bytes.Reader, shape []int32, t reflect.Type, ptr reflect.Value) {
|
|
switch t.Kind() {
|
|
case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32:
|
|
binary.Read(r, Endian(), ptr.Interface())
|
|
case reflect.Slice:
|
|
value := reflect.Indirect(ptr)
|
|
value.Set(reflect.MakeSlice(t, int(shape[0]), int(shape[0])))
|
|
if len(shape) == 1 && value.Len() > 0 {
|
|
switch value.Index(1).Kind() {
|
|
case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32:
|
|
binary.Read(r, Endian(), value.Interface())
|
|
return
|
|
}
|
|
}
|
|
|
|
for i := 0; i < value.Len(); i++ {
|
|
DecodeTensor(r, shape[1:], t.Elem(), value.Index(i).Addr())
|
|
}
|
|
}
|
|
}
|
|
|
|
func SizeofDataType(dtype PaddleDType) int32 {
|
|
switch dtype {
|
|
case UINT8:
|
|
return int32(C.sizeof_uchar)
|
|
case INT32:
|
|
return int32(C.sizeof_int)
|
|
case INT64:
|
|
return int32(C.sizeof_longlong)
|
|
case FLOAT32:
|
|
return int32(C.sizeof_float)
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func ShapeAndTypeOf(val reflect.Value) (shape []int32, dt PaddleDType) {
|
|
gotype := val.Type()
|
|
for gotype.Kind() == reflect.Array || gotype.Kind() == reflect.Slice {
|
|
shape = append(shape, int32(val.Len()))
|
|
if val.Len() > 0 {
|
|
val = val.Index(0)
|
|
}
|
|
gotype = gotype.Elem()
|
|
}
|
|
|
|
for _, t := range types {
|
|
if gotype.Kind() == t.gotype.Kind() {
|
|
return shape, PaddleDType(t.dtype)
|
|
}
|
|
}
|
|
return shape, dt
|
|
}
|