sbase

suckless unix tools
git clone git://git.2f30.org/sbase.git
Log | Files | Refs | README | LICENSE

commit b3ae1a7b4b9832d1978fe4676758e53887248c5e
parent 1822f70d12b85cf4be46c9cd7111a035394b0796
Author: Hiltjo Posthuma <hiltjo@codemadness.org>
Date:   Sun Nov 16 15:19:50 +0100

expr: improvements

- handle divide by zero.
- use eregcomp().
- use emalloc().
- use snprintf() for safety and add a buffer size argument to valstr() just
  to be sure.
- code-style fixes.

Diffstat:
expr.c | 187++++++++++++++++++++++++++++++++++++++++++-------------------------------------
1 file changed, 100 insertions(+), 87 deletions(-)
diff --git a/expr.c b/expr.c @@ -2,10 +2,11 @@ #include <inttypes.h> #include <limits.h> #include <regex.h> -#include <stdio.h> #include <stdint.h> +#include <stdio.h> #include <stdlib.h> #include <string.h> + #include "util.h" enum { @@ -13,7 +14,7 @@ enum { }; typedef struct { - char *s; + char *s; intmax_t n; } Val; @@ -21,7 +22,7 @@ static void doop(int*, int**, Val*, Val**); static Val match(Val, Val); static void num(Val); static int valcmp(Val, Val); -static char *valstr(Val, char*); +static char *valstr(Val, char*, size_t); static int yylex(void); static int yyparse(int); @@ -29,6 +30,13 @@ static char **args; static size_t intlen; static Val yylval; +static void +ezero(intmax_t n) +{ + if(n == 0) + enprintf(2, "division by zero\n"); +} + /* otop points to one past last op * vtop points to one past last val * guaranteed otop != ops @@ -40,43 +48,47 @@ doop(int *ops, int **otop, Val *vals, Val **vtop) Val ret, a, b; int op; - if((*otop)[-1] == '(') + if ((*otop)[-1] == '(') enprintf(2, "syntax error: extra (\n"); - if(*vtop - vals < 2) + if (*vtop - vals < 2) enprintf(2, "syntax error: missing expression or extra operator\n"); - a = (*vtop)[-2]; - b = (*vtop)[-1]; + a = (*vtop)[-2]; + b = (*vtop)[-1]; op = (*otop)[-1]; switch (op) { - case '|': - if ( a.s && *a.s) ret = (Val){ a.s , 0 }; - else if(!a.s && a.n) ret = (Val){ NULL, a.n }; - else if( b.s && *b.s) ret = (Val){ b.s , 0 }; - else ret = (Val){ NULL, b.n }; - break; - - case '&': - if(((a.s && *a.s) || a.n) && - ((b.s && *b.s) || b.n)) ret = a; - else ret = (Val){ NULL, 0 }; - break; - - case '=': ret = (Val){ NULL, valcmp(a, b) == 0 }; break; - case '>': ret = (Val){ NULL, valcmp(a, b) > 0 }; break; - case GE : ret = (Val){ NULL, valcmp(a, b) >= 0 }; break; - case '<': ret = (Val){ NULL, valcmp(a, b) < 0 }; break; - case LE : ret = (Val){ NULL, valcmp(a, b) <= 0 }; break; - case NE : ret = (Val){ NULL, valcmp(a, b) != 0 }; break; - - case '+': num(a); num(b); ret = (Val){ NULL, a.n + b.n }; break; - case '-': num(a); num(b); ret = (Val){ NULL, a.n - b.n }; break; - case '*': num(a); num(b); ret = (Val){ NULL, a.n * b.n }; break; - case '/': num(a); num(b); ret = (Val){ NULL, a.n / b.n }; break; - case '%': num(a); num(b); ret = (Val){ NULL, a.n % b.n }; break; - - case ':': ret = match(a, b); break; + case '|': + if (a.s && *a.s) + ret = (Val){ a.s, 0 }; + else if (!a.s && a.n) + ret = (Val){ NULL, a.n }; + else if (b.s && *b.s) + ret = (Val){ b.s, 0 }; + else + ret = (Val){ NULL, b.n }; + break; + case '&': + if (((a.s && *a.s) || a.n) && + ((b.s && *b.s) || b.n)) + ret = a; + else + ret = (Val){ NULL, 0 }; + break; + case '=': ret = (Val){ NULL, valcmp(a, b) == 0 }; break; + case '>': ret = (Val){ NULL, valcmp(a, b) > 0 }; break; + case GE : ret = (Val){ NULL, valcmp(a, b) >= 0 }; break; + case '<': ret = (Val){ NULL, valcmp(a, b) < 0 }; break; + case LE : ret = (Val){ NULL, valcmp(a, b) <= 0 }; break; + case NE : ret = (Val){ NULL, valcmp(a, b) != 0 }; break; + + case '+': num(a); num(b); ret = (Val){ NULL, a.n + b.n }; break; + case '-': num(a); num(b); ret = (Val){ NULL, a.n - b.n }; break; + case '*': num(a); num(b); ret = (Val){ NULL, a.n * b.n }; break; + case '/': num(a); num(b); ezero(b.n); ret = (Val){ NULL, a.n / b.n }; break; + case '%': num(a); num(b); ezero(b.n); ret = (Val){ NULL, a.n % b.n }; break; + + case ':': ret = match(a, b); break; } (*vtop)[-2] = ret; @@ -87,33 +99,29 @@ doop(int *ops, int **otop, Val *vals, Val **vtop) static Val match(Val vstr, Val vregx) { - char b1[intlen], *str = valstr(vstr , b1); - char b2[intlen], *regx = valstr(vregx, b2); + intmax_t d; + char *ret, *p; + regoff_t len; + char b1[intlen], *str = valstr(vstr, b1, sizeof(b1)); + char b2[intlen], *regx = valstr(vregx, b2, sizeof(b2)); regex_t re; regmatch_t matches[2]; char anchreg[strlen(regx) + 2]; - sprintf(anchreg, "^%s", regx); - - if(regcomp(&re, anchreg, 0)) - enprintf(3, "regcomp failed\n"); + snprintf(anchreg, sizeof(anchreg), "^%s", regx); + enregcomp(3, &re, anchreg, 0); - if(regexec(&re, str, 2, matches, 0)) + if (regexec(&re, str, 2, matches, 0)) return (Val){ (re.re_nsub ? "" : NULL), 0 }; - if(re.re_nsub) { - intmax_t d; - char *ret, *p; - regoff_t len = matches[1].rm_eo - matches[1].rm_so + 1; - - if(!(ret = malloc(len))) // FIXME: free - enprintf(3, "malloc failed\n"); - + if (re.re_nsub) { + len = matches[1].rm_eo - matches[1].rm_so + 1; + ret = emalloc(len); /* TODO: free ret */ d = strtoimax(ret, &p, 10); strlcpy(ret, str + matches[1].rm_so, len); - if(*ret && !*p) + if (*ret && !*p) return (Val){ NULL, d }; return (Val){ ret, 0 }; } @@ -123,27 +131,27 @@ match(Val vstr, Val vregx) static void num(Val v) { - if(v.s) + if (v.s) enprintf(2, "syntax error: expected integer got `%s'\n", v.s); } static int valcmp(Val a, Val b) { - char b1[intlen], *p = valstr(a, b1); - char b2[intlen], *q = valstr(b, b2); + char b1[intlen], *p = valstr(a, b1, sizeof(b1)); + char b2[intlen], *q = valstr(b, b2, sizeof(b2)); - if(!a.s && !b.s) + if (!a.s && !b.s) return (a.n > b.n) - (a.n < b.n); return strcmp(p, q); } static char * -valstr(Val val, char *buf) +valstr(Val val, char *buf, size_t bufsiz) { char *p = val.s; - if(!p) { - sprintf(buf, "%"PRIdMAX, val.n); + if (!p) { + snprintf(buf, bufsiz, "%"PRIdMAX, val.n); p = buf; } return p; @@ -155,21 +163,24 @@ yylex(void) intmax_t d; char *q, *p, *ops = "|&=><+-*/%():"; - if(!(p = *args++)) + if (!(p = *args++)) return 0; d = strtoimax(p, &q, 10); - if(*p && !*q) { + if (*p && !*q) { yylval = (Val){ NULL, d }; return VAL; } - if(*p && !p[1] && strchr(ops, *p)) + if (*p && !p[1] && strchr(ops, *p)) return *p; - if(strcmp(p, ">=") == 0) return GE; - if(strcmp(p, "<=") == 0) return LE; - if(strcmp(p, "!=") == 0) return NE; + if (strcmp(p, ">=") == 0) + return GE; + if (strcmp(p, "<=") == 0) + return LE; + if (strcmp(p, "!=") == 0) + return NE; yylval = (Val){ p, 0 }; return VAL; @@ -192,38 +203,40 @@ yyparse(int argc) while((type = yylex())) { switch (type) { - case VAL: *vtop++ = yylval; break; - case '(': *otop++ = '(' ; break; - case ')': - if(last == '(') - enprintf(2, "syntax error: empty ( )\n"); - while(otop > ops && otop[-1] != '(') - doop(ops, &otop, vals, &vtop); - if(otop == ops) - enprintf(2, "syntax error: extra )\n"); - otop--; - break; - default : - if(prec[last]) - enprintf(2, "syntax error: extra operator\n"); - while(otop > ops && prec[otop[-1]] >= prec[type]) - doop(ops, &otop, vals, &vtop); - *otop++ = type; - break; + case VAL: *vtop++ = yylval; break; + case '(': *otop++ = '(' ; break; + case ')': + if (last == '(') + enprintf(2, "syntax error: empty ( )\n"); + while(otop > ops && otop[-1] != '(') + doop(ops, &otop, vals, &vtop); + if (otop == ops) + enprintf(2, "syntax error: extra )\n"); + otop--; + break; + default : + if (prec[last]) + enprintf(2, "syntax error: extra operator\n"); + while (otop > ops && prec[otop[-1]] >= prec[type]) + doop(ops, &otop, vals, &vtop); + *otop++ = type; + break; } last = type; } while(otop > ops) doop(ops, &otop, vals, &vtop); - if(vtop == vals) + if (vtop == vals) enprintf(2, "syntax error: missing expression\n"); - if(vtop - vals > 1) + if (vtop - vals > 1) enprintf(2, "syntax error: extra expression\n"); vtop--; - if(vtop->s) printf("%s\n" , vtop->s); - else printf("%"PRIdMAX"\n", vtop->n); + if (vtop->s) + printf("%s\n", vtop->s); + else + printf("%"PRIdMAX"\n", vtop->n); return (vtop->s && *vtop->s) || vtop->n; } @@ -231,11 +244,11 @@ yyparse(int argc) int main(int argc, char **argv) { - if(!(intlen = snprintf(NULL, 0, "%"PRIdMAX, INTMAX_MIN) + 1)) + if (!(intlen = snprintf(NULL, 0, "%"PRIdMAX, INTMAX_MIN) + 1)) enprintf(3, "failed to get max digits\n"); args = argv + 1; - if(*args && !strcmp("--", *args)) + if (*args && !strcmp("--", *args)) ++args; return !yyparse(argc);