commit eec19fac6675df594355c653abf331a6b17c8057
parent 777ef1489c602ac28cc6a068b2a7a48f77126ce1
Author: sin <sin@2f30.org>
Date: Tue, 29 Mar 2016 10:11:13 +0100
frail attempt at handling "bad packets"
Diffstat:
M | stun.c | | | 49 | +++++++++++++++++++++++++++++++++++++++++-------- |
1 file changed, 41 insertions(+), 8 deletions(-)
diff --git a/stun.c b/stun.c
@@ -46,6 +46,8 @@
#define CHALLENGETIMEO 1 /* in seconds */
#define RECONNECTTIMEO 60 /* in seconds */
#define HDRLEN 2
+#define PKTLENMASK 0xfff
+#define BADPKT 0x8000
#define MTU 1440
enum {
@@ -115,7 +117,7 @@ logerr(char *msg, ...)
int
padto(int len, int blocksize)
{
- return len + blocksize - (len & (blocksize - 1));
+ return len + blocksize - len % blocksize;
}
int
@@ -410,19 +412,36 @@ readnet(int fd, unsigned char *buf, int len)
unsigned char hdr[HDRLEN];
int n, pktlen, paddedlen;
+ /* read unpadded packet length */
n = readall(fd, hdr, sizeof(hdr));
if (n <= 0)
return n;
+
pktlen = unpack16(hdr);
+ pktlen &= PKTLENMASK;
paddedlen = padto(pktlen, AES_BLOCK_SIZE);
+
+ /* attempt to drain an invalid packet */
if (pktlen < 0 || pktlen > len ||
paddedlen < 0 || paddedlen > MTU + AES_BLOCK_SIZE) {
- logwarn("bogus payload length: %d", paddedlen);
- return -1;
+ while (paddedlen) {
+ if (paddedlen > MTU + AES_BLOCK_SIZE)
+ len = MTU + AES_BLOCK_SIZE;
+ else
+ len = paddedlen;
+ n = readall(fd, encbuf, len);
+ if (n <= 0)
+ break;
+ paddedlen -= n;
+ }
+ return BADPKT;
}
+
+ /* read encrypted payload */
n = readall(fd, encbuf, paddedlen);
if (n <= 0)
return n;
+
aesdec(&dec, decbuf, encbuf, paddedlen);
memcpy(buf, decbuf, pktlen);
return pktlen;
@@ -468,7 +487,8 @@ challenge(int netfd)
goto err;
default:
if (pfd[0].revents & (POLLIN | POLLHUP)) {
- if (readnet(netfd, buf, sizeof(buf)) <= 0)
+ ret = readnet(netfd, buf, sizeof(buf));
+ if (ret <= 0 || ret == BADPKT)
goto err;
reply = unpack32(buf);
if (n + 1 == reply) {
@@ -487,8 +507,10 @@ response(int netfd)
{
unsigned char buf[sizeof(uint32_t)];
uint32_t reply;
+ int ret;
- if (readnet(netfd, buf, sizeof(buf)) <= 0)
+ ret = readnet(netfd, buf, sizeof(buf));
+ if (ret <= 0 || ret == BADPKT)
return -1;
reply = unpack32(buf);
pack32(buf, reply + 1);
@@ -512,13 +534,24 @@ tunnel(int netfd, int devfd)
ret = poll(pfd, 2, -1);
if (ret < 0)
logerr("poll failed");
+
if (pfd[0].revents & (POLLERR | POLLNVAL) ||
pfd[1].revents & (POLLERR | POLLNVAL))
logerr("bad fd in poll set");
- if (pfd[0].revents & (POLLIN | POLLHUP))
- if ((n = readnet(netfd, buf, MTU)) <= 0 ||
- (n = writedev(devfd, buf, n)) <= 0)
+
+ if (pfd[0].revents & (POLLIN | POLLHUP)) {
+ n = readnet(netfd, buf, MTU);
+ if (n <= 0)
return 1;
+ if (n == BADPKT) {
+ logwarn("bad packet");
+ } else {
+ n = writedev(devfd, buf, n);
+ if (n <= 0)
+ return 1;
+ }
+ }
+
if (pfd[1].revents & (POLLIN | POLLHUP))
if ((n = readdev(devfd, buf, MTU)) <= 0 ||
(n = writenet(netfd, buf, n)) <= 0)