commit f10bf9bf69a68cac25e77468607f3c7dcce17d53
parent b0a8fd65be5ea2d7c1024d06587c513cb952dfe7
Author: sin <sin@2f30.org>
Date: Thu, 31 Mar 2016 14:37:12 +0100
more robust handling of corrupt packets
Diffstat:
M | stun.c | | | 155 | +++++++++++++++++++++++++++++++++++-------------------------------------------- |
1 file changed, 69 insertions(+), 86 deletions(-)
diff --git a/stun.c b/stun.c
@@ -27,7 +27,7 @@
* All tunneled traffic is encapsulated inside the TCP payload.
* The packet format is shown below:
*
- * [PAYLOAD LENGTH] [IV] [PAYLOAD] [TAG]
+ * [TAG] [IV] [PAYLOAD LENGTH] [PAYLOAD]
*
* Where payload length is 2 octets, IV is 12 octets and tag is 16 octects.
*/
@@ -77,14 +77,13 @@
#endif
#define NOPRIVUSER "nobody"
-#define CHALLENGETIMEO 1 /* in seconds */
-#define RECONNECTTIMEO 60 /* in seconds */
+#define RCVTIMEO 250 /* in milliseconds */
+#define RECONNECTTIMEO 60 /* in seconds */
#define MTU 1412
#define HDRLEN 2
#define IVLEN 12
#define TAGLEN 16
#define MAXPKTLEN (MTU + AES_BLOCK_SIZE + HDRLEN + IVLEN + TAGLEN)
-#define PKTLENMASK 0xfff
#define BADPKT 0x8000
enum {
@@ -223,33 +222,26 @@ readall(int fd, void *buf, int len)
return total;
}
+void
+ms2tv(struct timeval *tv, long ms)
+{
+ tv->tv_sec = ms / 1000;
+ tv->tv_usec = (ms % 1000) * 1000;
+}
+
int
-setrcvtimeo(int fd, time_t sec)
+setrcvtimeo(int fd, long ms)
{
struct timeval tv;
int ret;
- tv.tv_sec = sec;
- tv.tv_usec = 0;
+ ms2tv(&tv, ms);
ret = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
if (ret < 0)
logwarn("failed to set timeout on socket");
return ret;
}
-int
-setnonblock(int fd, int mode)
-{
- int flags;
-
- flags = fcntl(fd, F_GETFL);
- if (mode)
- flags |= O_NONBLOCK;
- else
- flags &= ~O_NONBLOCK;
- return fcntl(fd, F_SETFL, flags);
-}
-
void
revokeprivs(void)
{
@@ -284,7 +276,8 @@ aesinit(EVP_CIPHER_CTX *ectx, EVP_CIPHER_CTX *dctx)
int
aesenc(EVP_CIPHER_CTX *ctx, unsigned char *ct, unsigned char *pt, int plen,
- unsigned char *key, unsigned char *iv, unsigned char *tag, int taglen)
+ unsigned char *key, unsigned char *iv, unsigned char *aad, int aadlen,
+ unsigned char *tag, int taglen)
{
int clen, flen;
@@ -294,6 +287,9 @@ aesenc(EVP_CIPHER_CTX *ctx, unsigned char *ct, unsigned char *pt, int plen,
if (EVP_EncryptInit_ex(ctx, NULL, NULL, key, iv) != 1)
logerr("EVP_EncryptInit_ex failed");
+ if (EVP_EncryptUpdate(ctx, NULL, &clen, aad, aadlen) != 1)
+ logerr("EVP_EncryptUpdate failed");
+
if (EVP_EncryptUpdate(ctx, ct, &clen, pt, plen) != 1)
logerr("EVP_EncryptUpdate failed");
@@ -308,7 +304,8 @@ aesenc(EVP_CIPHER_CTX *ctx, unsigned char *ct, unsigned char *pt, int plen,
int
aesdec(EVP_CIPHER_CTX *ctx, unsigned char *pt, unsigned char *ct, int clen,
- unsigned char *key, unsigned char *iv, unsigned char *tag, int taglen)
+ unsigned char *key, unsigned char *iv, unsigned char *aad, int aadlen,
+ unsigned char *tag, int taglen)
{
int plen, flen;
@@ -318,6 +315,9 @@ aesdec(EVP_CIPHER_CTX *ctx, unsigned char *pt, unsigned char *ct, int clen,
if (EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv) != 1)
logerr("EVP_DecryptInit_ex failed");
+ if (EVP_DecryptUpdate(ctx, NULL, &plen, aad, aadlen) != 1)
+ logerr("EVP_DecryptUpdate failed");
+
if (EVP_DecryptUpdate(ctx, pt, &plen, ct, clen) != 1)
logerr("EVP_DecryptUpdate failed");
@@ -481,14 +481,14 @@ writenet(int fd, unsigned char *pt, int len)
unsigned char hdr[HDRLEN], iv[IVLEN], tag[TAGLEN];
unsigned char pkt[MAXPKTLEN];
- arc4random_buf(iv, IVLEN);
- len = aesenc(&ectx, payload, pt, len, aeskey, iv, tag, TAGLEN);
pack16(hdr, len);
- memcpy(pkt, hdr, HDRLEN);
- memcpy(&pkt[HDRLEN], iv, IVLEN);
- memcpy(&pkt[HDRLEN + IVLEN], payload, len);
- memcpy(&pkt[HDRLEN + IVLEN + len], tag, TAGLEN);
- len += IVLEN + HDRLEN + TAGLEN;
+ arc4random_buf(iv, IVLEN);
+ aesenc(&ectx, payload, pt, len, aeskey, iv, hdr, HDRLEN, tag, TAGLEN);
+ memcpy(pkt, tag, TAGLEN);
+ memcpy(&pkt[TAGLEN], iv, IVLEN);
+ memcpy(&pkt[TAGLEN + IVLEN], hdr, HDRLEN);
+ memcpy(&pkt[TAGLEN + IVLEN + HDRLEN], payload, len);
+ len += TAGLEN + IVLEN + HDRLEN;
return writeall(fd, pkt, len);
}
@@ -499,43 +499,34 @@ readnet(int fd, unsigned char *pt, int len)
unsigned char hdr[HDRLEN], iv[IVLEN], tag[TAGLEN];
int n, pktlen;
- n = readall(fd, hdr, HDRLEN);
- if (n <= 0)
- return n;
+#define CHECKERR(n) do { \
+ if ((n) == 0) { \
+ return 0; \
+ } else if ((n) < 0) { \
+ if (errno != EWOULDBLOCK) \
+ return -1; \
+ return BADPKT; \
+ } \
+} while (0)
- pktlen = unpack16(hdr);
- pktlen &= PKTLENMASK;
-
- /* discard bad packets */
- if (pktlen > sizeof(payload)) {
- setnonblock(fd, 1);
- while (pktlen) {
- if (pktlen > sizeof(payload))
- len = sizeof(payload);
- else
- len = pktlen;
- n = readall(fd, payload, len);
- if (n <= 0)
- break;
- pktlen -= n;
- }
- setnonblock(fd, 0);
- return BADPKT;
- }
+ n = readall(fd, tag, TAGLEN);
+ CHECKERR(n);
n = readall(fd, iv, IVLEN);
- if (n <= 0)
- return n;
+ CHECKERR(n);
- n = readall(fd, payload, pktlen);
- if (n <= 0)
- return n;
+ n = readall(fd, hdr, HDRLEN);
+ CHECKERR(n);
- n = readall(fd, tag, TAGLEN);
- if (n <= 0)
- return n;
+ pktlen = unpack16(hdr);
+ if (pktlen > sizeof(payload))
+ pktlen = sizeof(payload);
+
+ n = readall(fd, payload, pktlen);
+ CHECKERR(n);
- pktlen = aesdec(&dctx, pt, payload, pktlen, aeskey, iv, tag, TAGLEN);
+ pktlen = aesdec(&dctx, pt, payload, pktlen, aeskey, iv, hdr, HDRLEN,
+ tag, TAGLEN);
if (pktlen < 0)
return BADPKT;
return pktlen;
@@ -549,37 +540,27 @@ challenge(int netfd)
uint64_t n, reply;
int ret;
- ret = setrcvtimeo(netfd, CHALLENGETIMEO);
- if (ret < 0)
- return -1;
arc4random_buf(&n, sizeof(uint64_t));
pack64(buf, n);
if (writenet(netfd, buf, sizeof(uint64_t)) <= 0)
- goto err;
- pfd[0].fd = netfd;
- pfd[0].events = POLLIN;
- ret = poll(pfd, 1, CHALLENGETIMEO * 1000);
- switch (ret) {
- case -1:
- logwarn("poll failed");
- goto err;
- case 0:
- logwarn("challenge-response timed out");
- goto err;
- default:
+ return -1;
+ for (;;) {
+ pfd[0].fd = netfd;
+ pfd[0].events = POLLIN;
+ ret = poll(pfd, 1, -1);
+ if (ret < 0) {
+ logwarn("poll failed");
+ return -1;
+ }
if (pfd[0].revents & (POLLIN | POLLHUP)) {
ret = readnet(netfd, buf, sizeof(uint64_t));
if (ret <= 0 || ret == BADPKT)
- goto err;
+ return -1;
reply = unpack64(buf);
- if (n + 1 == reply) {
- setrcvtimeo(netfd, 0);
+ if (n + 1 == reply)
return 0;
- }
}
}
-err:
- setrcvtimeo(netfd, 0);
return -1;
}
@@ -623,20 +604,20 @@ tunnel(int netfd, int devfd)
if (pfd[0].revents & (POLLIN | POLLHUP)) {
n = readnet(netfd, buf, sizeof(buf));
if (n <= 0)
- return 1;
+ return -1;
if (n == BADPKT) {
logwarn("bad packet");
} else {
n = writedev(devfd, buf, n);
if (n <= 0)
- return 1;
+ return -1;
}
}
if (pfd[1].revents & (POLLIN | POLLHUP))
if ((n = readdev(devfd, buf, MTU)) <= 0 ||
(n = writenet(netfd, buf, n)) <= 0)
- return 1;
+ return -1;
}
return 0;
}
@@ -688,6 +669,7 @@ serversetup(int devfd)
logdbg("remote peer connected: %s",
inet_ntoa(remote.sin_addr));
+ setrcvtimeo(netfd, RCVTIMEO);
setsockopt(netfd, IPPROTO_TCP, TCP_NODELAY, (int []){1}, sizeof(int));
ret = challenge(netfd);
@@ -742,6 +724,7 @@ clientsetup(int devfd)
return -1;
}
+ setrcvtimeo(netfd, RCVTIMEO);
setsockopt(netfd, IPPROTO_TCP, TCP_NODELAY, (int []){1}, sizeof(int));
if (debug)
@@ -759,11 +742,11 @@ clientsetup(int devfd)
goto err;
}
- ret = tunnel(netfd, devfd);
+ tunnel(netfd, devfd);
err:
logwarn("connection to %s:%s dropped", host, port);
close(netfd);
- return ret;
+ return -1;
}
void