warp-vpn

point to point VPN implementation
git clone git://git.2f30.org/warp-vpn
Log | Files | Refs | README

commit eec19fac6675df594355c653abf331a6b17c8057
parent 777ef1489c602ac28cc6a068b2a7a48f77126ce1
Author: sin <sin@2f30.org>
Date:   Tue, 29 Mar 2016 10:11:13 +0100

frail attempt at handling "bad packets"

Diffstat:
Mstun.c | 49+++++++++++++++++++++++++++++++++++++++++--------
1 file changed, 41 insertions(+), 8 deletions(-)

diff --git a/stun.c b/stun.c @@ -46,6 +46,8 @@ #define CHALLENGETIMEO 1 /* in seconds */ #define RECONNECTTIMEO 60 /* in seconds */ #define HDRLEN 2 +#define PKTLENMASK 0xfff +#define BADPKT 0x8000 #define MTU 1440 enum { @@ -115,7 +117,7 @@ logerr(char *msg, ...) int padto(int len, int blocksize) { - return len + blocksize - (len & (blocksize - 1)); + return len + blocksize - len % blocksize; } int @@ -410,19 +412,36 @@ readnet(int fd, unsigned char *buf, int len) unsigned char hdr[HDRLEN]; int n, pktlen, paddedlen; + /* read unpadded packet length */ n = readall(fd, hdr, sizeof(hdr)); if (n <= 0) return n; + pktlen = unpack16(hdr); + pktlen &= PKTLENMASK; paddedlen = padto(pktlen, AES_BLOCK_SIZE); + + /* attempt to drain an invalid packet */ if (pktlen < 0 || pktlen > len || paddedlen < 0 || paddedlen > MTU + AES_BLOCK_SIZE) { - logwarn("bogus payload length: %d", paddedlen); - return -1; + while (paddedlen) { + if (paddedlen > MTU + AES_BLOCK_SIZE) + len = MTU + AES_BLOCK_SIZE; + else + len = paddedlen; + n = readall(fd, encbuf, len); + if (n <= 0) + break; + paddedlen -= n; + } + return BADPKT; } + + /* read encrypted payload */ n = readall(fd, encbuf, paddedlen); if (n <= 0) return n; + aesdec(&dec, decbuf, encbuf, paddedlen); memcpy(buf, decbuf, pktlen); return pktlen; @@ -468,7 +487,8 @@ challenge(int netfd) goto err; default: if (pfd[0].revents & (POLLIN | POLLHUP)) { - if (readnet(netfd, buf, sizeof(buf)) <= 0) + ret = readnet(netfd, buf, sizeof(buf)); + if (ret <= 0 || ret == BADPKT) goto err; reply = unpack32(buf); if (n + 1 == reply) { @@ -487,8 +507,10 @@ response(int netfd) { unsigned char buf[sizeof(uint32_t)]; uint32_t reply; + int ret; - if (readnet(netfd, buf, sizeof(buf)) <= 0) + ret = readnet(netfd, buf, sizeof(buf)); + if (ret <= 0 || ret == BADPKT) return -1; reply = unpack32(buf); pack32(buf, reply + 1); @@ -512,13 +534,24 @@ tunnel(int netfd, int devfd) ret = poll(pfd, 2, -1); if (ret < 0) logerr("poll failed"); + if (pfd[0].revents & (POLLERR | POLLNVAL) || pfd[1].revents & (POLLERR | POLLNVAL)) logerr("bad fd in poll set"); - if (pfd[0].revents & (POLLIN | POLLHUP)) - if ((n = readnet(netfd, buf, MTU)) <= 0 || - (n = writedev(devfd, buf, n)) <= 0) + + if (pfd[0].revents & (POLLIN | POLLHUP)) { + n = readnet(netfd, buf, MTU); + if (n <= 0) return 1; + if (n == BADPKT) { + logwarn("bad packet"); + } else { + n = writedev(devfd, buf, n); + if (n <= 0) + return 1; + } + } + if (pfd[1].revents & (POLLIN | POLLHUP)) if ((n = readdev(devfd, buf, MTU)) <= 0 || (n = writenet(netfd, buf, n)) <= 0)