warp-vpn

point to point VPN implementation
git clone git://git.2f30.org/warp-vpn
Log | Files | Refs | README

commit f10bf9bf69a68cac25e77468607f3c7dcce17d53
parent b0a8fd65be5ea2d7c1024d06587c513cb952dfe7
Author: sin <sin@2f30.org>
Date:   Thu, 31 Mar 2016 14:37:12 +0100

more robust handling of corrupt packets

Diffstat:
Mstun.c | 155+++++++++++++++++++++++++++++++++++--------------------------------------------
1 file changed, 69 insertions(+), 86 deletions(-)

diff --git a/stun.c b/stun.c @@ -27,7 +27,7 @@ * All tunneled traffic is encapsulated inside the TCP payload. * The packet format is shown below: * - * [PAYLOAD LENGTH] [IV] [PAYLOAD] [TAG] + * [TAG] [IV] [PAYLOAD LENGTH] [PAYLOAD] * * Where payload length is 2 octets, IV is 12 octets and tag is 16 octects. */ @@ -77,14 +77,13 @@ #endif #define NOPRIVUSER "nobody" -#define CHALLENGETIMEO 1 /* in seconds */ -#define RECONNECTTIMEO 60 /* in seconds */ +#define RCVTIMEO 250 /* in milliseconds */ +#define RECONNECTTIMEO 60 /* in seconds */ #define MTU 1412 #define HDRLEN 2 #define IVLEN 12 #define TAGLEN 16 #define MAXPKTLEN (MTU + AES_BLOCK_SIZE + HDRLEN + IVLEN + TAGLEN) -#define PKTLENMASK 0xfff #define BADPKT 0x8000 enum { @@ -223,33 +222,26 @@ readall(int fd, void *buf, int len) return total; } +void +ms2tv(struct timeval *tv, long ms) +{ + tv->tv_sec = ms / 1000; + tv->tv_usec = (ms % 1000) * 1000; +} + int -setrcvtimeo(int fd, time_t sec) +setrcvtimeo(int fd, long ms) { struct timeval tv; int ret; - tv.tv_sec = sec; - tv.tv_usec = 0; + ms2tv(&tv, ms); ret = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); if (ret < 0) logwarn("failed to set timeout on socket"); return ret; } -int -setnonblock(int fd, int mode) -{ - int flags; - - flags = fcntl(fd, F_GETFL); - if (mode) - flags |= O_NONBLOCK; - else - flags &= ~O_NONBLOCK; - return fcntl(fd, F_SETFL, flags); -} - void revokeprivs(void) { @@ -284,7 +276,8 @@ aesinit(EVP_CIPHER_CTX *ectx, EVP_CIPHER_CTX *dctx) int aesenc(EVP_CIPHER_CTX *ctx, unsigned char *ct, unsigned char *pt, int plen, - unsigned char *key, unsigned char *iv, unsigned char *tag, int taglen) + unsigned char *key, unsigned char *iv, unsigned char *aad, int aadlen, + unsigned char *tag, int taglen) { int clen, flen; @@ -294,6 +287,9 @@ aesenc(EVP_CIPHER_CTX *ctx, unsigned char *ct, unsigned char *pt, int plen, if (EVP_EncryptInit_ex(ctx, NULL, NULL, key, iv) != 1) logerr("EVP_EncryptInit_ex failed"); + if (EVP_EncryptUpdate(ctx, NULL, &clen, aad, aadlen) != 1) + logerr("EVP_EncryptUpdate failed"); + if (EVP_EncryptUpdate(ctx, ct, &clen, pt, plen) != 1) logerr("EVP_EncryptUpdate failed"); @@ -308,7 +304,8 @@ aesenc(EVP_CIPHER_CTX *ctx, unsigned char *ct, unsigned char *pt, int plen, int aesdec(EVP_CIPHER_CTX *ctx, unsigned char *pt, unsigned char *ct, int clen, - unsigned char *key, unsigned char *iv, unsigned char *tag, int taglen) + unsigned char *key, unsigned char *iv, unsigned char *aad, int aadlen, + unsigned char *tag, int taglen) { int plen, flen; @@ -318,6 +315,9 @@ aesdec(EVP_CIPHER_CTX *ctx, unsigned char *pt, unsigned char *ct, int clen, if (EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv) != 1) logerr("EVP_DecryptInit_ex failed"); + if (EVP_DecryptUpdate(ctx, NULL, &plen, aad, aadlen) != 1) + logerr("EVP_DecryptUpdate failed"); + if (EVP_DecryptUpdate(ctx, pt, &plen, ct, clen) != 1) logerr("EVP_DecryptUpdate failed"); @@ -481,14 +481,14 @@ writenet(int fd, unsigned char *pt, int len) unsigned char hdr[HDRLEN], iv[IVLEN], tag[TAGLEN]; unsigned char pkt[MAXPKTLEN]; - arc4random_buf(iv, IVLEN); - len = aesenc(&ectx, payload, pt, len, aeskey, iv, tag, TAGLEN); pack16(hdr, len); - memcpy(pkt, hdr, HDRLEN); - memcpy(&pkt[HDRLEN], iv, IVLEN); - memcpy(&pkt[HDRLEN + IVLEN], payload, len); - memcpy(&pkt[HDRLEN + IVLEN + len], tag, TAGLEN); - len += IVLEN + HDRLEN + TAGLEN; + arc4random_buf(iv, IVLEN); + aesenc(&ectx, payload, pt, len, aeskey, iv, hdr, HDRLEN, tag, TAGLEN); + memcpy(pkt, tag, TAGLEN); + memcpy(&pkt[TAGLEN], iv, IVLEN); + memcpy(&pkt[TAGLEN + IVLEN], hdr, HDRLEN); + memcpy(&pkt[TAGLEN + IVLEN + HDRLEN], payload, len); + len += TAGLEN + IVLEN + HDRLEN; return writeall(fd, pkt, len); } @@ -499,43 +499,34 @@ readnet(int fd, unsigned char *pt, int len) unsigned char hdr[HDRLEN], iv[IVLEN], tag[TAGLEN]; int n, pktlen; - n = readall(fd, hdr, HDRLEN); - if (n <= 0) - return n; +#define CHECKERR(n) do { \ + if ((n) == 0) { \ + return 0; \ + } else if ((n) < 0) { \ + if (errno != EWOULDBLOCK) \ + return -1; \ + return BADPKT; \ + } \ +} while (0) - pktlen = unpack16(hdr); - pktlen &= PKTLENMASK; - - /* discard bad packets */ - if (pktlen > sizeof(payload)) { - setnonblock(fd, 1); - while (pktlen) { - if (pktlen > sizeof(payload)) - len = sizeof(payload); - else - len = pktlen; - n = readall(fd, payload, len); - if (n <= 0) - break; - pktlen -= n; - } - setnonblock(fd, 0); - return BADPKT; - } + n = readall(fd, tag, TAGLEN); + CHECKERR(n); n = readall(fd, iv, IVLEN); - if (n <= 0) - return n; + CHECKERR(n); - n = readall(fd, payload, pktlen); - if (n <= 0) - return n; + n = readall(fd, hdr, HDRLEN); + CHECKERR(n); - n = readall(fd, tag, TAGLEN); - if (n <= 0) - return n; + pktlen = unpack16(hdr); + if (pktlen > sizeof(payload)) + pktlen = sizeof(payload); + + n = readall(fd, payload, pktlen); + CHECKERR(n); - pktlen = aesdec(&dctx, pt, payload, pktlen, aeskey, iv, tag, TAGLEN); + pktlen = aesdec(&dctx, pt, payload, pktlen, aeskey, iv, hdr, HDRLEN, + tag, TAGLEN); if (pktlen < 0) return BADPKT; return pktlen; @@ -549,37 +540,27 @@ challenge(int netfd) uint64_t n, reply; int ret; - ret = setrcvtimeo(netfd, CHALLENGETIMEO); - if (ret < 0) - return -1; arc4random_buf(&n, sizeof(uint64_t)); pack64(buf, n); if (writenet(netfd, buf, sizeof(uint64_t)) <= 0) - goto err; - pfd[0].fd = netfd; - pfd[0].events = POLLIN; - ret = poll(pfd, 1, CHALLENGETIMEO * 1000); - switch (ret) { - case -1: - logwarn("poll failed"); - goto err; - case 0: - logwarn("challenge-response timed out"); - goto err; - default: + return -1; + for (;;) { + pfd[0].fd = netfd; + pfd[0].events = POLLIN; + ret = poll(pfd, 1, -1); + if (ret < 0) { + logwarn("poll failed"); + return -1; + } if (pfd[0].revents & (POLLIN | POLLHUP)) { ret = readnet(netfd, buf, sizeof(uint64_t)); if (ret <= 0 || ret == BADPKT) - goto err; + return -1; reply = unpack64(buf); - if (n + 1 == reply) { - setrcvtimeo(netfd, 0); + if (n + 1 == reply) return 0; - } } } -err: - setrcvtimeo(netfd, 0); return -1; } @@ -623,20 +604,20 @@ tunnel(int netfd, int devfd) if (pfd[0].revents & (POLLIN | POLLHUP)) { n = readnet(netfd, buf, sizeof(buf)); if (n <= 0) - return 1; + return -1; if (n == BADPKT) { logwarn("bad packet"); } else { n = writedev(devfd, buf, n); if (n <= 0) - return 1; + return -1; } } if (pfd[1].revents & (POLLIN | POLLHUP)) if ((n = readdev(devfd, buf, MTU)) <= 0 || (n = writenet(netfd, buf, n)) <= 0) - return 1; + return -1; } return 0; } @@ -688,6 +669,7 @@ serversetup(int devfd) logdbg("remote peer connected: %s", inet_ntoa(remote.sin_addr)); + setrcvtimeo(netfd, RCVTIMEO); setsockopt(netfd, IPPROTO_TCP, TCP_NODELAY, (int []){1}, sizeof(int)); ret = challenge(netfd); @@ -742,6 +724,7 @@ clientsetup(int devfd) return -1; } + setrcvtimeo(netfd, RCVTIMEO); setsockopt(netfd, IPPROTO_TCP, TCP_NODELAY, (int []){1}, sizeof(int)); if (debug) @@ -759,11 +742,11 @@ clientsetup(int devfd) goto err; } - ret = tunnel(netfd, devfd); + tunnel(netfd, devfd); err: logwarn("connection to %s:%s dropped", host, port); close(netfd); - return ret; + return -1; } void