stun

simple point to point tunnel
git clone git://git.2f30.org/stun
Log | Files | Refs | README

commit ac9a90a4cfa4d6650dd45b2f9102d29498737c77
parent a1f1de8150759ccf09dfabc1b02620462891a5bf
Author: sin <sin@2f30.org>
Date:   Tue, 12 Apr 2016 14:10:33 +0100

use non-blocking sockets

cleanup will follow

Diffstat:
Mauth.c | 85++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
Mcrypto.c | 2--
Mlog.c | 2--
Mnetpkt.c | 211++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------
Mstun.c | 26+++++++++-----------------
Mstun.h | 14+++++++++-----
Mutil.c | 54++++++++++++++----------------------------------------
7 files changed, 240 insertions(+), 154 deletions(-)

diff --git a/auth.c b/auth.c @@ -1,5 +1,3 @@ -#include <sys/time.h> - #include <poll.h> #include <stdint.h> #include <stdlib.h> @@ -14,37 +12,40 @@ int challenge(int netfd) { unsigned char buf[sizeof(uint64_t)]; + size_t outlen; struct pollfd pfd[1]; uint64_t n, reply; int ret; - arc4random_buf(&n, sizeof(uint64_t)); + arc4random_buf(&n, sizeof(buf)); pack64(buf, n); - if (netwrite(netfd, buf, sizeof(uint64_t)) <= 0) - return -1; - - pfd[0].fd = netfd; - pfd[0].events = POLLIN; - ret = poll(pfd, 1, RCVTIMEO); - if (ret < 0) { - logwarn("poll failed"); - return -1; - } else if (ret == 0) { - logwarn("challenge-response timed out"); + if (netwrite(netfd, buf, sizeof(buf), &outlen) == PKTFAILED) return -1; - } - if (pfd[0].revents & (POLLIN | POLLHUP)) { - ret = netread(netfd, buf, sizeof(uint64_t)); - if (ret <= 0) { + for (;;) { + pfd[0].fd = netfd; + pfd[0].events = POLLIN; + ret = poll(pfd, 1, RCVTIMEO); + if (ret < 0) { + logwarn("poll failed"); return -1; - } else if (ret == BADPKT) { - logwarn("bad packet"); + } else if (ret == 0) { + logwarn("challenge-response timed out"); return -1; } - reply = unpack64(buf); - if (n + 1 == reply) - return 0; + + if (pfd[0].revents & (POLLIN | POLLHUP)) { + ret = netread(netfd, buf, sizeof(buf), &outlen); + if (ret == PKTFAILED) { + return -1; + } else if (ret == PKTCOMPLETE) { + if (outlen != sizeof(buf)) + return -1; + reply = unpack64(buf); + if (n + 1 == reply) + return 0; + } + } } return -1; } @@ -53,19 +54,37 @@ int response(int netfd) { unsigned char buf[sizeof(uint64_t)]; + size_t outlen; + struct pollfd pfd[1]; uint64_t reply; int ret; - ret = netread(netfd, buf, sizeof(uint64_t)); - if (ret <= 0) { - return -1; - } else if (ret == BADPKT) { - logwarn("bad packet"); - return -1; + for (;;) { + pfd[0].fd = netfd; + pfd[0].events = POLLIN; + ret = poll(pfd, 1, RCVTIMEO); + if (ret < 0) { + logwarn("poll failed"); + return -1; + } else if (ret == 0) { + logwarn("challenge-response timed out"); + return -1; + } + + if (pfd[0].revents & (POLLIN | POLLHUP)) { + ret = netread(netfd, buf, sizeof(buf), &outlen); + if (ret == PKTFAILED) { + return -1; + } else if (ret == PKTCOMPLETE) { + if (outlen != sizeof(buf)) + return -1; + reply = unpack64(buf); + pack64(buf, reply + 1); + if (netwrite(netfd, buf, sizeof(buf), &outlen) == PKTFAILED) + return -1; + break; + } + } } - reply = unpack64(buf); - pack64(buf, reply + 1); - if (netwrite(netfd, buf, sizeof(uint64_t)) <= 0) - return -1; return 0; } diff --git a/crypto.c b/crypto.c @@ -1,5 +1,3 @@ -#include <sys/time.h> - #include <stdint.h> #include <string.h> diff --git a/log.c b/log.c @@ -1,5 +1,3 @@ -#include <sys/time.h> - #include <stdarg.h> #include <stdint.h> #include <stdio.h> diff --git a/netpkt.c b/netpkt.c @@ -1,8 +1,7 @@ -#include <sys/time.h> - #include <errno.h> #include <stdint.h> #include <stdlib.h> +#include <unistd.h> #if defined(__linux__) #include <bsd/stdlib.h> @@ -10,71 +9,173 @@ #include "stun.h" -static unsigned char *outpkt; -static unsigned char *inpkt; -static size_t maxpktlen; +enum { + STATEINITIAL, + STATENONCE, + STATEHDR, + STATEPAYLOAD, + STATETAG, + STATEOPEN, + STATEDISCARD +}; + +static unsigned char *wbuf; +static unsigned char *rbuf; +static size_t rbuftotal, rbufrem; +static size_t maxbuflen; static size_t noncelen; static size_t taglen; +static int state = STATEINITIAL; int -netwrite(int fd, unsigned char *pt, int ptlen) +netwrite(int fd, unsigned char *pt, size_t ptlen, size_t *outlen) { - size_t pktlen = noncelen + HDRLEN + ptlen + taglen; - size_t outlen; - - if (pktlen > maxpktlen) - return -1; + unsigned char *p = wbuf; + size_t buflen = noncelen + HDRLEN + ptlen + taglen; + int n, total = 0; - arc4random_buf(outpkt, noncelen); - pack16(&outpkt[noncelen], ptlen); - if (!cryptoseal(&outpkt[noncelen + HDRLEN], &outlen, - ptlen + taglen, outpkt, noncelen, - pt, ptlen, &outpkt[noncelen], HDRLEN)) { + arc4random_buf(wbuf, noncelen); + pack16(&wbuf[noncelen], ptlen); + if (!cryptoseal(&wbuf[noncelen + HDRLEN], outlen, + ptlen + taglen, wbuf, noncelen, + pt, ptlen, &wbuf[noncelen], HDRLEN)) { logwarn("cryptoseal failed"); return -1; } - return writeall(fd, outpkt, pktlen); + *outlen = ptlen; + + while (buflen > 0) { + n = write(fd, p + total, buflen); + if (n == 0) { + return PKTFAILED; + } else if (n < 0) { + if (errno == EWOULDBLOCK) + continue; + return PKTFAILED; + } + total += n; + buflen -= n; + } + + return PKTCOMPLETE; } -/* - * Read one complete packet off the network. If the payload - * length has been tampered with the tag will either not match - * or the read will timeout after RCVTIMEO ms. Timing out is - * necessary to make sure the two ends synchronize again. - */ int -netread(int fd, unsigned char *pt, int ptlen) +netread(int fd, unsigned char *pt, size_t ptlen, size_t *outlen) { - size_t pktlen = noncelen + HDRLEN + ptlen + taglen; - size_t outlen; - int n, ctlen; - - if (pktlen > maxpktlen) - return -1; + int n; - if ((n = readall(fd, inpkt, noncelen)) <= 0) - goto err; - if ((n = readall(fd, &inpkt[noncelen], HDRLEN)) <= 0) - goto err; - /* if payload len is bogus cap it */ - if ((ctlen = unpack16(&inpkt[noncelen])) > ptlen) - ctlen = ptlen; - if ((n = readall(fd, &inpkt[noncelen + HDRLEN], ctlen + taglen)) <= 0) - goto err; - - if (!cryptoopen(pt, &outlen, ptlen, inpkt, noncelen, - &inpkt[noncelen + HDRLEN], ctlen + taglen, - &inpkt[noncelen], HDRLEN)) { - logwarn("cryptoopen failed"); - return BADPKT; + for (;;) { + switch (state) { + case STATEINITIAL: + rbuftotal = 0; + rbufrem = noncelen; + state = STATENONCE; + break; + case STATENONCE: + while (rbufrem > 0) { + n = read(fd, rbuf + rbuftotal, rbufrem); + if (n == 0) { + return PKTFAILED; + } else if (n < 0) { + if (errno == EWOULDBLOCK) + return PKTPARTIAL; + return PKTFAILED; + } + rbuftotal += n; + rbufrem -= n; + } + if (rbufrem == 0) { + rbufrem = HDRLEN; + state = STATEHDR; + } + break; + case STATEHDR: + while (rbufrem > 0) { + n = read(fd, rbuf + rbuftotal, rbufrem); + if (n == 0) { + return PKTFAILED; + } else if (n < 0) { + if (errno == EWOULDBLOCK); + return PKTPARTIAL; + return PKTFAILED; + } + rbuftotal += n; + rbufrem -= n; + } + if (rbufrem == 0) { + n = unpack16(&rbuf[noncelen]); + if (n > ptlen) { + rbufrem = MAXPAYLOADLEN; + state = STATEDISCARD; + } else { + rbufrem = n; + state = STATEPAYLOAD; + } + } + break; + case STATEPAYLOAD: + while (rbufrem > 0) { + n = read(fd, rbuf + rbuftotal, rbufrem); + if (n == 0) { + return PKTFAILED; + } else if (n < 0) { + if (errno == EWOULDBLOCK); + return PKTPARTIAL; + return PKTFAILED; + } + rbuftotal += n; + rbufrem -= n; + } + if (rbufrem == 0) { + rbufrem = taglen; + state = STATETAG; + } + break; + case STATETAG: + while (rbufrem > 0) { + n = read(fd, rbuf + rbuftotal, rbufrem); + if (n == 0) { + return PKTFAILED; + } else if (n < 0) { + if (errno == EWOULDBLOCK); + return PKTPARTIAL; + return PKTFAILED; + } + rbuftotal += n; + rbufrem -= n; + } + if (rbufrem == 0) { + rbufrem = taglen; + state = STATEOPEN; + } + break; + case STATEOPEN: + if (!cryptoopen(pt, outlen, ptlen, rbuf, noncelen, + &rbuf[noncelen + HDRLEN], + rbuftotal - noncelen - HDRLEN, + &rbuf[noncelen], HDRLEN)) { + logwarn("cryptoopen failed"); + return PKTFAILED; + } + state = STATEINITIAL; + return PKTCOMPLETE; + case STATEDISCARD: + for (;;) { + n = read(fd, rbuf, rbufrem); + if (n == 0) { + return PKTFAILED; + } else if (n < 0) { + if (errno == EWOULDBLOCK) + break; + return PKTFAILED; + } + } + state = STATEINITIAL; + break; + } } - return outlen; -err: - if (n == 0) - return 0; - if (errno != EWOULDBLOCK) - return -1; - return BADPKT; + return PKTFAILED; } void @@ -82,9 +183,9 @@ netinit(void) { noncelen = cryptononcelen(); taglen = cryptotaglen(); - maxpktlen = noncelen + HDRLEN + MAXPAYLOADLEN + taglen; - if (!(outpkt = malloc(maxpktlen))) + maxbuflen = noncelen + HDRLEN + MAXPAYLOADLEN + taglen; + if (!(wbuf = malloc(maxbuflen))) logerr("oom"); - if (!(inpkt = malloc(maxpktlen))) + if (!(rbuf = malloc(maxbuflen))) logerr("oom"); } diff --git a/stun.c b/stun.c @@ -58,7 +58,6 @@ #include <stdint.h> #include <stdlib.h> #include <string.h> -#include <time.h> #include <unistd.h> #include "stun.h" @@ -76,6 +75,7 @@ int tunnel(int netfd, int devfd) { unsigned char buf[MAXPAYLOADLEN]; + size_t outlen; struct pollfd pfd[2]; int n; @@ -88,22 +88,18 @@ tunnel(int netfd, int devfd) logerr("poll failed"); if (pfd[0].revents & (POLLIN | POLLHUP)) { - n = netread(netfd, buf, MAXPAYLOADLEN); - if (n <= 0) + n = netread(netfd, buf, sizeof(buf), &outlen); + if (n == PKTFAILED) return -1; - else if (n == BADPKT) - logwarn("bad packet"); - else - devwrite(devfd, buf, n); + else if (n == PKTCOMPLETE) + devwrite(devfd, buf, outlen); } if (pfd[1].revents & (POLLIN | POLLHUP)) { n = devread(devfd, buf, MAXPAYLOADLEN); - if (n > 0) { - n = netwrite(netfd, buf, n); - if (n <= 0) + if (n > 0) + if (netwrite(netfd, buf, n, &outlen) == PKTFAILED) return -1; - } } } return 0; @@ -114,7 +110,6 @@ serversetup(int devfd) { struct addrinfo hints, *ai, *p; struct sockaddr_in remote; - struct timeval tv; int ret, netfd, listenfd; memset(&hints, 0, sizeof(hints)); @@ -162,8 +157,7 @@ serversetup(int devfd) logdbg("remote peer connected: %s", inet_ntoa(remote.sin_addr)); - ms2tv(&tv, RCVTIMEO); - setsockopt(netfd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + setnonblock(netfd, 1); setsockopt(netfd, SOL_SOCKET, SO_KEEPALIVE, (int []){1}, sizeof(int)); setsockopt(netfd, IPPROTO_TCP, TCP_NODELAY, (int []){1}, sizeof(int)); @@ -185,7 +179,6 @@ int clientsetup(int devfd) { struct addrinfo hints, *ai, *p; - struct timeval tv; int ret, netfd; memset(&hints, 0, sizeof(hints)); @@ -212,8 +205,7 @@ clientsetup(int devfd) if (debug) logdbg("connected to %s:%s", host, port); - ms2tv(&tv, RCVTIMEO); - setsockopt(netfd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + setnonblock(netfd, 1); setsockopt(netfd, SOL_SOCKET, SO_KEEPALIVE, (int []){1}, sizeof(int)); setsockopt(netfd, IPPROTO_TCP, TCP_NODELAY, (int []){1}, sizeof(int)); diff --git a/stun.h b/stun.h @@ -10,6 +10,12 @@ #define DEFCIPHER "chacha20-poly1305" enum { + PKTFAILED, + PKTPARTIAL, + PKTCOMPLETE +}; + +enum { TUNDEV, TAPDEV }; @@ -46,8 +52,8 @@ void logwarn(char *, ...); void logerr(char *, ...); /* netpkt.c */ -int netwrite(int, unsigned char *, int); -int netread(int, unsigned char *, int); +int netwrite(int, unsigned char *, size_t, size_t *); +int netread(int, unsigned char *, size_t, size_t *); void netinit(void); /* util.c */ @@ -55,7 +61,5 @@ void pack16(unsigned char *, uint16_t); uint16_t unpack16(unsigned char *); void pack64(unsigned char *, uint64_t); uint64_t unpack64(unsigned char *); -int writeall(int, void *, int); -int readall(int, void *, int); -void ms2tv(struct timeval *, long); void revokeprivs(void); +int setnonblock(int, int); diff --git a/util.c b/util.c @@ -1,6 +1,6 @@ #include <sys/types.h> -#include <sys/time.h> +#include <fcntl.h> #include <grp.h> #include <pwd.h> #include <stdint.h> @@ -47,45 +47,6 @@ unpack64(unsigned char *buf) (uint64_t)buf[7]; } -int -writeall(int fd, void *buf, int len) -{ - unsigned char *p = buf; - int n, total = 0; - - while (len > 0) { - n = write(fd, p + total, len); - if (n <= 0) - break; - total += n; - len -= n; - } - return total; -} - -int -readall(int fd, void *buf, int len) -{ - unsigned char *p = buf; - int n, total = 0; - - while (len > 0) { - n = read(fd, p + total, len); - if (n <= 0) - break; - total += n; - len -= n; - } - return total; -} - -void -ms2tv(struct timeval *tv, long ms) -{ - tv->tv_sec = ms / 1000; - tv->tv_usec = (ms % 1000) * 1000; -} - void revokeprivs(void) { @@ -98,3 +59,16 @@ revokeprivs(void) setresuid(pw->pw_uid, pw->pw_uid, pw->pw_uid) < 0) logerr("failed to revoke privs"); } + +int +setnonblock(int fd, int mode) +{ + int flags; + + flags = fcntl(fd, F_GETFL); + if (mode) + flags |= O_NONBLOCK; + else + flags &= ~O_NONBLOCK; + return fcntl(fd, F_SETFL, flags); +}