1 /*
2  * Hunt - A refined core library for D programming language.
3  *
4  * Copyright (C) 2018-2019 HuntLabs
5  *
6  * Website: https://www.huntlabs.net/
7  *
8  * Licensed under the Apache-2.0 License.
9  *
10  */
11 
12 module hunt.io.socket.IOCP;
13 
14 // dfmt off
15 version (HAVE_IOCP) : 
16 
17 pragma(lib, "Ws2_32");
18 // dfmt on
19 
20 import hunt.collection.ByteBuffer;
21 import hunt.io.socket.Common;
22 import hunt.logging;
23 import hunt.Functions;
24 import hunt.concurrency.thread.Helper;
25 
26 import core.sys.windows.windows;
27 import core.sys.windows.winsock2;
28 import core.sys.windows.mswsock;
29 
30 import std.conv;
31 import std.exception;
32 import std.format;
33 import std.process;
34 import std.socket;
35 import std.stdio;
36 
37 /**
38 TCP Server
39 */
40 abstract class AbstractListener : AbstractSocketChannel {
41     this(Selector loop, AddressFamily family = AddressFamily.INET, size_t bufferSize = 4 * 1024) {
42         super(loop, ChannelType.Accept);
43         setFlag(ChannelFlag.Read, true);
44         _buffer = new ubyte[bufferSize];
45         this.socket = new TcpSocket(family);
46 
47         loadWinsockExtension(this.handle);
48     }
49 
50     mixin CheckIocpError;
51 
52     protected void doAccept() {
53         _iocp.channel = this;
54         _iocp.operation = IocpOperation.accept;
55         _clientSocket = new Socket(this.localAddress.addressFamily,
56                 SocketType.STREAM, ProtocolType.TCP);
57         DWORD dwBytesReceived = 0;
58 
59         version (HUNT_DEBUG) {
60             tracef("client socket: acceptor=%s  inner socket=%s", this.handle,
61                     _clientSocket.handle());
62             // info("AcceptEx@", AcceptEx);
63         }
64         uint sockaddrSize = cast(uint) sockaddr_storage.sizeof;
65         // https://docs.microsoft.com/en-us/windows/desktop/api/mswsock/nf-mswsock-acceptex
66         BOOL ret = AcceptEx(this.handle, cast(SOCKET) _clientSocket.handle, _buffer.ptr,
67                 0, sockaddrSize + 16, sockaddrSize + 16, &dwBytesReceived, &_iocp.overlapped);
68         version (HUNT_DEBUG)
69             trace("AcceptEx return: ", ret);
70         checkErro(ret, FALSE);
71     }
72 
73     protected bool onAccept(scope AcceptHandler handler) {
74         version (HUNT_DEBUG)
75             trace("a new connection coming...");
76         this.clearError();
77         SOCKET slisten = cast(SOCKET) this.handle;
78         SOCKET slink = cast(SOCKET) this._clientSocket.handle;
79         // void[] value = (&slisten)[0..1];
80         // setsockopt(slink, SocketOptionLevel.SOCKET, 0x700B, value.ptr,
81         //                    cast(uint) value.length);
82         version (HUNT_DEBUG)
83             tracef("slisten=%s, slink=%s", slisten, slink);
84         setsockopt(slink, SocketOptionLevel.SOCKET, 0x700B, cast(void*)&slisten, slisten.sizeof);
85         if (handler !is null)
86             handler(this._clientSocket);
87 
88         version (HUNT_DEBUG)
89             trace("accepting next connection...");
90         if (this.isRegistered)
91             this.doAccept();
92         return true;
93     }
94 
95     override void onClose() {
96         // assert(false, "");
97         // TODO: created by Administrator @ 2018-3-27 15:51:52
98     }
99 
100     private IocpContext _iocp;
101     private WSABUF _dataWriteBuffer;
102     private ubyte[] _buffer;
103     private Socket _clientSocket;
104 }
105 
106 
107 /**
108 TCP Client
109 */
110 abstract class AbstractStream : AbstractSocketChannel, Stream {
111     DataReceivedHandler dataReceivedHandler;
112     DataWrittenHandler sentHandler;
113     protected AddressFamily _family;
114 
115     this(Selector loop, AddressFamily family = AddressFamily.INET, size_t bufferSize = 4096 * 2) {
116         super(loop, ChannelType.TCP);
117         setFlag(ChannelFlag.Read, true);
118         setFlag(ChannelFlag.Write, true);
119 
120         version (HUNT_DEBUG)
121             trace("Buffer size for read: ", bufferSize);
122         _readBuffer = new ubyte[bufferSize];
123         this.socket = new TcpSocket(family);
124 
125         loadWinsockExtension(this.handle);
126     }
127 
128     mixin CheckIocpError;
129 
130     override void onRead() {
131         version (HUNT_DEBUG)
132             trace("ready to read");
133         _inRead = false;
134         super.onRead();
135     }
136 
137     override void onWrite() {
138         _inWrite = false;
139         super.onWrite();
140     }
141 
142     protected void beginRead() {
143         _inRead = true;
144         _dataReadBuffer.len = cast(uint) _readBuffer.length;
145         _dataReadBuffer.buf = cast(char*) _readBuffer.ptr;
146         _iocpread.channel = this;
147         _iocpread.operation = IocpOperation.read;
148         DWORD dwReceived = 0;
149         DWORD dwFlags = 0;
150 
151         version (HUNT_DEBUG)
152             tracef("start receiving by handle[fd=%d] ", this.socket.handle);
153 
154         // https://docs.microsoft.com/en-us/windows/desktop/api/winsock2/nf-winsock2-wsarecv
155         int nRet = WSARecv(cast(SOCKET) this.socket.handle, &_dataReadBuffer, 1u, &dwReceived, &dwFlags,
156                 &_iocpread.overlapped, cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
157 
158         checkErro(nRet, SOCKET_ERROR);
159     }
160 
161     protected void doConnect(Address addr) {
162         Address binded = createAddress(this.socket.addressFamily);
163         this.socket.bind(binded);
164         _iocpwrite.channel = this;
165         _iocpwrite.operation = IocpOperation.connect;
166         int nRet = ConnectEx(cast(SOCKET) this.socket.handle(),
167                 cast(SOCKADDR*) addr.name(), addr.nameLen(), null, 0, null,
168                 &_iocpwrite.overlapped);
169         checkErro(nRet, SOCKET_ERROR);
170     }
171 
172     private uint doWrite() {
173         _inWrite = true;
174         DWORD dwFlags = 0;
175         DWORD dwSent = 0;
176         _iocpwrite.channel = this;
177         _iocpwrite.operation = IocpOperation.write;
178         version (HUNT_DEBUG) {
179             size_t bufferLength = sendDataBuffer.length;
180             tracef("To be written %d nbytes by handle[fd=%d]", bufferLength, this.socket.handle());
181             // trace(cast(string) data);
182             if (bufferLength > 32)
183                 tracef("%(%02X %) ...", sendDataBuffer[0 .. 32]);
184             else
185                 tracef("%(%02X %)", sendDataBuffer[0 .. $]);
186         }
187 
188         int nRet = WSASend(cast(SOCKET) this.socket.handle(), &_dataWriteBuffer, 1, &dwSent,
189                 dwFlags, &_iocpwrite.overlapped, cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
190 
191         version (HUNT_DEBUG) {
192             if (dwSent != _dataWriteBuffer.len)
193                 warningf("dwSent=%d, BufferLength=%d", dwSent, _dataWriteBuffer.len);
194         }
195 
196         checkErro(nRet, SOCKET_ERROR);
197 
198         if (this.isError) {
199             errorf("Socket error on write: fd=%d, message=%s", this.handle, this.erroString);
200             this.close();
201         }
202 
203         return dwSent;
204     }
205 
206     protected void doRead() {
207         this.clearError();
208         version (HUNT_DEBUG)
209             tracef("start reading: %d nbytes", this.readLen);
210 
211         if (readLen > 0) {
212             // import std.stdio;
213             // writefln("length=%d, data: %(%02X %)", readLen, _readBuffer[0 .. readLen]);
214 
215             if (dataReceivedHandler !is null)
216                 dataReceivedHandler(this._readBuffer[0 .. readLen]);
217             version (HUNT_DEBUG)
218                 tracef("reading done: %d nbytes", this.readLen);
219 
220             // continue reading
221             this.beginRead();
222         } else if (readLen == 0) {
223             version (HUNT_DEBUG) {
224                 if (_remoteAddress !is null)
225                     warningf("connection broken: %s", _remoteAddress.toString());
226             }
227             onDisconnected();
228             // if (_isClosed)
229             //     this.close();
230         } else {
231             version (HUNT_DEBUG) {
232                 warningf("undefined behavior on thread %d", getTid());
233             } else {
234                 this._error = true;
235                 this._erroString = "undefined behavior on thread";
236             }
237         }
238     }
239 
240     // private ThreadID lastThreadID;
241 
242     /// 
243     // TODO: created by Administrator @ 2018-4-18 10:15:20
244     // Send a big block of data
245     protected size_t tryWrite(const ubyte[] data) {
246         if (_isWritting) {
247             warning("Busy in writting on thread: ");
248             return 0;
249         }
250         version (HUNT_DEBUG)
251             trace("start to write");
252         _isWritting = true;
253 
254         clearError();
255         setWriteBuffer(data);
256         size_t nBytes = doWrite();
257 
258         return nBytes;
259     }
260 
261     protected void tryWrite() {
262         if (_writeQueue.empty)
263             return;
264 
265         version (HUNT_DEBUG)
266             trace("start writting...");
267         _isWritting = true;
268         clearError();
269 
270         writeBuffer = _writeQueue.front();
271         const(ubyte)[] data = writeBuffer.remaining();
272         setWriteBuffer(data);
273         size_t nBytes = doWrite();
274 
275         if (nBytes < data.length) { // to fix the corrupted data 
276             version (HUNT_DEBUG)
277                 warningf("remaining data: %d / %d ", data.length - nBytes, data.length);
278             sendDataBuffer = data.dup;
279         }
280     }
281 
282     private void setWriteBuffer(in ubyte[] data) {
283         version (HUNT_DEBUG)
284             tracef("data length: %d nbytes", data.length);
285         // trace(cast(string) data);
286         // tracef("%(%02X %)", data);
287 
288         sendDataBuffer = data; //data[writeLen .. $]; // TODO: need more tests
289         _dataWriteBuffer.buf = cast(char*) sendDataBuffer.ptr;
290         _dataWriteBuffer.len = cast(uint) sendDataBuffer.length;
291     }
292 
293     /**
294      * Called by selector after data sent
295      * Note: It's only for IOCP selector: 
296     */
297     void onWriteDone(size_t nBytes) {
298         version (HUNT_DEBUG)
299             tracef("finishing writting: %d bytes", nBytes);
300         if (isWriteCancelling) {
301             _isWritting = false;
302             isWriteCancelling = false;
303             _writeQueue.clear(); // clean the data buffer 
304             return;
305         }
306 
307         if (writeBuffer.pop(nBytes)) {
308             if (_writeQueue.deQueue() is null) {
309                 version (HUNT_DEBUG)
310                     warning("_writeQueue is empty!");
311             }
312 
313             writeBuffer.finish();
314             _isWritting = false;
315 
316             version (HUNT_DEBUG)
317                 tracef("writting done: %d bytes", nBytes);
318 
319             tryWrite();
320         } else // if (sendDataBuffer.length > nBytes) 
321         {
322             // version (HUNT_DEBUG)
323             tracef("remaining nbytes: %d", sendDataBuffer.length - nBytes);
324             // FIXME: Needing refactor or cleanup -@Administrator at 2018-6-12 13:56:17
325             // sendDataBuffer corrupted
326             // const(ubyte)[] data = writeBuffer.remaining();
327             // tracef("%(%02X %)", data);
328             // tracef("%(%02X %)", sendDataBuffer);
329             setWriteBuffer(sendDataBuffer[nBytes .. $]); // send remaining
330             nBytes = doWrite();
331         }
332     }
333 
334     void cancelWrite() {
335         isWriteCancelling = true;
336     }
337 
338     protected void onDisconnected() {
339         _isConnected = false;
340         _isClosed = true;
341         if (disconnectionHandler !is null)
342             disconnectionHandler();
343     }
344 
345     bool _isConnected; //if server side always true.
346     SimpleEventHandler disconnectionHandler;
347 
348     protected WriteBufferQueue _writeQueue;
349     protected bool isWriteCancelling = false;
350     private const(ubyte)[] _readBuffer;
351     private const(ubyte)[] sendDataBuffer;
352     private StreamWriteBuffer writeBuffer;
353 
354     private IocpContext _iocpread;
355     private IocpContext _iocpwrite;
356 
357     private WSABUF _dataReadBuffer;
358     private WSABUF _dataWriteBuffer;
359 
360     private bool _inWrite;
361     private bool _inRead;
362 }
363 
364 /**
365 UDP Socket
366 */
367 abstract class AbstractDatagramSocket : AbstractSocketChannel {
368     /// Constructs a blocking IPv4 UDP Socket.
369     this(Selector loop, AddressFamily family = AddressFamily.INET) {
370         super(loop, ChannelType.UDP);
371         setFlag(ChannelFlag.Read, true);
372         setFlag(ChannelFlag.ETMode, false);
373 
374         this.socket = new UdpSocket(family);
375         _readBuffer = new UdpDataObject();
376         _readBuffer.data = new ubyte[4096 * 2];
377 
378         if (family == AddressFamily.INET)
379             _bindAddress = new InternetAddress(InternetAddress.PORT_ANY);
380         else if (family == AddressFamily.INET6)
381             _bindAddress = new Internet6Address(Internet6Address.PORT_ANY);
382         else
383             _bindAddress = new UnknownAddress();
384     }
385 
386     final void bind(Address addr) {
387         if (_binded)
388             return;
389         _bindAddress = addr;
390         socket.bind(_bindAddress);
391         _binded = true;
392     }
393 
394     final bool isBind() {
395         return _binded;
396     }
397 
398     Address bindAddr() {
399         return _bindAddress;
400     }
401 
402     override void start() {
403         if (!_binded) {
404             socket.bind(_bindAddress);
405             _binded = true;
406         }
407     }
408 
409     // abstract void doRead();
410 
411     private UdpDataObject _readBuffer;
412     protected bool _binded = false;
413     protected Address _bindAddress;
414 
415     version (HAVE_IOCP) {
416         mixin CheckIocpError;
417 
418         void doRead() {
419             version (HUNT_DEBUG)
420                 trace("Receiving......");
421 
422             _dataReadBuffer.len = cast(uint) _readBuffer.data.length;
423             _dataReadBuffer.buf = cast(char*) _readBuffer.data.ptr;
424             _iocpread.channel = this;
425             _iocpread.operation = IocpOperation.read;
426             remoteAddrLen = cast(int) bindAddr().nameLen();
427 
428             DWORD dwReceived = 0;
429             DWORD dwFlags = 0;
430 
431             int nRet = WSARecvFrom(cast(SOCKET) this.handle, &_dataReadBuffer,
432                     cast(uint) 1, &dwReceived, &dwFlags, cast(SOCKADDR*)&remoteAddr, &remoteAddrLen,
433                     &_iocpread.overlapped, cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
434             checkErro(nRet, SOCKET_ERROR);
435         }
436 
437         Address buildAddress() {
438             Address tmpaddr;
439             if (remoteAddrLen == 32) {
440                 sockaddr_in* addr = cast(sockaddr_in*)(&remoteAddr);
441                 tmpaddr = new InternetAddress(*addr);
442             } else {
443                 sockaddr_in6* addr = cast(sockaddr_in6*)(&remoteAddr);
444                 tmpaddr = new Internet6Address(*addr);
445             }
446             return tmpaddr;
447         }
448 
449         bool tryRead(scope ReadCallBack read) {
450             this.clearError();
451             if (this.readLen == 0) {
452                 read(null);
453             } else {
454                 ubyte[] data = this._readBuffer.data;
455                 this._readBuffer.data = data[0 .. this.readLen];
456                 this._readBuffer.addr = this.buildAddress();
457                 scope (exit)
458                     this._readBuffer.data = data;
459                 read(this._readBuffer);
460                 this._readBuffer.data = data;
461                 if (this.isRegistered)
462                     this.doRead();
463             }
464             return false;
465         }
466 
467         IocpContext _iocpread;
468         WSABUF _dataReadBuffer;
469 
470         sockaddr remoteAddr;
471         int remoteAddrLen;
472     }
473 
474 }
475 
476 /**
477 */
478 mixin template CheckIocpError() {
479     void checkErro(int ret, int erro = 0) {
480         DWORD dwLastError = GetLastError();
481         version (HUNT_DEBUG)
482             infof("erro=%d, dwLastError=%d", erro, dwLastError);
483         if (ret != erro || dwLastError == 0)
484             return;
485 
486         if (ERROR_IO_PENDING != dwLastError) { // ERROR_IO_PENDING
487             import hunt.system.Error;
488             warningf("erro=%d, dwLastError=%d", erro, dwLastError);
489             this._error = true;
490             this._erroString = getErrorMessage(dwLastError); // format("IOCP error: code=%s", dwLastError);
491         }
492     }
493 }
494 
495 enum IocpOperation {
496     accept,
497     connect,
498     read,
499     write,
500     event,
501     close
502 }
503 
504 struct IocpContext {
505     OVERLAPPED overlapped;
506     IocpOperation operation;
507     AbstractChannel channel = null;
508 }
509 
510 alias WSAOVERLAPPED = OVERLAPPED;
511 alias LPWSAOVERLAPPED = OVERLAPPED*;
512 
513 __gshared static LPFN_ACCEPTEX AcceptEx;
514 __gshared static LPFN_CONNECTEX ConnectEx;
515 /*__gshared LPFN_DISCONNECTEX DisconnectEx;
516 __gshared LPFN_GETACCEPTEXSOCKADDRS GetAcceptexSockAddrs;
517 __gshared LPFN_TRANSMITFILE TransmitFile;
518 __gshared LPFN_TRANSMITPACKETS TransmitPackets;
519 __gshared LPFN_WSARECVMSG WSARecvMsg;
520 __gshared LPFN_WSASENDMSG WSASendMsg;*/
521 
522 shared static this() {
523     WSADATA wsaData;
524     int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
525     if (iResult != 0) {
526         stderr.writeln("unable to load Winsock!");
527     }
528 }
529 
530 shared static ~this() {
531     WSACleanup();
532 }
533 
534 void loadWinsockExtension(SOCKET socket) {
535     if (isApiLoaded)
536         return;
537     isApiLoaded = true;
538 
539     // SOCKET ListenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
540     // scope (exit)
541     //     closesocket(ListenSocket);
542     GUID guid;
543     mixin(GET_FUNC_POINTER("WSAID_ACCEPTEX", "AcceptEx", socket.stringof));
544     mixin(GET_FUNC_POINTER("WSAID_CONNECTEX", "ConnectEx"));
545     /* mixin(GET_FUNC_POINTER("WSAID_DISCONNECTEX", "DisconnectEx"));
546      mixin(GET_FUNC_POINTER("WSAID_GETACCEPTEXSOCKADDRS", "GetAcceptexSockAddrs"));
547      mixin(GET_FUNC_POINTER("WSAID_TRANSMITFILE", "TransmitFile"));
548      mixin(GET_FUNC_POINTER("WSAID_TRANSMITPACKETS", "TransmitPackets"));
549      mixin(GET_FUNC_POINTER("WSAID_WSARECVMSG", "WSARecvMsg"));*/
550 }
551 
552 private __gshared bool isApiLoaded = false;
553 
554 private bool GetFunctionPointer(FuncPointer)(SOCKET sock, ref FuncPointer pfn, ref GUID guid) {
555     DWORD dwBytesReturned = 0;
556     if (WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, guid.sizeof,
557             &pfn, pfn.sizeof, &dwBytesReturned, null, null) == SOCKET_ERROR) {
558         error("Get function failed with error:", GetLastError());
559         return false;
560     }
561 
562     return true;
563 }
564 
565 private string GET_FUNC_POINTER(string GuidValue, string pft, string socket = "socket") {
566     string str = " guid = " ~ GuidValue ~ ";";
567     str ~= "if( !GetFunctionPointer( " ~ socket ~ ", " ~ pft
568         ~ ", guid ) ) { errnoEnforce(false,\"get function error!\"); } ";
569     return str;
570 }
571 
572 enum : DWORD {
573     IOCPARAM_MASK = 0x7f,
574     IOC_VOID = 0x20000000,
575     IOC_OUT = 0x40000000,
576     IOC_IN = 0x80000000,
577     IOC_INOUT = IOC_IN | IOC_OUT
578 }
579 
580 enum IOC_UNIX = 0x00000000;
581 enum IOC_WS2 = 0x08000000;
582 enum IOC_PROTOCOL = 0x10000000;
583 enum IOC_VENDOR = 0x18000000;
584 
585 template _WSAIO(int x, int y) {
586     enum _WSAIO = IOC_VOID | x | y;
587 }
588 
589 template _WSAIOR(int x, int y) {
590     enum _WSAIOR = IOC_OUT | x | y;
591 }
592 
593 template _WSAIOW(int x, int y) {
594     enum _WSAIOW = IOC_IN | x | y;
595 }
596 
597 template _WSAIORW(int x, int y) {
598     enum _WSAIORW = IOC_INOUT | x | y;
599 }
600 
601 enum SIO_ASSOCIATE_HANDLE = _WSAIOW!(IOC_WS2, 1);
602 enum SIO_ENABLE_CIRCULAR_QUEUEING = _WSAIO!(IOC_WS2, 2);
603 enum SIO_FIND_ROUTE = _WSAIOR!(IOC_WS2, 3);
604 enum SIO_FLUSH = _WSAIO!(IOC_WS2, 4);
605 enum SIO_GET_BROADCAST_ADDRESS = _WSAIOR!(IOC_WS2, 5);
606 enum SIO_GET_EXTENSION_FUNCTION_POINTER = _WSAIORW!(IOC_WS2, 6);
607 enum SIO_GET_QOS = _WSAIORW!(IOC_WS2, 7);
608 enum SIO_GET_GROUP_QOS = _WSAIORW!(IOC_WS2, 8);
609 enum SIO_MULTIPOINT_LOOPBACK = _WSAIOW!(IOC_WS2, 9);
610 enum SIO_MULTICAST_SCOPE = _WSAIOW!(IOC_WS2, 10);
611 enum SIO_SET_QOS = _WSAIOW!(IOC_WS2, 11);
612 enum SIO_SET_GROUP_QOS = _WSAIOW!(IOC_WS2, 12);
613 enum SIO_TRANSLATE_HANDLE = _WSAIORW!(IOC_WS2, 13);
614 enum SIO_ROUTING_INTERFACE_QUERY = _WSAIORW!(IOC_WS2, 20);
615 enum SIO_ROUTING_INTERFACE_CHANGE = _WSAIOW!(IOC_WS2, 21);
616 enum SIO_ADDRESS_LIST_QUERY = _WSAIOR!(IOC_WS2, 22);
617 enum SIO_ADDRESS_LIST_CHANGE = _WSAIO!(IOC_WS2, 23);
618 enum SIO_QUERY_TARGET_PNP_HANDLE = _WSAIOR!(IOC_WS2, 24);
619 enum SIO_NSP_NOTIFY_CHANGE = _WSAIOW!(IOC_WS2, 25);
620 
621 extern (Windows):
622 nothrow:
623 int WSARecv(SOCKET, LPWSABUF, DWORD, LPDWORD, LPDWORD, LPWSAOVERLAPPED,
624         LPWSAOVERLAPPED_COMPLETION_ROUTINE);
625 int WSARecvDisconnect(SOCKET, LPWSABUF);
626 int WSARecvFrom(SOCKET, LPWSABUF, DWORD, LPDWORD, LPDWORD, SOCKADDR*, LPINT,
627         LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE);
628 
629 int WSASend(SOCKET, LPWSABUF, DWORD, LPDWORD, DWORD, LPWSAOVERLAPPED,
630         LPWSAOVERLAPPED_COMPLETION_ROUTINE);
631 int WSASendDisconnect(SOCKET, LPWSABUF);
632 int WSASendTo(SOCKET, LPWSABUF, DWORD, LPDWORD, DWORD, const(SOCKADDR)*, int,
633         LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE);