commit a1f1de8150759ccf09dfabc1b02620462891a5bf
parent ac13c3d064f18ac41b1d8fd0a3f330811d5eba97
Author: sin <sin@2f30.org>
Date: Tue, 12 Apr 2016 12:28:35 +0100
rework netpkt code
Diffstat:
6 files changed, 59 insertions(+), 53 deletions(-)
diff --git a/auth.c b/auth.c
@@ -20,7 +20,7 @@ challenge(int netfd)
arc4random_buf(&n, sizeof(uint64_t));
pack64(buf, n);
- if (writenet(netfd, buf, sizeof(uint64_t)) <= 0)
+ if (netwrite(netfd, buf, sizeof(uint64_t)) <= 0)
return -1;
pfd[0].fd = netfd;
@@ -35,7 +35,7 @@ challenge(int netfd)
}
if (pfd[0].revents & (POLLIN | POLLHUP)) {
- ret = readnet(netfd, buf, sizeof(uint64_t));
+ ret = netread(netfd, buf, sizeof(uint64_t));
if (ret <= 0) {
return -1;
} else if (ret == BADPKT) {
@@ -56,7 +56,7 @@ response(int netfd)
uint64_t reply;
int ret;
- ret = readnet(netfd, buf, sizeof(uint64_t));
+ ret = netread(netfd, buf, sizeof(uint64_t));
if (ret <= 0) {
return -1;
} else if (ret == BADPKT) {
@@ -65,7 +65,7 @@ response(int netfd)
}
reply = unpack64(buf);
pack64(buf, reply + 1);
- if (writenet(netfd, buf, sizeof(uint64_t)) <= 0)
+ if (netwrite(netfd, buf, sizeof(uint64_t)) <= 0)
return -1;
return 0;
}
diff --git a/dev_bsd.c b/dev_bsd.c
@@ -18,7 +18,7 @@
#include "stun.h"
int
-opendev(char *dev)
+devopen(char *dev)
{
struct tuninfo ti;
int fd;
@@ -44,7 +44,7 @@ opendev(char *dev)
}
int
-writedev(int fd, unsigned char *buf, int len)
+devwrite(int fd, unsigned char *buf, int len)
{
struct iovec iov[2];
uint32_t type = htonl(AF_INET);
@@ -67,7 +67,7 @@ writedev(int fd, unsigned char *buf, int len)
}
int
-readdev(int fd, unsigned char *buf, int len)
+devread(int fd, unsigned char *buf, int len)
{
struct iovec iov[2];
uint32_t type;
diff --git a/dev_linux.c b/dev_linux.c
@@ -13,7 +13,7 @@
#include "stun.h"
int
-opendev(char *dev)
+devopen(char *dev)
{
struct ifreq ifr;
int fd, s;
@@ -43,13 +43,13 @@ opendev(char *dev)
}
int
-writedev(int fd, unsigned char *buf, int len)
+devwrite(int fd, unsigned char *buf, int len)
{
return write(fd, buf, len);
}
int
-readdev(int fd, unsigned char *buf, int len)
+devread(int fd, unsigned char *buf, int len)
{
return read(fd, buf, len);
}
diff --git a/netpkt.c b/netpkt.c
@@ -10,32 +10,30 @@
#include "stun.h"
+static unsigned char *outpkt;
+static unsigned char *inpkt;
+static size_t maxpktlen;
+static size_t noncelen;
+static size_t taglen;
+
int
-writenet(int fd, unsigned char *pt, int ptlen)
+netwrite(int fd, unsigned char *pt, int ptlen)
{
- unsigned char *pkt;
- size_t noncelen = cryptononcelen();
- size_t taglen = cryptotaglen();
size_t pktlen = noncelen + HDRLEN + ptlen + taglen;
size_t outlen;
- int n;
- if (!(pkt = malloc(pktlen)))
+ if (pktlen > maxpktlen)
return -1;
- arc4random_buf(pkt, noncelen);
- pack16(&pkt[noncelen], ptlen);
- if (!cryptoseal(&pkt[noncelen + HDRLEN], &outlen,
- ptlen + taglen, pkt, noncelen,
- pt, ptlen, &pkt[noncelen], HDRLEN)) {
- free(pkt);
+ arc4random_buf(outpkt, noncelen);
+ pack16(&outpkt[noncelen], ptlen);
+ if (!cryptoseal(&outpkt[noncelen + HDRLEN], &outlen,
+ ptlen + taglen, outpkt, noncelen,
+ pt, ptlen, &outpkt[noncelen], HDRLEN)) {
logwarn("cryptoseal failed");
return -1;
}
-
- n = writeall(fd, pkt, pktlen);
- free(pkt);
- return n;
+ return writeall(fd, outpkt, pktlen);
}
/*
@@ -45,43 +43,48 @@ writenet(int fd, unsigned char *pt, int ptlen)
* necessary to make sure the two ends synchronize again.
*/
int
-readnet(int fd, unsigned char *pt, int ptlen)
+netread(int fd, unsigned char *pt, int ptlen)
{
- unsigned char *pkt;
- size_t noncelen = cryptononcelen();
- size_t taglen = cryptotaglen();
size_t pktlen = noncelen + HDRLEN + ptlen + taglen;
size_t outlen;
int n, ctlen;
- if (!(pkt = malloc(pktlen)))
+ if (pktlen > maxpktlen)
return -1;
- if ((n = readall(fd, pkt, noncelen)) <= 0)
+ if ((n = readall(fd, inpkt, noncelen)) <= 0)
goto err;
- if ((n = readall(fd, &pkt[noncelen], HDRLEN)) <= 0)
+ if ((n = readall(fd, &inpkt[noncelen], HDRLEN)) <= 0)
goto err;
/* if payload len is bogus cap it */
- if ((ctlen = unpack16(&pkt[noncelen])) > ptlen)
+ if ((ctlen = unpack16(&inpkt[noncelen])) > ptlen)
ctlen = ptlen;
- if ((n = readall(fd, &pkt[noncelen + HDRLEN], ctlen + taglen)) <= 0)
+ if ((n = readall(fd, &inpkt[noncelen + HDRLEN], ctlen + taglen)) <= 0)
goto err;
- if (!cryptoopen(pt, &outlen, ptlen, pkt, noncelen,
- &pkt[noncelen + HDRLEN], ctlen + taglen,
- &pkt[noncelen], HDRLEN)) {
- free(pkt);
+ if (!cryptoopen(pt, &outlen, ptlen, inpkt, noncelen,
+ &inpkt[noncelen + HDRLEN], ctlen + taglen,
+ &inpkt[noncelen], HDRLEN)) {
logwarn("cryptoopen failed");
return BADPKT;
}
-
- free(pkt);
return outlen;
err:
- free(pkt);
if (n == 0)
return 0;
if (errno != EWOULDBLOCK)
return -1;
return BADPKT;
}
+
+void
+netinit(void)
+{
+ noncelen = cryptononcelen();
+ taglen = cryptotaglen();
+ maxpktlen = noncelen + HDRLEN + MAXPAYLOADLEN + taglen;
+ if (!(outpkt = malloc(maxpktlen)))
+ logerr("oom");
+ if (!(inpkt = malloc(maxpktlen)))
+ logerr("oom");
+}
diff --git a/stun.c b/stun.c
@@ -88,19 +88,19 @@ tunnel(int netfd, int devfd)
logerr("poll failed");
if (pfd[0].revents & (POLLIN | POLLHUP)) {
- n = readnet(netfd, buf, MAXPAYLOADLEN);
+ n = netread(netfd, buf, MAXPAYLOADLEN);
if (n <= 0)
return -1;
else if (n == BADPKT)
logwarn("bad packet");
else
- writedev(devfd, buf, n);
+ devwrite(devfd, buf, n);
}
if (pfd[1].revents & (POLLIN | POLLHUP)) {
- n = readdev(devfd, buf, MAXPAYLOADLEN);
+ n = devread(devfd, buf, MAXPAYLOADLEN);
if (n > 0) {
- n = writenet(netfd, buf, n);
+ n = netwrite(netfd, buf, n);
if (n <= 0)
return -1;
}
@@ -283,7 +283,7 @@ main(int argc, char *argv[])
if (!debug)
daemon(0, 0);
loginit("stun");
- devfd = opendev(argv[0]);
+ devfd = devopen(argv[0]);
/* disable core dumps as memory contains the pre-shared key */
rlim.rlim_cur = rlim.rlim_max = 0;
@@ -295,6 +295,8 @@ main(int argc, char *argv[])
cryptoinit(pw);
bzero(pw, strlen(pw));
+ netinit();
+
if (sflag)
return serversetup(devfd);
diff --git a/stun.h b/stun.h
@@ -35,9 +35,9 @@ int cryptoopen(unsigned char *, size_t *, size_t, const unsigned char *,
size_t);
/* dev_*.c */
-int opendev(char *);
-int writedev(int, unsigned char *, int);
-int readdev(int, unsigned char *, int);
+int devopen(char *);
+int devwrite(int, unsigned char *, int);
+int devread(int, unsigned char *, int);
/* log.c */
void loginit(char *);
@@ -45,9 +45,10 @@ void logdbg(char *, ...);
void logwarn(char *, ...);
void logerr(char *, ...);
-/* net.c */
-int writenet(int, unsigned char *, int);
-int readnet(int, unsigned char *, int);
+/* netpkt.c */
+int netwrite(int, unsigned char *, int);
+int netread(int, unsigned char *, int);
+void netinit(void);
/* util.c */
void pack16(unsigned char *, uint16_t);