Go的标准库中Windows下的网络是使用了IOCP的,参见go源码go/src/runtime/netpoll_windows.go,标准库为了与Epoll、kqueue等不同平台的IO模式使用统一的API,进行了封装。

如果想直接使用Windows的IOCP API编程,比如想按照: Windows下的高效网络模型IOCP完整示例中的流程写,就需要自行封装IOCP相关的API,虽然标准库中封装了很多系统调用,但是不是很全,而且API的函数签名也有一些问题,比如:

 1// Deprecated: CreateIoCompletionPort has the wrong function signature. Use x/sys/windows.CreateIoCompletionPort.
 2func CreateIoCompletionPort(filehandle Handle, cphandle Handle, key uint32, threadcnt uint32) (Handle, error) {
 3	return createIoCompletionPort(filehandle, cphandle, uintptr(key), threadcnt)
 4}
 5
 6// Deprecated: GetQueuedCompletionStatus has the wrong function signature. Use x/sys/windows.GetQueuedCompletionStatus.
 7func GetQueuedCompletionStatus(cphandle Handle, qty *uint32, key *uint32, overlapped **Overlapped, timeout uint32) error {
 8	var ukey uintptr
 9	var pukey *uintptr
10	if key != nil {
11		ukey = uintptr(*key)
12		pukey = &ukey
13	}
14	err := getQueuedCompletionStatus(cphandle, qty, pukey, overlapped, timeout)
15	if key != nil {
16		*key = uint32(ukey)
17		if uintptr(*key) != ukey && err == nil {
18			err = errorspkg.New("GetQueuedCompletionStatus returned key overflow")
19		}
20	}
21	return err
22}
23
24// Deprecated: PostQueuedCompletionStatus has the wrong function signature. Use x/sys/windows.PostQueuedCompletionStatus.
25func PostQueuedCompletionStatus(cphandle Handle, qty uint32, key uint32, overlapped *Overlapped) error {
26	return postQueuedCompletionStatus(cphandle, qty, uintptr(key), overlapped)
27}

看了一下,其实内部调用的函数签名是没问题的,可以使用Go的魔法指令go:linkname来解决:

1//go:linkname CreateIoCompletionPort syscall.createIoCompletionPort
2func CreateIoCompletionPort(fileHandle syscall.Handle, cpHandle syscall.Handle, key uintptr, threadCnt uint32) (handle syscall.Handle, err error)
3
4//go:linkname GetQueuedCompletionStatus syscall.getQueuedCompletionStatus
5func GetQueuedCompletionStatus(cpHandle syscall.Handle, qty *uint32, key *uintptr, overlapped **syscall.Overlapped, timeout uint32) (err error)
6
7//go:linkname PostQueuedCompletionStatus syscall.postQueuedCompletionStatus
8func PostQueuedCompletionStatus(cphandle syscall.Handle, qty uint32, key uintptr, overlapped *syscall.Overlapped) (err error)

另外还需要使用到一些API,比如WSACreateEventWSAWaitForMultipleEventsWSAResetEventWSAGetOverlappedResult,就需要自行从Ws2_32.dll中装载了:

 1var (
 2	modws2_32 = syscall.NewLazyDLL("Ws2_32.dll")
 3
 4	procWSACreateEvent           = modws2_32.NewProc("WSACreateEvent")
 5	procWSAWaitForMultipleEvents = modws2_32.NewProc("WSAWaitForMultipleEvents")
 6	procWSAResetEvent            = modws2_32.NewProc("WSAResetEvent")
 7	procWSAGetOverlappedResult   = modws2_32.NewProc("WSAGetOverlappedResult")
 8)
 9
10func WSACreateEvent() (Handle syscall.Handle, err error) {
11	r1, _, e1 := syscall.SyscallN(procWSACreateEvent.Addr())
12	if r1 == 0 {
13		err = errnoErr(e1)
14	}
15	return syscall.Handle(r1), nil
16}
17
18func WSAWaitForMultipleEvents(cEvents uint32, lpEvent *syscall.Handle, fWaitAll bool, dwTimeout uint32, fAlertable bool) (uint32, error) {
19	var WaitAll, Alertable uint32
20	if fWaitAll {
21		WaitAll = 1
22	}
23	if fAlertable {
24		Alertable = 1
25	}
26	r1, _, e1 := syscall.SyscallN(procWSAWaitForMultipleEvents.Addr(), uintptr(cEvents), uintptr(unsafe.Pointer(lpEvent)), uintptr(WaitAll), uintptr(dwTimeout), uintptr(Alertable))
27	if r1 == syscall.WAIT_FAILED {
28		return 0, errnoErr(e1)
29	}
30	return uint32(r1), nil
31}
32
33func WSAResetEvent(handle syscall.Handle) (err error) {
34	r1, _, e1 := syscall.SyscallN(procWSAResetEvent.Addr(), uintptr(handle))
35	if r1 == 0 {
36		err = errnoErr(e1)
37	}
38	return
39}
40
41func WSAGetOverlappedResult(socket syscall.Handle, overlapped *syscall.Overlapped, transferBytes *uint32, bWait bool, flag *uint32) (err error) {
42	var wait uint32
43	if bWait {
44		wait = 1
45	}
46	r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(socket), uintptr(unsafe.Pointer(overlapped)),
47		uintptr(unsafe.Pointer(transferBytes)), uintptr(wait), uintptr(unsafe.Pointer(flag)))
48	if r1 == 0 {
49		err = errnoErr(e1)
50	}
51	return
52}

笔者尝试了下,完全可以,

在这里插入图片描述

直接附上源码:

  1package main
  2
  3import (
  4	"errors"
  5	"fmt"
  6	"os"
  7	"runtime"
  8	"syscall"
  9	"unsafe"
 10	_ "unsafe"
 11)
 12
 13//go:linkname CreateIoCompletionPort syscall.createIoCompletionPort
 14func CreateIoCompletionPort(fileHandle syscall.Handle, cpHandle syscall.Handle, key uintptr, threadCnt uint32) (handle syscall.Handle, err error)
 15
 16//go:linkname GetQueuedCompletionStatus syscall.getQueuedCompletionStatus
 17func GetQueuedCompletionStatus(cpHandle syscall.Handle, qty *uint32, key *uintptr, overlapped **syscall.Overlapped, timeout uint32) (err error)
 18
 19//go:linkname PostQueuedCompletionStatus syscall.postQueuedCompletionStatus
 20func PostQueuedCompletionStatus(cphandle syscall.Handle, qty uint32, key uintptr, overlapped *syscall.Overlapped) (err error)
 21
 22//go:linkname errnoErr syscall.errnoErr
 23func errnoErr(e syscall.Errno) error
 24
 25var (
 26	modws2_32 = syscall.NewLazyDLL("Ws2_32.dll")
 27
 28	procWSACreateEvent           = modws2_32.NewProc("WSACreateEvent")
 29	procWSACloseEvent            = modws2_32.NewProc("WSACloseEvent")
 30	procWSAWaitForMultipleEvents = modws2_32.NewProc("WSAWaitForMultipleEvents")
 31	procWSAResetEvent            = modws2_32.NewProc("WSAResetEvent")
 32	procWSAGetOverlappedResult   = modws2_32.NewProc("WSAGetOverlappedResult")
 33)
 34
 35func WSACreateEvent() (handle syscall.Handle, err error) {
 36	r1, _, e1 := syscall.SyscallN(procWSACreateEvent.Addr())
 37	if r1 == 0 {
 38		err = errnoErr(e1)
 39	}
 40	return syscall.Handle(r1), err
 41}
 42
 43func WSACloseEvent(handle syscall.Handle) (err error) {
 44	r1, _, e1 := syscall.SyscallN(procWSACloseEvent.Addr(), uintptr(handle))
 45	if r1 == 0 {
 46		err = errnoErr(e1)
 47	}
 48	return err
 49}
 50
 51func WSAResetEvent(handle syscall.Handle) (err error) {
 52	r1, _, e1 := syscall.SyscallN(procWSAResetEvent.Addr(), uintptr(handle))
 53	if r1 == 0 {
 54		err = errnoErr(e1)
 55	}
 56	return
 57}
 58
 59func WSAWaitForMultipleEvents(cEvents uint32, lpEvent *syscall.Handle, fWaitAll bool, dwTimeout uint32, fAlertable bool) (uint32, error) {
 60	var WaitAll, Alertable uint32
 61	if fWaitAll {
 62		WaitAll = 1
 63	}
 64	if fAlertable {
 65		Alertable = 1
 66	}
 67	r1, _, e1 := syscall.SyscallN(procWSAWaitForMultipleEvents.Addr(), uintptr(cEvents), uintptr(unsafe.Pointer(lpEvent)), uintptr(WaitAll), uintptr(dwTimeout), uintptr(Alertable))
 68	if r1 == syscall.WAIT_FAILED {
 69		return 0, errnoErr(e1)
 70	}
 71	return uint32(r1), nil
 72}
 73
 74func WSAGetOverlappedResult(socket syscall.Handle, overlapped *syscall.Overlapped, transferBytes *uint32, bWait bool, flag *uint32) (err error) {
 75	var wait uint32
 76	if bWait {
 77		wait = 1
 78	}
 79	r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(socket), uintptr(unsafe.Pointer(overlapped)),
 80		uintptr(unsafe.Pointer(transferBytes)), uintptr(wait), uintptr(unsafe.Pointer(flag)))
 81	if r1 == 0 {
 82		err = errnoErr(e1)
 83	}
 84	return
 85}
 86
 87func SetNonBlock(fd syscall.Handle) error {
 88	flag := uint32(1)
 89	size := uint32(unsafe.Sizeof(flag))
 90	ret := uint32(0)
 91	ol := syscall.Overlapped{}
 92	err := syscall.WSAIoctl(fd, 0x8004667e, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, &ol, 0)
 93	if err != nil {
 94		return err
 95	}
 96	return nil
 97}
 98
 99type IOData struct {
100	Overlapped syscall.Overlapped
101	WsaBuf     syscall.WSABuf
102	NBytes     uint32
103	isRead     bool
104	cliSock    syscall.Handle
105}
106
107func closeIO(data *IOData) {
108	if data.Overlapped.HEvent != syscall.Handle(0) {
109		WSACloseEvent(data.Overlapped.HEvent)
110		data.Overlapped.HEvent = syscall.Handle(0)
111	}
112	syscall.Closesocket(data.cliSock)
113}
114
115func main() {
116	listenFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
117	if err != nil {
118		return
119	}
120	defer func() {
121		syscall.Closesocket(listenFd)
122		syscall.WSACleanup()
123	}()
124	v4 := &syscall.SockaddrInet4{
125		Port: 6000,
126		Addr: [4]byte{},
127	}
128	err = syscall.Bind(listenFd, v4)
129	if err != nil {
130		return
131	}
132	err = syscall.Listen(listenFd, 0)
133	if err != nil {
134		return
135	}
136
137	hIOCP, err := CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 0)
138	if err != nil {
139		return
140	}
141	count := runtime.NumCPU()
142	for i := 0; i < count; i++ {
143		go workThread(hIOCP)
144	}
145
146	defer func() {
147		for i := 0; i < count; i++ {
148			PostQueuedCompletionStatus(hIOCP, 0, 0, nil)
149		}
150	}()
151
152	for {
153		acceptFd, er := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
154		if er != nil {
155			break
156		}
157		b := make([]byte, 1024)
158		recvD := uint32(0)
159		data := &IOData{
160			Overlapped: syscall.Overlapped{},
161			WsaBuf: syscall.WSABuf{
162				Len: 1024,
163				Buf: &b[0],
164			},
165			NBytes:  1024,
166			isRead:  true,
167			cliSock: acceptFd,
168		}
169		data.Overlapped.HEvent, er = WSACreateEvent()
170		if er != nil {
171			fmt.Printf("WSACreateEvent failed:%s", er)
172			closeIO(data)
173			break
174		}
175
176		size := uint32(unsafe.Sizeof(&syscall.SockaddrInet4{}) + 16)
177		er = syscall.AcceptEx(listenFd, acceptFd, data.WsaBuf.Buf, data.WsaBuf.Len-size*2, size, size, &recvD, &data.Overlapped)
178		if er != nil && !errors.Is(er, syscall.ERROR_IO_PENDING) {
179			er = os.NewSyscallError("AcceptEx", er)
180			fmt.Printf("AcceptEx Error:%s", er)
181			closeIO(data)
182			break
183		}
184
185		_, er = WSAWaitForMultipleEvents(1, &data.Overlapped.HEvent, true, 100, false)
186		if er != nil {
187			fmt.Printf("WSAWaitForMultipleEvents Error:%s", er)
188			closeIO(data)
189			break
190		}
191		WSAResetEvent(data.Overlapped.HEvent)
192		flag := uint32(0)
193		er = WSAGetOverlappedResult(acceptFd, &data.Overlapped, &data.NBytes, true, &flag)
194		if er != nil {
195			fmt.Printf("WSAGetOverlappedResult Error:%s", er)
196			closeIO(data)
197			break
198		}
199		if data.NBytes == 0 {
200			closeIO(data)
201			continue
202		}
203		fmt.Printf("client %d connected\n", acceptFd)
204		_, err = CreateIoCompletionPort(acceptFd, hIOCP, 0, 0)
205		if err != nil {
206			fmt.Printf("CreateIoCompletionPort Error:%s", er)
207			closeIO(data)
208			break
209		}
210		postWrite(data)
211	}
212}
213
214func postWrite(data *IOData) (err error) {
215	data.isRead = false
216	// 这里输出一下data指针,让运行时不把data给GC掉,否则就会出问题
217	fmt.Printf("%p cli:%d send %s\n", data, data.cliSock, unsafe.String(data.WsaBuf.Buf, data.NBytes))
218	err = syscall.WSASend(data.cliSock, &data.WsaBuf, 1, &data.NBytes, 0, &data.Overlapped, nil)
219	if err != nil && !errors.Is(err, syscall.ERROR_IO_PENDING) {
220		fmt.Printf("cli:%d send failed: %s\n", data.cliSock, err)
221		closeIO(data)
222		return err
223	}
224	return
225}
226
227func postRead(data *IOData) (err error) {
228	data.NBytes = data.WsaBuf.Len
229	data.isRead = true
230	flag := uint32(0)
231	err = syscall.WSARecv(data.cliSock, &data.WsaBuf, 1, &data.NBytes, &flag, &data.Overlapped, nil)
232	if err != nil && !errors.Is(err, syscall.ERROR_IO_PENDING) {
233		fmt.Printf("cli:%d receive failed: %s\n", data.cliSock, err)
234		closeIO(data)
235		return err
236	}
237	return
238}
239
240func workThread(hIOCP syscall.Handle) {
241	var pOverlapped *syscall.Overlapped
242	var ioSize uint32
243	var key uintptr
244	for {
245		err := GetQueuedCompletionStatus(hIOCP, &ioSize, &key, &pOverlapped, syscall.INFINITE)
246		if err != nil {
247			fmt.Printf("GetQueuedCompletionStatus failed: %s\n", err)
248			return
249		}
250		ioData := (*IOData)(unsafe.Pointer(pOverlapped))
251		if ioSize == 0 {
252			closeIO(ioData)
253			break
254		}
255		ioData.NBytes = ioSize
256		if ioData.isRead {
257			postWrite(ioData)
258		} else {
259			postRead(ioData)
260		}
261	}
262}

源码只是一个示例,可能有不完善的地方,感兴趣的读者可以自行完善。