Go语言直接使用Windows的IOCP API写一个echo服务器
Go的标准库中Windows下的网络是使用了IOCP的,参见go源码go/src/runtime/netpoll_windows.go
,标准库为了与Epoll、kqueue等不同平台的IO模式使用统一的API,进行了封装。
如果想直接使用Windows的IOCP API编程,比如想按照: Windows下的高效网络模型IOCP完整示例中的流程写,就需要自行封装IOCP相关的API,虽然标准库中封装了很多系统调用,但是不是很全,而且API的函数签名也有一些问题,比如:
1// Deprecated: CreateIoCompletionPort has the wrong function signature. Use x/sys/windows.CreateIoCompletionPort.
2func CreateIoCompletionPort(filehandle Handle, cphandle Handle, key uint32, threadcnt uint32) (Handle, error) {
3 return createIoCompletionPort(filehandle, cphandle, uintptr(key), threadcnt)
4}
5
6// Deprecated: GetQueuedCompletionStatus has the wrong function signature. Use x/sys/windows.GetQueuedCompletionStatus.
7func GetQueuedCompletionStatus(cphandle Handle, qty *uint32, key *uint32, overlapped **Overlapped, timeout uint32) error {
8 var ukey uintptr
9 var pukey *uintptr
10 if key != nil {
11 ukey = uintptr(*key)
12 pukey = &ukey
13 }
14 err := getQueuedCompletionStatus(cphandle, qty, pukey, overlapped, timeout)
15 if key != nil {
16 *key = uint32(ukey)
17 if uintptr(*key) != ukey && err == nil {
18 err = errorspkg.New("GetQueuedCompletionStatus returned key overflow")
19 }
20 }
21 return err
22}
23
24// Deprecated: PostQueuedCompletionStatus has the wrong function signature. Use x/sys/windows.PostQueuedCompletionStatus.
25func PostQueuedCompletionStatus(cphandle Handle, qty uint32, key uint32, overlapped *Overlapped) error {
26 return postQueuedCompletionStatus(cphandle, qty, uintptr(key), overlapped)
27}
看了一下,其实内部调用的函数签名是没问题的,可以使用Go的魔法指令go:linkname
来解决:
1//go:linkname CreateIoCompletionPort syscall.createIoCompletionPort
2func CreateIoCompletionPort(fileHandle syscall.Handle, cpHandle syscall.Handle, key uintptr, threadCnt uint32) (handle syscall.Handle, err error)
3
4//go:linkname GetQueuedCompletionStatus syscall.getQueuedCompletionStatus
5func GetQueuedCompletionStatus(cpHandle syscall.Handle, qty *uint32, key *uintptr, overlapped **syscall.Overlapped, timeout uint32) (err error)
6
7//go:linkname PostQueuedCompletionStatus syscall.postQueuedCompletionStatus
8func PostQueuedCompletionStatus(cphandle syscall.Handle, qty uint32, key uintptr, overlapped *syscall.Overlapped) (err error)
另外还需要使用到一些API,比如WSACreateEvent
、WSAWaitForMultipleEvents
、WSAResetEvent
、WSAGetOverlappedResult
,就需要自行从Ws2_32.dll
中装载了:
1var (
2 modws2_32 = syscall.NewLazyDLL("Ws2_32.dll")
3
4 procWSACreateEvent = modws2_32.NewProc("WSACreateEvent")
5 procWSAWaitForMultipleEvents = modws2_32.NewProc("WSAWaitForMultipleEvents")
6 procWSAResetEvent = modws2_32.NewProc("WSAResetEvent")
7 procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
8)
9
10func WSACreateEvent() (Handle syscall.Handle, err error) {
11 r1, _, e1 := syscall.SyscallN(procWSACreateEvent.Addr())
12 if r1 == 0 {
13 err = errnoErr(e1)
14 }
15 return syscall.Handle(r1), nil
16}
17
18func WSAWaitForMultipleEvents(cEvents uint32, lpEvent *syscall.Handle, fWaitAll bool, dwTimeout uint32, fAlertable bool) (uint32, error) {
19 var WaitAll, Alertable uint32
20 if fWaitAll {
21 WaitAll = 1
22 }
23 if fAlertable {
24 Alertable = 1
25 }
26 r1, _, e1 := syscall.SyscallN(procWSAWaitForMultipleEvents.Addr(), uintptr(cEvents), uintptr(unsafe.Pointer(lpEvent)), uintptr(WaitAll), uintptr(dwTimeout), uintptr(Alertable))
27 if r1 == syscall.WAIT_FAILED {
28 return 0, errnoErr(e1)
29 }
30 return uint32(r1), nil
31}
32
33func WSAResetEvent(handle syscall.Handle) (err error) {
34 r1, _, e1 := syscall.SyscallN(procWSAResetEvent.Addr(), uintptr(handle))
35 if r1 == 0 {
36 err = errnoErr(e1)
37 }
38 return
39}
40
41func WSAGetOverlappedResult(socket syscall.Handle, overlapped *syscall.Overlapped, transferBytes *uint32, bWait bool, flag *uint32) (err error) {
42 var wait uint32
43 if bWait {
44 wait = 1
45 }
46 r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(socket), uintptr(unsafe.Pointer(overlapped)),
47 uintptr(unsafe.Pointer(transferBytes)), uintptr(wait), uintptr(unsafe.Pointer(flag)))
48 if r1 == 0 {
49 err = errnoErr(e1)
50 }
51 return
52}
笔者尝试了下,完全可以,
直接附上源码:
1package main
2
3import (
4 "errors"
5 "fmt"
6 "os"
7 "runtime"
8 "syscall"
9 "unsafe"
10 _ "unsafe"
11)
12
13//go:linkname CreateIoCompletionPort syscall.createIoCompletionPort
14func CreateIoCompletionPort(fileHandle syscall.Handle, cpHandle syscall.Handle, key uintptr, threadCnt uint32) (handle syscall.Handle, err error)
15
16//go:linkname GetQueuedCompletionStatus syscall.getQueuedCompletionStatus
17func GetQueuedCompletionStatus(cpHandle syscall.Handle, qty *uint32, key *uintptr, overlapped **syscall.Overlapped, timeout uint32) (err error)
18
19//go:linkname PostQueuedCompletionStatus syscall.postQueuedCompletionStatus
20func PostQueuedCompletionStatus(cphandle syscall.Handle, qty uint32, key uintptr, overlapped *syscall.Overlapped) (err error)
21
22//go:linkname errnoErr syscall.errnoErr
23func errnoErr(e syscall.Errno) error
24
25var (
26 modws2_32 = syscall.NewLazyDLL("Ws2_32.dll")
27
28 procWSACreateEvent = modws2_32.NewProc("WSACreateEvent")
29 procWSACloseEvent = modws2_32.NewProc("WSACloseEvent")
30 procWSAWaitForMultipleEvents = modws2_32.NewProc("WSAWaitForMultipleEvents")
31 procWSAResetEvent = modws2_32.NewProc("WSAResetEvent")
32 procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
33)
34
35func WSACreateEvent() (handle syscall.Handle, err error) {
36 r1, _, e1 := syscall.SyscallN(procWSACreateEvent.Addr())
37 if r1 == 0 {
38 err = errnoErr(e1)
39 }
40 return syscall.Handle(r1), err
41}
42
43func WSACloseEvent(handle syscall.Handle) (err error) {
44 r1, _, e1 := syscall.SyscallN(procWSACloseEvent.Addr(), uintptr(handle))
45 if r1 == 0 {
46 err = errnoErr(e1)
47 }
48 return err
49}
50
51func WSAResetEvent(handle syscall.Handle) (err error) {
52 r1, _, e1 := syscall.SyscallN(procWSAResetEvent.Addr(), uintptr(handle))
53 if r1 == 0 {
54 err = errnoErr(e1)
55 }
56 return
57}
58
59func WSAWaitForMultipleEvents(cEvents uint32, lpEvent *syscall.Handle, fWaitAll bool, dwTimeout uint32, fAlertable bool) (uint32, error) {
60 var WaitAll, Alertable uint32
61 if fWaitAll {
62 WaitAll = 1
63 }
64 if fAlertable {
65 Alertable = 1
66 }
67 r1, _, e1 := syscall.SyscallN(procWSAWaitForMultipleEvents.Addr(), uintptr(cEvents), uintptr(unsafe.Pointer(lpEvent)), uintptr(WaitAll), uintptr(dwTimeout), uintptr(Alertable))
68 if r1 == syscall.WAIT_FAILED {
69 return 0, errnoErr(e1)
70 }
71 return uint32(r1), nil
72}
73
74func WSAGetOverlappedResult(socket syscall.Handle, overlapped *syscall.Overlapped, transferBytes *uint32, bWait bool, flag *uint32) (err error) {
75 var wait uint32
76 if bWait {
77 wait = 1
78 }
79 r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(socket), uintptr(unsafe.Pointer(overlapped)),
80 uintptr(unsafe.Pointer(transferBytes)), uintptr(wait), uintptr(unsafe.Pointer(flag)))
81 if r1 == 0 {
82 err = errnoErr(e1)
83 }
84 return
85}
86
87func SetNonBlock(fd syscall.Handle) error {
88 flag := uint32(1)
89 size := uint32(unsafe.Sizeof(flag))
90 ret := uint32(0)
91 ol := syscall.Overlapped{}
92 err := syscall.WSAIoctl(fd, 0x8004667e, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, &ol, 0)
93 if err != nil {
94 return err
95 }
96 return nil
97}
98
99type IOData struct {
100 Overlapped syscall.Overlapped
101 WsaBuf syscall.WSABuf
102 NBytes uint32
103 isRead bool
104 cliSock syscall.Handle
105}
106
107func closeIO(data *IOData) {
108 if data.Overlapped.HEvent != syscall.Handle(0) {
109 WSACloseEvent(data.Overlapped.HEvent)
110 data.Overlapped.HEvent = syscall.Handle(0)
111 }
112 syscall.Closesocket(data.cliSock)
113}
114
115func main() {
116 listenFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
117 if err != nil {
118 return
119 }
120 defer func() {
121 syscall.Closesocket(listenFd)
122 syscall.WSACleanup()
123 }()
124 v4 := &syscall.SockaddrInet4{
125 Port: 6000,
126 Addr: [4]byte{},
127 }
128 err = syscall.Bind(listenFd, v4)
129 if err != nil {
130 return
131 }
132 err = syscall.Listen(listenFd, 0)
133 if err != nil {
134 return
135 }
136
137 hIOCP, err := CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 0)
138 if err != nil {
139 return
140 }
141 count := runtime.NumCPU()
142 for i := 0; i < count; i++ {
143 go workThread(hIOCP)
144 }
145
146 defer func() {
147 for i := 0; i < count; i++ {
148 PostQueuedCompletionStatus(hIOCP, 0, 0, nil)
149 }
150 }()
151
152 for {
153 acceptFd, er := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
154 if er != nil {
155 break
156 }
157 b := make([]byte, 1024)
158 recvD := uint32(0)
159 data := &IOData{
160 Overlapped: syscall.Overlapped{},
161 WsaBuf: syscall.WSABuf{
162 Len: 1024,
163 Buf: &b[0],
164 },
165 NBytes: 1024,
166 isRead: true,
167 cliSock: acceptFd,
168 }
169 data.Overlapped.HEvent, er = WSACreateEvent()
170 if er != nil {
171 fmt.Printf("WSACreateEvent failed:%s", er)
172 closeIO(data)
173 break
174 }
175
176 size := uint32(unsafe.Sizeof(&syscall.SockaddrInet4{}) + 16)
177 er = syscall.AcceptEx(listenFd, acceptFd, data.WsaBuf.Buf, data.WsaBuf.Len-size*2, size, size, &recvD, &data.Overlapped)
178 if er != nil && !errors.Is(er, syscall.ERROR_IO_PENDING) {
179 er = os.NewSyscallError("AcceptEx", er)
180 fmt.Printf("AcceptEx Error:%s", er)
181 closeIO(data)
182 break
183 }
184
185 _, er = WSAWaitForMultipleEvents(1, &data.Overlapped.HEvent, true, 100, false)
186 if er != nil {
187 fmt.Printf("WSAWaitForMultipleEvents Error:%s", er)
188 closeIO(data)
189 break
190 }
191 WSAResetEvent(data.Overlapped.HEvent)
192 flag := uint32(0)
193 er = WSAGetOverlappedResult(acceptFd, &data.Overlapped, &data.NBytes, true, &flag)
194 if er != nil {
195 fmt.Printf("WSAGetOverlappedResult Error:%s", er)
196 closeIO(data)
197 break
198 }
199 if data.NBytes == 0 {
200 closeIO(data)
201 continue
202 }
203 fmt.Printf("client %d connected\n", acceptFd)
204 _, err = CreateIoCompletionPort(acceptFd, hIOCP, 0, 0)
205 if err != nil {
206 fmt.Printf("CreateIoCompletionPort Error:%s", er)
207 closeIO(data)
208 break
209 }
210 postWrite(data)
211 }
212}
213
214func postWrite(data *IOData) (err error) {
215 data.isRead = false
216 // 这里输出一下data指针,让运行时不把data给GC掉,否则就会出问题
217 fmt.Printf("%p cli:%d send %s\n", data, data.cliSock, unsafe.String(data.WsaBuf.Buf, data.NBytes))
218 err = syscall.WSASend(data.cliSock, &data.WsaBuf, 1, &data.NBytes, 0, &data.Overlapped, nil)
219 if err != nil && !errors.Is(err, syscall.ERROR_IO_PENDING) {
220 fmt.Printf("cli:%d send failed: %s\n", data.cliSock, err)
221 closeIO(data)
222 return err
223 }
224 return
225}
226
227func postRead(data *IOData) (err error) {
228 data.NBytes = data.WsaBuf.Len
229 data.isRead = true
230 flag := uint32(0)
231 err = syscall.WSARecv(data.cliSock, &data.WsaBuf, 1, &data.NBytes, &flag, &data.Overlapped, nil)
232 if err != nil && !errors.Is(err, syscall.ERROR_IO_PENDING) {
233 fmt.Printf("cli:%d receive failed: %s\n", data.cliSock, err)
234 closeIO(data)
235 return err
236 }
237 return
238}
239
240func workThread(hIOCP syscall.Handle) {
241 var pOverlapped *syscall.Overlapped
242 var ioSize uint32
243 var key uintptr
244 for {
245 err := GetQueuedCompletionStatus(hIOCP, &ioSize, &key, &pOverlapped, syscall.INFINITE)
246 if err != nil {
247 fmt.Printf("GetQueuedCompletionStatus failed: %s\n", err)
248 return
249 }
250 ioData := (*IOData)(unsafe.Pointer(pOverlapped))
251 if ioSize == 0 {
252 closeIO(ioData)
253 break
254 }
255 ioData.NBytes = ioSize
256 if ioData.isRead {
257 postWrite(ioData)
258 } else {
259 postRead(ioData)
260 }
261 }
262}
源码只是一个示例,可能有不完善的地方,感兴趣的读者可以自行完善。
- 原文作者:Witton
- 原文链接:https://wittonbell.github.io/posts/2024/2024-05-24-Go语言直接使用Windows的IOCP-API写一个echo服务器/
- 版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议. 进行许可,非商业转载请注明出处(作者,原文链接),商业转载请联系作者获得授权。