fix go api bug. (#31857)

develop
Wilber 5 years ago committed by GitHub
parent e804f08559
commit 70b67f1029
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,6 +50,7 @@ output_data := value.Interface().([][]float32)
运行
```bash
go mod init github.com/paddlepaddle
export LD_LIBRARY_PATH=`pwd`/paddle_c/paddle/lib:$LD_LIBRARY_PATH
go run ./demo/mobilenet.go
```

@ -13,7 +13,7 @@
// limitations under the License.
package main
import "../paddle"
import "github.com/paddlepaddle/paddle"
import "strings"
import "io/ioutil"
import "strconv"

@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h>
// #include <paddle_c_api.h>
import "C"

@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h>
// #include <stdlib.h>
// #include <paddle_c_api.h>

@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h>
// #include "paddle_c_api.h"
import "C"
@ -88,7 +88,7 @@ func (predictor *Predictor) GetInputNames() []string {
}
func (predictor *Predictor) GetOutputNames() []string {
names := make([]string, predictor.GetInputNum())
names := make([]string, predictor.GetOutputNum())
for i := 0; i < len(names); i++ {
names[i] = predictor.GetOutputName(i)
}

@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h>
// #include <stdlib.h>
// #include <string.h>
@ -209,7 +209,7 @@ func DecodeTensor(r *bytes.Reader, shape []int32, t reflect.Type, ptr reflect.Va
value := reflect.Indirect(ptr)
value.Set(reflect.MakeSlice(t, int(shape[0]), int(shape[0])))
if len(shape) == 1 && value.Len() > 0 {
switch value.Index(1).Kind() {
switch value.Index(0).Kind() {
case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32:
binary.Read(r, Endian(), value.Interface())
return

@ -207,13 +207,16 @@ int PD_GetOutputNum(const PD_Predictor* predictor) {
}
const char* PD_GetInputName(const PD_Predictor* predictor, int n) {
static std::vector<std::string> names = predictor->predictor->GetInputNames();
static std::vector<std::string> names;
names.resize(predictor->predictor->GetInputNames().size());
names[n] = predictor->predictor->GetInputNames()[n];
return names[n].c_str();
}
const char* PD_GetOutputName(const PD_Predictor* predictor, int n) {
static std::vector<std::string> names =
predictor->predictor->GetOutputNames();
static std::vector<std::string> names;
names.resize(predictor->predictor->GetOutputNames().size());
names[n] = predictor->predictor->GetOutputNames()[n];
return names[n].c_str();
}

Loading…
Cancel
Save