commit ac9a90a4cfa4d6650dd45b2f9102d29498737c77
parent a1f1de8150759ccf09dfabc1b02620462891a5bf
Author: sin <sin@2f30.org>
Date: Tue, 12 Apr 2016 14:10:33 +0100
use non-blocking sockets
cleanup will follow
Diffstat:
M | auth.c | | | 85 | ++++++++++++++++++++++++++++++++++++++++++++++++------------------------------- |
M | crypto.c | | | 2 | -- |
M | log.c | | | 2 | -- |
M | netpkt.c | | | 211 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------- |
M | stun.c | | | 26 | +++++++++----------------- |
M | stun.h | | | 14 | +++++++++----- |
M | util.c | | | 54 | ++++++++++++++---------------------------------------- |
7 files changed, 240 insertions(+), 154 deletions(-)
diff --git a/auth.c b/auth.c
@@ -1,5 +1,3 @@
-#include <sys/time.h>
-
#include <poll.h>
#include <stdint.h>
#include <stdlib.h>
@@ -14,37 +12,40 @@ int
challenge(int netfd)
{
unsigned char buf[sizeof(uint64_t)];
+ size_t outlen;
struct pollfd pfd[1];
uint64_t n, reply;
int ret;
- arc4random_buf(&n, sizeof(uint64_t));
+ arc4random_buf(&n, sizeof(buf));
pack64(buf, n);
- if (netwrite(netfd, buf, sizeof(uint64_t)) <= 0)
- return -1;
-
- pfd[0].fd = netfd;
- pfd[0].events = POLLIN;
- ret = poll(pfd, 1, RCVTIMEO);
- if (ret < 0) {
- logwarn("poll failed");
- return -1;
- } else if (ret == 0) {
- logwarn("challenge-response timed out");
+ if (netwrite(netfd, buf, sizeof(buf), &outlen) == PKTFAILED)
return -1;
- }
- if (pfd[0].revents & (POLLIN | POLLHUP)) {
- ret = netread(netfd, buf, sizeof(uint64_t));
- if (ret <= 0) {
+ for (;;) {
+ pfd[0].fd = netfd;
+ pfd[0].events = POLLIN;
+ ret = poll(pfd, 1, RCVTIMEO);
+ if (ret < 0) {
+ logwarn("poll failed");
return -1;
- } else if (ret == BADPKT) {
- logwarn("bad packet");
+ } else if (ret == 0) {
+ logwarn("challenge-response timed out");
return -1;
}
- reply = unpack64(buf);
- if (n + 1 == reply)
- return 0;
+
+ if (pfd[0].revents & (POLLIN | POLLHUP)) {
+ ret = netread(netfd, buf, sizeof(buf), &outlen);
+ if (ret == PKTFAILED) {
+ return -1;
+ } else if (ret == PKTCOMPLETE) {
+ if (outlen != sizeof(buf))
+ return -1;
+ reply = unpack64(buf);
+ if (n + 1 == reply)
+ return 0;
+ }
+ }
}
return -1;
}
@@ -53,19 +54,37 @@ int
response(int netfd)
{
unsigned char buf[sizeof(uint64_t)];
+ size_t outlen;
+ struct pollfd pfd[1];
uint64_t reply;
int ret;
- ret = netread(netfd, buf, sizeof(uint64_t));
- if (ret <= 0) {
- return -1;
- } else if (ret == BADPKT) {
- logwarn("bad packet");
- return -1;
+ for (;;) {
+ pfd[0].fd = netfd;
+ pfd[0].events = POLLIN;
+ ret = poll(pfd, 1, RCVTIMEO);
+ if (ret < 0) {
+ logwarn("poll failed");
+ return -1;
+ } else if (ret == 0) {
+ logwarn("challenge-response timed out");
+ return -1;
+ }
+
+ if (pfd[0].revents & (POLLIN | POLLHUP)) {
+ ret = netread(netfd, buf, sizeof(buf), &outlen);
+ if (ret == PKTFAILED) {
+ return -1;
+ } else if (ret == PKTCOMPLETE) {
+ if (outlen != sizeof(buf))
+ return -1;
+ reply = unpack64(buf);
+ pack64(buf, reply + 1);
+ if (netwrite(netfd, buf, sizeof(buf), &outlen) == PKTFAILED)
+ return -1;
+ break;
+ }
+ }
}
- reply = unpack64(buf);
- pack64(buf, reply + 1);
- if (netwrite(netfd, buf, sizeof(uint64_t)) <= 0)
- return -1;
return 0;
}
diff --git a/crypto.c b/crypto.c
@@ -1,5 +1,3 @@
-#include <sys/time.h>
-
#include <stdint.h>
#include <string.h>
diff --git a/log.c b/log.c
@@ -1,5 +1,3 @@
-#include <sys/time.h>
-
#include <stdarg.h>
#include <stdint.h>
#include <stdio.h>
diff --git a/netpkt.c b/netpkt.c
@@ -1,8 +1,7 @@
-#include <sys/time.h>
-
#include <errno.h>
#include <stdint.h>
#include <stdlib.h>
+#include <unistd.h>
#if defined(__linux__)
#include <bsd/stdlib.h>
@@ -10,71 +9,173 @@
#include "stun.h"
-static unsigned char *outpkt;
-static unsigned char *inpkt;
-static size_t maxpktlen;
+enum {
+ STATEINITIAL,
+ STATENONCE,
+ STATEHDR,
+ STATEPAYLOAD,
+ STATETAG,
+ STATEOPEN,
+ STATEDISCARD
+};
+
+static unsigned char *wbuf;
+static unsigned char *rbuf;
+static size_t rbuftotal, rbufrem;
+static size_t maxbuflen;
static size_t noncelen;
static size_t taglen;
+static int state = STATEINITIAL;
int
-netwrite(int fd, unsigned char *pt, int ptlen)
+netwrite(int fd, unsigned char *pt, size_t ptlen, size_t *outlen)
{
- size_t pktlen = noncelen + HDRLEN + ptlen + taglen;
- size_t outlen;
-
- if (pktlen > maxpktlen)
- return -1;
+ unsigned char *p = wbuf;
+ size_t buflen = noncelen + HDRLEN + ptlen + taglen;
+ int n, total = 0;
- arc4random_buf(outpkt, noncelen);
- pack16(&outpkt[noncelen], ptlen);
- if (!cryptoseal(&outpkt[noncelen + HDRLEN], &outlen,
- ptlen + taglen, outpkt, noncelen,
- pt, ptlen, &outpkt[noncelen], HDRLEN)) {
+ arc4random_buf(wbuf, noncelen);
+ pack16(&wbuf[noncelen], ptlen);
+ if (!cryptoseal(&wbuf[noncelen + HDRLEN], outlen,
+ ptlen + taglen, wbuf, noncelen,
+ pt, ptlen, &wbuf[noncelen], HDRLEN)) {
logwarn("cryptoseal failed");
return -1;
}
- return writeall(fd, outpkt, pktlen);
+ *outlen = ptlen;
+
+ while (buflen > 0) {
+ n = write(fd, p + total, buflen);
+ if (n == 0) {
+ return PKTFAILED;
+ } else if (n < 0) {
+ if (errno == EWOULDBLOCK)
+ continue;
+ return PKTFAILED;
+ }
+ total += n;
+ buflen -= n;
+ }
+
+ return PKTCOMPLETE;
}
-/*
- * Read one complete packet off the network. If the payload
- * length has been tampered with the tag will either not match
- * or the read will timeout after RCVTIMEO ms. Timing out is
- * necessary to make sure the two ends synchronize again.
- */
int
-netread(int fd, unsigned char *pt, int ptlen)
+netread(int fd, unsigned char *pt, size_t ptlen, size_t *outlen)
{
- size_t pktlen = noncelen + HDRLEN + ptlen + taglen;
- size_t outlen;
- int n, ctlen;
-
- if (pktlen > maxpktlen)
- return -1;
+ int n;
- if ((n = readall(fd, inpkt, noncelen)) <= 0)
- goto err;
- if ((n = readall(fd, &inpkt[noncelen], HDRLEN)) <= 0)
- goto err;
- /* if payload len is bogus cap it */
- if ((ctlen = unpack16(&inpkt[noncelen])) > ptlen)
- ctlen = ptlen;
- if ((n = readall(fd, &inpkt[noncelen + HDRLEN], ctlen + taglen)) <= 0)
- goto err;
-
- if (!cryptoopen(pt, &outlen, ptlen, inpkt, noncelen,
- &inpkt[noncelen + HDRLEN], ctlen + taglen,
- &inpkt[noncelen], HDRLEN)) {
- logwarn("cryptoopen failed");
- return BADPKT;
+ for (;;) {
+ switch (state) {
+ case STATEINITIAL:
+ rbuftotal = 0;
+ rbufrem = noncelen;
+ state = STATENONCE;
+ break;
+ case STATENONCE:
+ while (rbufrem > 0) {
+ n = read(fd, rbuf + rbuftotal, rbufrem);
+ if (n == 0) {
+ return PKTFAILED;
+ } else if (n < 0) {
+ if (errno == EWOULDBLOCK)
+ return PKTPARTIAL;
+ return PKTFAILED;
+ }
+ rbuftotal += n;
+ rbufrem -= n;
+ }
+ if (rbufrem == 0) {
+ rbufrem = HDRLEN;
+ state = STATEHDR;
+ }
+ break;
+ case STATEHDR:
+ while (rbufrem > 0) {
+ n = read(fd, rbuf + rbuftotal, rbufrem);
+ if (n == 0) {
+ return PKTFAILED;
+ } else if (n < 0) {
+ if (errno == EWOULDBLOCK);
+ return PKTPARTIAL;
+ return PKTFAILED;
+ }
+ rbuftotal += n;
+ rbufrem -= n;
+ }
+ if (rbufrem == 0) {
+ n = unpack16(&rbuf[noncelen]);
+ if (n > ptlen) {
+ rbufrem = MAXPAYLOADLEN;
+ state = STATEDISCARD;
+ } else {
+ rbufrem = n;
+ state = STATEPAYLOAD;
+ }
+ }
+ break;
+ case STATEPAYLOAD:
+ while (rbufrem > 0) {
+ n = read(fd, rbuf + rbuftotal, rbufrem);
+ if (n == 0) {
+ return PKTFAILED;
+ } else if (n < 0) {
+ if (errno == EWOULDBLOCK);
+ return PKTPARTIAL;
+ return PKTFAILED;
+ }
+ rbuftotal += n;
+ rbufrem -= n;
+ }
+ if (rbufrem == 0) {
+ rbufrem = taglen;
+ state = STATETAG;
+ }
+ break;
+ case STATETAG:
+ while (rbufrem > 0) {
+ n = read(fd, rbuf + rbuftotal, rbufrem);
+ if (n == 0) {
+ return PKTFAILED;
+ } else if (n < 0) {
+ if (errno == EWOULDBLOCK);
+ return PKTPARTIAL;
+ return PKTFAILED;
+ }
+ rbuftotal += n;
+ rbufrem -= n;
+ }
+ if (rbufrem == 0) {
+ rbufrem = taglen;
+ state = STATEOPEN;
+ }
+ break;
+ case STATEOPEN:
+ if (!cryptoopen(pt, outlen, ptlen, rbuf, noncelen,
+ &rbuf[noncelen + HDRLEN],
+ rbuftotal - noncelen - HDRLEN,
+ &rbuf[noncelen], HDRLEN)) {
+ logwarn("cryptoopen failed");
+ return PKTFAILED;
+ }
+ state = STATEINITIAL;
+ return PKTCOMPLETE;
+ case STATEDISCARD:
+ for (;;) {
+ n = read(fd, rbuf, rbufrem);
+ if (n == 0) {
+ return PKTFAILED;
+ } else if (n < 0) {
+ if (errno == EWOULDBLOCK)
+ break;
+ return PKTFAILED;
+ }
+ }
+ state = STATEINITIAL;
+ break;
+ }
}
- return outlen;
-err:
- if (n == 0)
- return 0;
- if (errno != EWOULDBLOCK)
- return -1;
- return BADPKT;
+ return PKTFAILED;
}
void
@@ -82,9 +183,9 @@ netinit(void)
{
noncelen = cryptononcelen();
taglen = cryptotaglen();
- maxpktlen = noncelen + HDRLEN + MAXPAYLOADLEN + taglen;
- if (!(outpkt = malloc(maxpktlen)))
+ maxbuflen = noncelen + HDRLEN + MAXPAYLOADLEN + taglen;
+ if (!(wbuf = malloc(maxbuflen)))
logerr("oom");
- if (!(inpkt = malloc(maxpktlen)))
+ if (!(rbuf = malloc(maxbuflen)))
logerr("oom");
}
diff --git a/stun.c b/stun.c
@@ -58,7 +58,6 @@
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
-#include <time.h>
#include <unistd.h>
#include "stun.h"
@@ -76,6 +75,7 @@ int
tunnel(int netfd, int devfd)
{
unsigned char buf[MAXPAYLOADLEN];
+ size_t outlen;
struct pollfd pfd[2];
int n;
@@ -88,22 +88,18 @@ tunnel(int netfd, int devfd)
logerr("poll failed");
if (pfd[0].revents & (POLLIN | POLLHUP)) {
- n = netread(netfd, buf, MAXPAYLOADLEN);
- if (n <= 0)
+ n = netread(netfd, buf, sizeof(buf), &outlen);
+ if (n == PKTFAILED)
return -1;
- else if (n == BADPKT)
- logwarn("bad packet");
- else
- devwrite(devfd, buf, n);
+ else if (n == PKTCOMPLETE)
+ devwrite(devfd, buf, outlen);
}
if (pfd[1].revents & (POLLIN | POLLHUP)) {
n = devread(devfd, buf, MAXPAYLOADLEN);
- if (n > 0) {
- n = netwrite(netfd, buf, n);
- if (n <= 0)
+ if (n > 0)
+ if (netwrite(netfd, buf, n, &outlen) == PKTFAILED)
return -1;
- }
}
}
return 0;
@@ -114,7 +110,6 @@ serversetup(int devfd)
{
struct addrinfo hints, *ai, *p;
struct sockaddr_in remote;
- struct timeval tv;
int ret, netfd, listenfd;
memset(&hints, 0, sizeof(hints));
@@ -162,8 +157,7 @@ serversetup(int devfd)
logdbg("remote peer connected: %s",
inet_ntoa(remote.sin_addr));
- ms2tv(&tv, RCVTIMEO);
- setsockopt(netfd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
+ setnonblock(netfd, 1);
setsockopt(netfd, SOL_SOCKET, SO_KEEPALIVE, (int []){1}, sizeof(int));
setsockopt(netfd, IPPROTO_TCP, TCP_NODELAY, (int []){1}, sizeof(int));
@@ -185,7 +179,6 @@ int
clientsetup(int devfd)
{
struct addrinfo hints, *ai, *p;
- struct timeval tv;
int ret, netfd;
memset(&hints, 0, sizeof(hints));
@@ -212,8 +205,7 @@ clientsetup(int devfd)
if (debug)
logdbg("connected to %s:%s", host, port);
- ms2tv(&tv, RCVTIMEO);
- setsockopt(netfd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
+ setnonblock(netfd, 1);
setsockopt(netfd, SOL_SOCKET, SO_KEEPALIVE, (int []){1}, sizeof(int));
setsockopt(netfd, IPPROTO_TCP, TCP_NODELAY, (int []){1}, sizeof(int));
diff --git a/stun.h b/stun.h
@@ -10,6 +10,12 @@
#define DEFCIPHER "chacha20-poly1305"
enum {
+ PKTFAILED,
+ PKTPARTIAL,
+ PKTCOMPLETE
+};
+
+enum {
TUNDEV,
TAPDEV
};
@@ -46,8 +52,8 @@ void logwarn(char *, ...);
void logerr(char *, ...);
/* netpkt.c */
-int netwrite(int, unsigned char *, int);
-int netread(int, unsigned char *, int);
+int netwrite(int, unsigned char *, size_t, size_t *);
+int netread(int, unsigned char *, size_t, size_t *);
void netinit(void);
/* util.c */
@@ -55,7 +61,5 @@ void pack16(unsigned char *, uint16_t);
uint16_t unpack16(unsigned char *);
void pack64(unsigned char *, uint64_t);
uint64_t unpack64(unsigned char *);
-int writeall(int, void *, int);
-int readall(int, void *, int);
-void ms2tv(struct timeval *, long);
void revokeprivs(void);
+int setnonblock(int, int);
diff --git a/util.c b/util.c
@@ -1,6 +1,6 @@
#include <sys/types.h>
-#include <sys/time.h>
+#include <fcntl.h>
#include <grp.h>
#include <pwd.h>
#include <stdint.h>
@@ -47,45 +47,6 @@ unpack64(unsigned char *buf)
(uint64_t)buf[7];
}
-int
-writeall(int fd, void *buf, int len)
-{
- unsigned char *p = buf;
- int n, total = 0;
-
- while (len > 0) {
- n = write(fd, p + total, len);
- if (n <= 0)
- break;
- total += n;
- len -= n;
- }
- return total;
-}
-
-int
-readall(int fd, void *buf, int len)
-{
- unsigned char *p = buf;
- int n, total = 0;
-
- while (len > 0) {
- n = read(fd, p + total, len);
- if (n <= 0)
- break;
- total += n;
- len -= n;
- }
- return total;
-}
-
-void
-ms2tv(struct timeval *tv, long ms)
-{
- tv->tv_sec = ms / 1000;
- tv->tv_usec = (ms % 1000) * 1000;
-}
-
void
revokeprivs(void)
{
@@ -98,3 +59,16 @@ revokeprivs(void)
setresuid(pw->pw_uid, pw->pw_uid, pw->pw_uid) < 0)
logerr("failed to revoke privs");
}
+
+int
+setnonblock(int fd, int mode)
+{
+ int flags;
+
+ flags = fcntl(fd, F_GETFL);
+ if (mode)
+ flags |= O_NONBLOCK;
+ else
+ flags &= ~O_NONBLOCK;
+ return fcntl(fd, F_SETFL, flags);
+}