diff --git a/func.go b/func.go index 1a4552de..2df1f387 100644 --- a/func.go +++ b/func.go @@ -10,7 +10,6 @@ import ( "math" "reflect" "runtime" - "strconv" "sync" "unsafe" @@ -22,6 +21,15 @@ var thePool = sync.Pool{New: func() any { return new(syscall15Args) }} +var structTypeCache sync.Map +var structInstancePool sync.Map // map[reflect.Type]*sync.Pool + +// Pre-computed field names to avoid allocations +var fieldNames = [maxArgs]string{ + "X0", "X1", "X2", "X3", "X4", "X5", "X6", "X7", + "X8", "X9", "X10", "X11", "X12", "X13", "X14", +} + // RegisterLibFunc is a wrapper around RegisterFunc that uses the C function returned from Dlsym(handle, name). // It panics if it can't find the name symbol. func RegisterLibFunc(fptr any, handle uintptr, name string) { @@ -283,8 +291,10 @@ func RegisterFunc(fptr any, cfn uintptr) { } if runtime.GOARCH == "arm64" && runtime.GOOS == "darwin" && (numInts >= numOfIntegerRegisters() || numFloats >= numOfFloatRegisters) && v.Kind() != reflect.Struct { // hit the stack - fields := make([]reflect.StructField, len(args[i:])) + fields := make([]reflect.StructField, 0, 8) + // Build type hash as we build fields (avoids string allocation) + var typeHash uintptr for j, val := range args[i:] { if val.Kind() == reflect.String { ptr := strings.CString(val.String()) @@ -292,17 +302,28 @@ func RegisterFunc(fptr any, cfn uintptr) { val = reflect.ValueOf(ptr) args[i+j] = val } - fields[j] = reflect.StructField{ - Name: "X" + strconv.Itoa(j), - Type: val.Type(), - } + valType := val.Type() + fields = append(fields, reflect.StructField{ + Name: fieldNames[j], + Type: valType, + }) + // Hash the type pointer for cache key (use interface value directly) + typeHash = typeHash*31 ^ uintptr((*[2]uintptr)(unsafe.Pointer(&valType))[1]) } - structType := reflect.StructOf(fields) - structInstance := reflect.New(structType).Elem() + + var structType reflect.Type + if cached, ok := structTypeCache.Load(typeHash); ok { + structType = cached.(reflect.Type) + } else { + structType = reflect.StructOf(fields) + structTypeCache.Store(typeHash, structType) + } + structInstance := getPooledStructInstance(structType) for j, val := range args[i:] { structInstance.Field(j).Set(val) } placeRegisters(structInstance, addFloat, addInt) + returnPooledStructInstance(structType, structInstance) break } keepAlive = addValue(v, keepAlive, addInt, addFloat, addStack, &numInts, &numFloats, &numStack) @@ -476,6 +497,40 @@ func roundUpTo8(val uintptr) uintptr { return (val + 7) &^ 7 } +func getPooledStructInstance(t reflect.Type) reflect.Value { + // Try to load existing pool first (fast path) + if poolInterface, ok := structInstancePool.Load(t); ok { + pool := poolInterface.(*sync.Pool) + ptr := pool.Get() + val := reflect.ValueOf(ptr).Elem() + return val + } + + // Slow path: create new pool (only happens once per type) + newPool := &sync.Pool{ + New: func() any { + return reflect.New(t).Interface() + }, + } + poolInterface, _ := structInstancePool.LoadOrStore(t, newPool) + pool := poolInterface.(*sync.Pool) + ptr := pool.Get() + val := reflect.ValueOf(ptr).Elem() + return val +} + +func returnPooledStructInstance(t reflect.Type, v reflect.Value) { + if poolInterface, ok := structInstancePool.Load(t); ok { + pool := poolInterface.(*sync.Pool) + // Zero all fields before returning to pool + for i := 0; i < v.NumField(); i++ { + v.Field(i).SetZero() + } + ptr := v.Addr().Interface() + pool.Put(ptr) + } +} + func numOfIntegerRegisters() int { switch runtime.GOARCH { case "arm64", "loong64": diff --git a/struct_arm64.go b/struct_arm64.go index b11983f3..8605e77b 100644 --- a/struct_arm64.go +++ b/struct_arm64.go @@ -117,6 +117,8 @@ func placeRegisters(v reflect.Value, addFloat func(uintptr), addInt func(uintptr } else { addInt(uintptr(val)) } + val = 0 + class = _NO_CLASS } switch f.Type().Kind() { case reflect.Struct: diff --git a/struct_test.go b/struct_test.go index b7347e09..4e3483b3 100644 --- a/struct_test.go +++ b/struct_test.go @@ -373,8 +373,9 @@ func TestRegisterFunc_structArgs(t *testing.T) { } var Array4CharsFn func(chars Array4Chars) int32 purego.RegisterLibFunc(&Array4CharsFn, lib, "Array4Chars") - if ret := Array4CharsFn(Array4Chars{a: [...]int8{100, -127, 4, -100}}); ret != expectedSigned { - t.Fatalf("Array4CharsFn returned %#x wanted %#x", ret, expectedSigned) + const expectedSum = 1 + 2 + 4 + 8 + if ret := Array4CharsFn(Array4Chars{a: [...]int8{1, 2, 4, 8}}); ret != expectedSum { + t.Fatalf("Array4CharsFn returned %d wanted %d", ret, expectedSum) } } { @@ -486,6 +487,18 @@ func TestRegisterFunc_structArgs(t *testing.T) { t.Fatalf("FloatAndBool(y: false) = %d, want 0", ret) } } + { + type FourInt32s struct { + f0, f1, f2, f3 int32 + } + var FourInt32sFn func(FourInt32s) int32 + purego.RegisterLibFunc(&FourInt32sFn, lib, "FourInt32s") + result := FourInt32sFn(FourInt32s{100, -127, 4, -100}) + const want = 100 - 127 + 4 - 100 + if result != want { + t.Fatalf("FourInt32s returned %d wanted %d", result, want) + } + } } func TestRegisterFunc_structReturns(t *testing.T) { diff --git a/syscall_bench_test.go b/syscall_bench_test.go new file mode 100644 index 00000000..0e67d467 --- /dev/null +++ b/syscall_bench_test.go @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 The Ebitengine Authors + +package purego_test + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/ebitengine/purego" + "github.com/ebitengine/purego/internal/load" +) + +// BenchmarkCallingMethods compares RegisterFunc, SyscallN, and Callback methods +func BenchmarkCallingMethods(b *testing.B) { + testCases := []struct { + n int + fn any + fnPtr uintptr + cFnPtr uintptr + cFnName string + cCallbackPtr uintptr + cCallbackName string + args []uintptr + expectedSum int + }{ + {1, sum1, 0, 0, "sum1_c", 0, "call_callback1", []uintptr{1}, 1}, + {2, sum2, 0, 0, "sum2_c", 0, "call_callback2", []uintptr{1, 2}, 3}, + {3, sum3, 0, 0, "sum3_c", 0, "call_callback3", []uintptr{1, 2, 3}, 6}, + {5, sum5, 0, 0, "sum5_c", 0, "call_callback5", []uintptr{1, 2, 3, 4, 5}, 15}, + {10, sum10, 0, 0, "sum10_c", 0, "call_callback10", []uintptr{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 55}, + {14, sum15, 0, 0, "sum14_c", 0, "call_callback14", []uintptr{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, 105}, + {15, sum15, 0, 0, "sum15_c", 0, "call_callback15", []uintptr{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, 120}, + } + + // Build C library for benchmarking + libFileName := filepath.Join(b.TempDir(), "libbenchmark.so") + if err := buildSharedLib("CC", libFileName, filepath.Join("testdata", "benchmarktest", "benchmark.c")); err != nil { + b.Skipf("Failed to build C library: %v", err) + } + b.Cleanup(func() { + os.Remove(libFileName) + }) + + libHandle, err := load.OpenLibrary(libFileName) + if err != nil { + b.Fatalf("Failed to load C library: %v", err) + } + defer func() { + if err := load.CloseLibrary(libHandle); err != nil { + b.Fatalf("Failed to close library: %s", err) + } + }() + + // Create callbacks and load C functions + for i := range testCases { + testCases[i].fnPtr = purego.NewCallback(testCases[i].fn) + + cFn, err := load.OpenSymbol(libHandle, testCases[i].cFnName) + if err != nil { + b.Fatalf("Failed to load C function %s: %v", testCases[i].cFnName, err) + } + testCases[i].cFnPtr = cFn + + cCallbackFn, err := load.OpenSymbol(libHandle, testCases[i].cCallbackName) + if err != nil { + b.Fatalf("Failed to load C callback wrapper %s: %v", testCases[i].cCallbackName, err) + } + testCases[i].cCallbackPtr = cCallbackFn + } + + b.Run("RegisterFunc/Callback", func(b *testing.B) { + for _, tc := range testCases { + b.Run(fmt.Sprintf("%dargs", tc.n), func(b *testing.B) { + b.ReportAllocs() + registerFn := makeRegisterFunc(tc.n) + purego.RegisterFunc(registerFn, tc.fnPtr) + + b.ResetTimer() + result := callRegisterFunc(registerFn, tc.n, tc.args, b.N) + b.StopTimer() + + if int(result) != tc.expectedSum { + b.Fatalf("RegisterFunc/Callback: expected sum %d, got %d", tc.expectedSum, result) + } + }) + } + }) + + // Benchmark RegisterFunc with C functions + b.Run("RegisterFunc/CFunc", func(b *testing.B) { + for _, tc := range testCases { + b.Run(fmt.Sprintf("%dargs", tc.n), func(b *testing.B) { + b.ReportAllocs() + registerFn := makeRegisterFunc(tc.n) + purego.RegisterFunc(registerFn, tc.cFnPtr) + + b.ResetTimer() + result := callRegisterFunc(registerFn, tc.n, tc.args, b.N) + b.StopTimer() + + if int(result) != tc.expectedSum { + b.Fatalf("RegisterFunc/CFunc: expected sum %d, got %d", tc.expectedSum, result) + } + }) + } + }) + + // Benchmark SyscallN with Go callbacks + b.Run("SyscallN/Callback", func(b *testing.B) { + for _, tc := range testCases { + b.Run(fmt.Sprintf("%dargs", tc.n), func(b *testing.B) { + b.ReportAllocs() + var result uintptr + b.ResetTimer() + for i := 0; i < b.N; i++ { + result, _, _ = purego.SyscallN(tc.fnPtr, tc.args...) + } + b.StopTimer() + if int(result) != tc.expectedSum { + b.Fatalf("SyscallN/Callback: expected sum %d, got %d", tc.expectedSum, result) + } + }) + } + }) + + // Benchmark SyscallN with C functions + b.Run("SyscallN/CFunc", func(b *testing.B) { + for _, tc := range testCases { + b.Run(fmt.Sprintf("%dargs", tc.n), func(b *testing.B) { + b.ReportAllocs() + var result uintptr + b.ResetTimer() + for i := 0; i < b.N; i++ { + result, _, _ = purego.SyscallN(tc.cFnPtr, tc.args...) + } + b.StopTimer() + if int(result) != tc.expectedSum { + b.Fatalf("SyscallN/CFunc: expected sum %d, got %d", tc.expectedSum, result) + } + }) + } + }) + + // Benchmark round-trip: Go → C → Go callback (realistic use case) + b.Run("RoundTrip", func(b *testing.B) { + for _, tc := range testCases { + b.Run(fmt.Sprintf("%dargs", tc.n), func(b *testing.B) { + b.ReportAllocs() + // Build args: first arg is callback pointer, rest are the arguments + callbackArgs := make([]uintptr, len(tc.args)+1) + callbackArgs[0] = tc.fnPtr + copy(callbackArgs[1:], tc.args) + + // Skip if total args (callback + args) exceeds or meets limit + // SyscallN has issues with exactly 15 or more arguments + if len(callbackArgs) >= 15 { + b.Skipf("Round-trip with %d args + callback (%d total) exceeds/meets SyscallN limit", tc.n, len(callbackArgs)) + } + + var result uintptr + b.ResetTimer() + for i := 0; i < b.N; i++ { + result, _, _ = purego.SyscallN(tc.cCallbackPtr, callbackArgs...) + } + b.StopTimer() + if int(result) != tc.expectedSum { + b.Fatalf("RoundTrip: expected sum %d, got %d", tc.expectedSum, result) + } + }) + } + }) +} + +// makeRegisterFunc creates a function pointer of the appropriate signature +func makeRegisterFunc(n int) any { + switch n { + case 1: + return new(func(uintptr) uintptr) + case 2: + return new(func(uintptr, uintptr) uintptr) + case 3: + return new(func(uintptr, uintptr, uintptr) uintptr) + case 5: + return new(func(uintptr, uintptr, uintptr, uintptr, uintptr) uintptr) + case 10: + return new(func(uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr) uintptr) + case 14: + return new(func(uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr) uintptr) + case 15: + return new(func(uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr) uintptr) + default: + return nil + } +} + +// callRegisterFunc calls the registered function with the appropriate number of arguments +func callRegisterFunc(registerFn any, n int, args []uintptr, iterations int) uintptr { + var result uintptr + switch n { + case 1: + f := registerFn.(*func(uintptr) uintptr) + for i := 0; i < iterations; i++ { + result = (*f)(args[0]) + } + case 2: + f := registerFn.(*func(uintptr, uintptr) uintptr) + for i := 0; i < iterations; i++ { + result = (*f)(args[0], args[1]) + } + case 3: + f := registerFn.(*func(uintptr, uintptr, uintptr) uintptr) + for i := 0; i < iterations; i++ { + result = (*f)(args[0], args[1], args[2]) + } + case 5: + f := registerFn.(*func(uintptr, uintptr, uintptr, uintptr, uintptr) uintptr) + for i := 0; i < iterations; i++ { + result = (*f)(args[0], args[1], args[2], args[3], args[4]) + } + case 10: + f := registerFn.(*func(uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr) uintptr) + for i := 0; i < iterations; i++ { + result = (*f)(args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], args[9]) + } + case 14: + f := registerFn.(*func(uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr) uintptr) + for i := 0; i < iterations; i++ { + result = (*f)(args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], args[9], + args[10], args[11], args[12], args[13]) + } + case 15: + f := registerFn.(*func(uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr, uintptr) uintptr) + for i := 0; i < iterations; i++ { + result = (*f)(args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], args[9], + args[10], args[11], args[12], args[13], args[14]) + } + } + return result +} + +//go:noinline +func sum1(a1 uintptr) uintptr { return a1 } + +//go:noinline +func sum2(a1, a2 uintptr) uintptr { return a1 + a2 } + +//go:noinline +func sum3(a1, a2, a3 uintptr) uintptr { return a1 + a2 + a3 } + +//go:noinline +func sum5(a1, a2, a3, a4, a5 uintptr) uintptr { return a1 + a2 + a3 + a4 + a5 } + +//go:noinline +func sum10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 uintptr) uintptr { + return a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 +} + +//go:noinline +func sum15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15 uintptr) uintptr { + return a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 +} diff --git a/testdata/benchmarktest/benchmark.c b/testdata/benchmarktest/benchmark.c new file mode 100644 index 00000000..6a3005e2 --- /dev/null +++ b/testdata/benchmarktest/benchmark.c @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 The Ebitengine Authors + +// go:generate sh -c "if [ $(uname -s) = 'Darwin' ]; then cc -dynamiclib -O2 -o +// libnoop.dylib noop.c; else cc -shared -fPIC -O2 -o libnoop.so noop.c; fi" + +long sum1_c(long a1) { return a1; } + +long sum2_c(long a1, long a2) { return a1 + a2; } + +long sum3_c(long a1, long a2, long a3) { return a1 + a2 + a3; } + +long sum5_c(long a1, long a2, long a3, long a4, long a5) { + return a1 + a2 + a3 + a4 + a5; +} + +long sum10_c(long a1, long a2, long a3, long a4, long a5, long a6, long a7, + long a8, long a9, long a10) { + return a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10; +} + +long sum14_c(long a1, long a2, long a3, long a4, long a5, long a6, long a7, long a8, long a9, long a10, long a11, long a12, long a13, long a14, long a15) { + return a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + + a14 + a15; +} + +long sum15_c(long a1, long a2, long a3, long a4, long a5, long a6, long a7, long a8, long a9, long a10, long a11, long a12, long a13, long a14, long a15) { + return a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + + a14 + a15; +} +typedef long (*callback1_t)(long); +long call_callback1(callback1_t cb, long a1) { return cb(a1); } + +typedef long (*callback2_t)(long, long); +long call_callback2(callback2_t cb, long a1, long a2) { return cb(a1, a2); } + +typedef long (*callback3_t)(long, long, long); +long call_callback3(callback3_t cb, long a1, long a2, long a3) { + return cb(a1, a2, a3); +} + +typedef long (*callback5_t)(long, long, long, long, long); +long call_callback5(callback5_t cb, long a1, long a2, long a3, long a4, long a5) { + return cb(a1, a2, a3, a4, a5); +} + +typedef long (*callback10_t)(long, long, long, long, long, long, long, long, long, long); +long call_callback10(callback10_t cb, long a1, long a2, long a3, long a4, long a5, long a6, long a7, long a8, long a9, long a10) { + return cb(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10); +} + +typedef long (*callback14_t)(long, long, long, long, long, long, long, long, long, long, long, long, long, long); +long call_callback14(callback14_t cb, long a1, long a2, long a3, long a4, long a5, long a6, long a7, long a8, long a9, long a10, long a11, long a12, long a13, long a14) { + return cb(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14); +} + +typedef long (*callback15_t)(long, long, long, long, long, long, long, long, long, long, long, long, long, long, long); +long call_callback15(callback15_t cb, long a1, long a2, long a3, long a4, long a5, long a6, long a7, long a8, long a9, long a10, long a11, long a12, long a13, long a14, long a15) { + return cb(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15); +} diff --git a/testdata/structtest/struct_test.c b/testdata/structtest/struct_test.c index 5769f8d1..4cbf8060 100644 --- a/testdata/structtest/struct_test.c +++ b/testdata/structtest/struct_test.c @@ -1,7 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2024 The Ebitengine Authors -#include "stdint.h" +#include +#include +#include #if defined(__x86_64__) || defined(__aarch64__) typedef int64_t GoInt; @@ -361,3 +363,14 @@ struct FloatAndBool { int FloatAndBool(struct FloatAndBool f) { return f.has_value; } + +struct FourInt32s { + int32_t f0; + int32_t f1; + int32_t f2; + int32_t f3; +}; + +int32_t FourInt32s(struct FourInt32s s) { + return s.f0 + s.f1 + s.f2 + s.f3; +}