commit f9b1884ede4720f0d2e929784a0a4a37de05709e
parent 3125bcdb402be89f1847de69f5f491b832e31cd4
Author: Roberto E. Vargas Caballero <k0ga@shike2.com>
Date: Tue, 1 Sep 2015 16:02:48 +0200
Rewrite constant folding
There were millions of errors in the code of simplify().
This new version removes the uglyness of the macros and
uses specialized functions instead.
Diffstat:
M | cc1/cc1.h | | | 43 | ++++++++++++++++++++++++++++--------------- |
M | cc1/expr.c | | | 35 | +++++++++++++++++++---------------- |
M | cc1/fold.c | | | 576 | +++++++++++++++++++++++++++++++++++++++++++++---------------------------------- |
M | cc1/types.c | | | 89 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
4 files changed, 468 insertions(+), 275 deletions(-)
diff --git a/cc1/cc1.h b/cc1/cc1.h
@@ -15,6 +15,19 @@ typedef struct caselist Caselist;
typedef struct node Node;
typedef struct input Input;
+struct limits {
+ union {
+ TINT i;
+ TUINT u;
+ TFLOAT f;
+ } max;
+ union {
+ TINT i;
+ TUINT u;
+ TFLOAT f;
+ } min;
+};
+
/*
* TODO: Some of the data stored in type is shared with
* cc2, so it should be stored in a table shared
@@ -236,7 +249,6 @@ enum tokens {
/* operations */
enum op {
- OPTR,
OADD,
OMUL,
OSUB,
@@ -249,6 +261,16 @@ enum op {
OBAND,
OBXOR,
OBOR,
+ ONEG,
+ OCPL,
+ OAND,
+ OOR,
+ OEQ,
+ ONE,
+ OLT,
+ OGE,
+ OLE,
+ OGT,
OASSIGN,
OA_MUL,
OA_DIV,
@@ -261,11 +283,9 @@ enum op {
OA_XOR,
OA_OR,
OADDR,
- ONEG,
- OCPL,
- OEXC,
OCOMMA,
OCAST,
+ OPTR,
OSYM,
OASK,
OCOLON,
@@ -285,15 +305,7 @@ enum op {
ORET,
ODECL,
OSWITCH,
- OSWITCHT,
- OAND,
- OOR,
- OEQ,
- ONE,
- OLT,
- OGE,
- OLE,
- OGT
+ OSWITCHT
};
/* error.c */
@@ -308,6 +320,7 @@ extern bool eqtype(Type *tp1, Type *tp2);
extern Type *ctype(unsigned type, unsigned sign, unsigned size);
extern Type *mktype(Type *tp, unsigned op, short nelem, Type *data[]);
extern Type *duptype(Type *base);
+extern struct limits *getlimits(Type *tp);
/* symbol.c */
extern void dumpstab(char *msg);
@@ -348,8 +361,7 @@ extern void freetree(Node *np);
#define BTYPE(np) ((np)->type->op)
/* fold.c */
-extern Node *simplify(unsigned char op, Type *tp, Node *lp, Node *rp);
-extern Node *usimplify(unsigned char op, Type *tp, Node *np);
+extern Node *simplify(int op, Type *tp, Node *lp, Node *rp);
extern Node *constconv(Node *np, Type *newtp);
/* expr.c */
@@ -358,6 +370,7 @@ extern Node *convert(Node *np, Type *tp1, char iscast);
extern Node *eval(Node *np), *iconstexpr(void), *condition(void);
extern Node *exp2cond(Node *np, char neg);
extern bool isnodecmp(int op);
+extern int negop(int op);
/* cpp.c */
extern void icpp(void);
diff --git a/cc1/expr.c b/cc1/expr.c
@@ -132,7 +132,7 @@ numericaluop(char op, Node *np)
case FLOAT:
if (op == OADD)
return np;
- return usimplify(op, np->type, np);
+ return simplify(op, np->type, np, NULL);
default:
error("unary operator requires integer operand");
}
@@ -144,7 +144,7 @@ integeruop(char op, Node *np)
np = eval(np);
if (BTYPE(np) != INT)
error("unary operator requires integer operand");
- return usimplify(op, np->type, np);
+ return simplify(op, np->type, np, NULL);
}
Node *
@@ -312,23 +312,26 @@ compare(char op, Node *lp, Node *rp)
return simplify(op, inttype, lp, rp);
}
+int
+negop(int op)
+{
+ switch (op) {
+ case OAND: return OOR;
+ case OOR: return OAND;
+ case OEQ: return ONE;
+ case ONE: return OEQ;
+ case OLT: return OGE;
+ case OGE: return OLT;
+ case OLE: return OGT;
+ case OGT: return OLE;
+ }
+ return op;
+}
+
Node *
negate(Node *np)
{
- unsigned op;
-
- switch (np->op) {
- case OAND: op = OOR; break;
- case OOR: op = OAND; break;
- case OEQ: op = ONE; break;
- case ONE: op = OEQ; break;
- case OLT: op = OGE; break;
- case OGE: op = OLT; break;
- case OLE: op = OGT; break;
- case OGT: op = OLE; break;
- default: return np;
- }
- np->op = op;
+ np->op = negop(np->op);
return np;
}
diff --git a/cc1/fold.c b/cc1/fold.c
@@ -4,287 +4,375 @@
#include "../inc/cc.h"
#include "cc1.h"
+static bool
+addi(TINT l, TINT r, Type *tp)
+{
+ struct limits *lim = getlimits(tp);
+ TINT max = lim->max.i, min = lim->min.i;
-#define SYMICMP(sym, val) (((sym)->type->sign) ? \
- (sym)->u.i == (val) : (sym)->u.u == (val))
+ if (l < 0 && r < 0 && l >= min - r ||
+ l == 0 ||
+ r == 0 ||
+ l < 0 && r > 0 ||
+ l > 0 && r < 0 ||
+ l > 0 && r > 0 && l <= max - r) {
+ return 1;
+ }
+ warn("overflow in constant expression");
+ return 0;
+}
-#define FOLDINT(sym, ls, rs, op) (((sym)->type->sign) ? \
- ((sym)->u.i = ((ls)->u.i op (rs)->u.i)) : \
- ((sym)->u.u = ((ls)->u.u op (rs)->u.u)))
+static bool
+subi(TINT l, TINT r, Type *tp)
+{
+ return addi(l, -r, tp);
+}
-#define CMPISYM(sym, ls, rs, op) (((sym)->type->sign) ? \
- ((ls)->u.i op (rs)->u.i) : ((ls)->u.u op (rs)->u.u))
+static bool
+muli(TINT l, TINT r, Type *tp)
+{
+ struct limits *lim = getlimits(tp);
+ TINT max = lim->max.i, min = lim->min.i;
-Node *
-simplify(unsigned char op, Type *tp, Node *lp, Node *rp)
+ if (l > -1 && l <= 1 ||
+ r > -1 && r <= 1 ||
+ l < 0 && r < 0 && -l <= max/-r ||
+ l < 0 && r > 0 && l >= min/r ||
+ l > 0 && r < 0 && r >= min/l ||
+ l > 0 && r > 0 && l <= max/r) {
+ return 1;
+ }
+ warn("overflow in constant expression");
+ return 0;
+}
+
+static bool
+divi(TINT l, TINT r, Type *tp)
{
- Symbol *sym, *ls, *rs, aux;
- int iszero, isone, noconst = 0;
-
- if (!lp->constant && !rp->constant)
- goto no_simplify;
- if (!rp->constant) {
- Node *np;
- np = lp;
- lp = rp;
- rp = np;
+ struct limits *lim = getlimits(tp);
+
+ if (r == 0) {
+ warn("division by 0");
+ return 0;
}
- if (!lp->constant)
- noconst = 1;
-
- ls = lp->sym, rs = rp->sym;
- aux.type = tp;
-
- /* TODO: Add overflow checkings */
-
- if (isnodecmp(op)) {
- /*
- * Comparision nodes have integer type
- * but the operands can have different
- * type.
- */
- switch (BTYPE(lp)) {
- case INT: goto cmp_integers;
- case FLOAT: goto cmp_floats;
- default: goto no_simplify;
- }
+ if (l == lim->min.i && r == -1) {
+ warn("overflow in constant expression");
+ return 0;
}
+ return 1;
+}
- switch (tp->op) {
- case PTR:
+static bool
+lshi(TINT l, TINT r, Type *tp)
+{
+ if (r < 0 || r >= tp->size * 8) {
+ warn("shifting %d bits is undefined", r);
+ return 0;
+ }
+ return muli(l, 1 << r, tp);
+}
+
+static bool
+rshi(TINT l, TINT r, Type *tp)
+{
+ if (r < 0 || r >= tp->size * 8) {
+ warn("shifting %d bits is undefined", r);
+ return 0;
+ }
+ return 1;
+}
+
+static bool
+foldint(int op, Symbol *res, TINT l, TINT r)
+{
+ TINT i;
+ bool (*validate)(TINT, TINT, Type *tp);
+
+ switch (op) {
+ case OADD: validate = addi; break;
+ case OSUB: validate = subi; break;
+ case OMUL: validate = muli; break;
+ case ODIV: validate = divi; break;
+ case OSHL: validate = lshi; break;
+ case OSHR: validate = rshi; break;
+ case OMOD: validate = divi; break;
+ default: validate = NULL; break;
+ }
+
+ if (validate && !(*validate)(l, r, res->type))
+ return 0;
+
+ switch (op) {
+ case OADD: i = l + r; break;
+ case OSUB: i = l - r; break;
+ case OMUL: i = l * r; break;
+ case ODIV: i = l / r; break;
+ case OMOD: i = l % r; break;
+ case OSHL: i = l << r; break;
+ case OSHR: i = l >> r; break;
+ case OBAND: i = l & r; break;
+ case OBXOR: i = l ^ r; break;
+ case OBOR: i = l | r; break;
+ case OAND: i = l && r; break;
+ case OOR: i = l || r; break;
+ case OLT: i = l < r; break;
+ case OGT: i = l > r; break;
+ case OGE: i = l >= r; break;
+ case OLE: i = l <= r; break;
+ case OEQ: i = l == r; break;
+ case ONE: i = l != r; break;
+ case ONEG: i = -l; break;
+ case OCPL: i = ~l; break;
+ }
+ res->u.i = i;
+ return 1;
+}
+
+static bool
+folduint(int op, Symbol *res, TINT l, TINT r)
+{
+ TINT i;
+ TUINT u;
+
+ switch (op) {
+ case OADD: u = l + r; break;
+ case OSUB: u = l - r; break;
+ case OMUL: u = l * r; break;
+ case ODIV: u = l / r; break;
+ case OMOD: u = l % r; break;
+ case OSHL: u = l << r; break;
+ case OSHR: u = l >> r; break;
+ case OBAND: u = l & r; break;
+ case OBXOR: u = l ^ r; break;
+ case OBOR: u = l | r; break;
+ case ONEG: u = -l; break;
+ case OCPL: u = ~l; break;
+ case OAND: i = l && r; goto unsign;
+ case OOR: i = l || r; goto unsign;
+ case OLT: i = l < r; goto unsign;
+ case OGT: i = l > r; goto unsign;
+ case OGE: i = l >= r; goto unsign;
+ case OLE: i = l <= r; goto unsign;
+ case OEQ: i = l == r; goto unsign;
+ case ONE: i = l != r; goto unsign;
+ }
+
+sign:
+ res->u.u = u;
+ return 1;
+
+unsign:
+ res->u.i = i;
+ return 1;
+}
+
+static bool
+foldfloat(int op, Symbol *res, TFLOAT l, TFLOAT r)
+{
+ TFLOAT f;
+ TINT i;
+
+ switch (op) {
+ case OADD: f = l + r; break;
+ case OSUB: f = l - r; break;
+ case OMUL: f = l * r; break;
+ case ODIV: f = l / r; break;
+ case OLT: i = l < r; goto comparision;
+ case OGT: i = l > r; goto comparision;
+ case OGE: i = l >= r; goto comparision;
+ case OLE: i = l <= r; goto comparision;
+ case OEQ: i = l == r; goto comparision;
+ case ONE: i = l != r; goto comparision;
+ default: return 0;
+ }
+ res->u.f = f;
+ return 1;
+
+comparision:
+ res->u.i = i;
+ return 1;
+}
+
+static Node *
+foldconst(int type, int op, Type *tp, Symbol *ls, Symbol *rs)
+{
+ Symbol *sym, aux;
+ TINT i;
+ TUINT u;
+ TFLOAT f;
+
+ aux.type = ls->type;
+ switch (type) {
case INT:
- cmp_integers:
- iszero = SYMICMP(rs, 0);
- isone = SYMICMP(rs, 1);
- switch (op) {
- case OADD:
- if (iszero)
- return lp;
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, +);
- break;
- case OSUB:
- if (iszero)
- return lp;
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, -);
- break;
- case OMUL:
- if (isone)
- return lp;
- if (iszero)
- return constnode(zero);
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, *);
- break;
- case ODIV:
- if (isone)
- return lp;
- if (iszero)
- goto division_by_0;
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, /);
- break;
- case OMOD:
- if (iszero)
- goto division_by_0;
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, %);
- break;
- case OSHL:
- if (iszero)
- return lp;
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, <<);
- break;
- case OSHR:
- if (iszero)
- return lp;
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, >>);
- break;
- case OBAND:
- if (SYMICMP(rs, ~0))
- return lp;
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, &);
- break;
- case OBXOR:
- if (iszero)
- return lp;
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, ^);
- break;
- case OBOR:
- if (iszero)
- return lp;
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, |);
- break;
- case OAND:
- if (!iszero)
- return lp;
- /* TODO: What happens with something like f(0) && 0? */
- if (noconst)
- goto no_simplify;
- FOLDINT(&aux, ls, rs, &&);
- break;
- case OOR:
- if (iszero)
- return lp;
- if (noconst)
- goto no_simplify;
- /* TODO: What happens with something like f(0) || 1? */
- FOLDINT(&aux, ls, rs, ||);
- break;
- case OLT:
- /* TODO: what happens with signess? */
- if (noconst)
- goto no_simplify;
- aux.u.i = CMPISYM(&aux, ls, rs, <);
- break;
- case OGT:
- /* TODO: what happens with signess? */
- if (noconst)
- goto no_simplify;
- aux.u.i = CMPISYM(&aux, ls, rs, >);
- break;
- case OGE:
- /* TODO: what happens with signess? */
- if (noconst)
- goto no_simplify;
- aux.u.i = CMPISYM(&aux, ls, rs, >=);
- break;
- case OLE:
- /* TODO: what happens with signess? */
- if (noconst)
- goto no_simplify;
- aux.u.i = CMPISYM(&aux, ls, rs, <=);
- break;
- case OEQ:
- /* TODO: what happens with signess? */
- if (noconst)
- goto no_simplify;
- aux.u.i = CMPISYM(&aux, ls, rs, ==);
- break;
- case ONE:
- /* TODO: what happens with signess? */
- if (noconst)
- goto no_simplify;
- aux.u.i = CMPISYM(&aux, ls, rs, !=);
- break;
- }
+ i = (rs) ? rs->u.i : 0;
+ if (!foldint(op, &aux, ls->u.i, i))
+ return NULL;
+ break;
+ case UNSIGNED:
+ u = (rs) ? rs->u.u : 0u;
+ if (!folduint(op, &aux, ls->u.u, u))
+ return NULL;
break;
case FLOAT:
- cmp_floats:
- /* TODO: Add algebraic reductions for floats */
- switch (op) {
- case OADD:
- aux.u.f = ls->u.f + rs->u.f;
- break;
- case OSUB:
- aux.u.f = ls->u.f - rs->u.f;
- break;
- case OMUL:
- aux.u.f = ls->u.f * rs->u.f;
- break;
- case ODIV:
- if (rs->u.f == 0.0)
- goto division_by_0;
- aux.u.f = ls->u.f / rs->u.f;
- break;
- case OLT:
- aux.u.i = ls->u.f < rs->u.f;
- break;
- case OGT:
- aux.u.i = ls->u.f > rs->u.f;
- break;
- case OGE:
- aux.u.i = ls->u.f >= rs->u.f;
- break;
- case OLE:
- aux.u.i = ls->u.f <= rs->u.f;
- break;
- case OEQ:
- aux.u.i = ls->u.f == rs->u.f;
- break;
- case ONE:
- aux.u.i = ls->u.f != rs->u.f;
- break;
- }
+ f = (rs) ? rs->u.f : 0.0;
+ if (!foldfloat(op, &aux, ls->u.f, f))
+ return NULL;
break;
- default:
- goto no_simplify;
}
-
sym = newsym(NS_IDEN);
sym->type = tp;
sym->u = aux.u;
return constnode(sym);
+}
-division_by_0:
- warn("division by 0");
+static Node *
+fold(int op, Type *tp, Node *lp, Node *rp)
+{
+ Symbol *rs, *ls;
+ Node *np;
+ int type;
-no_simplify:
- return node(op, tp, lp, rp);
+ if (!lp->constant || rp && !rp->constant)
+ return NULL;
+ ls = lp->sym;
+ rs = (rp) ? rp->sym : NULL;
+
+ /*
+ * Comparision nodes have integer type
+ * but the operands can have different
+ * type.
+ */
+ type = (isnodecmp(op)) ? BTYPE(lp) : tp->op;
+ switch (type) {
+ case PTR:
+ case INT:
+ type = (tp->sign) ? INT : UNSIGNED;
+ break;
+ case FLOAT:
+ type = FLOAT;
+ break;
+ default:
+ return NULL;
+ }
+
+ if ((np = foldconst(type, op, tp, ls, rs)) == NULL)
+ return NULL;
+
+ freetree(lp);
+ freetree(rp);
+ return np;
}
-#define UFOLDINT(sym, ls, op) (((sym)->type->sign) ? \
- ((sym)->u.i = (op (ls)->u.i)) : \
- ((sym)->u.u = (op (ls)->u.u)))
+static void
+commutative(int *op, Node **lp, Node **rp)
+{
+ Node *l = *lp, *r = *rp, *aux;
-Node *
-usimplify(unsigned char op, Type *tp, Node *np)
+ if (r == NULL || r->constant || !l->constant)
+ return;
+
+ switch (*op) {
+ case OLT:
+ case OGT:
+ case OGE:
+ case OLE:
+ case OEQ:
+ case ONE:
+ *op = negop(*op);
+ case OADD:
+ case OMUL:
+ case OBAND:
+ case OBXOR:
+ case OBOR:
+ aux = l;
+ l = r;
+ r = aux;
+ *rp = r;
+ *lp = l;
+ break;
+ default:
+ return;
+ }
+}
+
+static bool
+cmp(Node *np, int val)
{
- Symbol *sym, *ns, aux;
+ Symbol *sym;
+ Type *tp;
if (!np->constant)
- goto no_simplify;
- ns = np->sym;
- aux.type = tp;
+ return 0;
+ sym = np->sym;
+ tp = sym->type;
switch (tp->op) {
+ case PTR:
case INT:
- switch (op) {
- case ONEG:
- UFOLDINT(&aux, ns, -);
- break;
- case OCPL:
- UFOLDINT(&aux, ns, ~);
- break;
- default:
- goto no_simplify;
- }
- break;
+ return ((tp->sign) ? sym->u.i : sym->u.u) == val;
case FLOAT:
- if (op != ONEG)
- goto no_simplify;
- aux.u.f = -ns->u.f;
+ return sym->u.f == val;
+ }
+ return 0;
+}
+
+static TUINT
+ones(int n)
+{
+ TUINT v;
+
+ for (v = 1; n--; v |= 1)
+ v <<= 1;
+ return v;
+}
+
+static bool
+identity(int op, Node *lp, Node *rp)
+{
+ int val;
+
+ switch (op) {
+ case OSHL:
+ case OSHR:
+ case OBXOR:
+ case OADD:
+ case OSUB:
+ val = 0;
+ break;
+ case ODIV:
+ case OMOD:
+ case OMUL:
+ case OBOR:
+ val = 1;
break;
+ case OBAND:
+ if (cmp(lp, ones(lp->type->size * 8)))
+ goto free_right;
default:
- goto no_simplify;
+ return 0;
}
+ if (!cmp(rp, val))
+ return 0;
+free_right:
+ freetree(rp);
+ return 1;
+}
- sym = newsym(NS_IDEN);
- sym->type = tp;
- sym->u = aux.u;
- return constnode(sym);
+Node *
+simplify(int op, Type *tp, Node *lp, Node *rp)
+{
+ Node *np;
-no_simplify:
- return node(op, tp, np, NULL);
+ if ((np = fold(op, tp, lp, rp)) != NULL)
+ return np;
+ commutative(&op, &lp, &rp);
+ if (identity(op, lp, rp))
+ return lp;
+ return node(op, tp, lp, rp);
}
/* TODO: check validity of types */
+/* TODO: Integrate it with simplify */
Node *
constconv(Node *np, Type *newtp)
diff --git a/cc1/types.c b/cc1/types.c
@@ -11,6 +11,67 @@
#define NR_TYPE_HASH 16
/*
+ * Compiler can generate warnings here if the ranges of TINT,
+ * TUINT and TFLOAT are smaller than any of the constants in this
+ * array. Ignore them if you know that the target types are correct
+ */
+static struct limits limits[][4] = {
+ {
+ { /* 0 = signed 1 byte */
+ .min.i = -127,
+ .max.i = 127
+ },
+ { /* 1 = signed 2 byte */
+ .min.i = -32767,
+ .max.i = 327677
+ },
+ { /* 2 = signed 4 byte */
+ .min.i = -2147483647L,
+ .max.i = 2147483647L
+ },
+ { /* 3 = signed 8 byte */
+ .min.i = -9223372036854775807LL,
+ .max.i = 9223372036854775807LL,
+ }
+ },
+ {
+ { /* 0 = unsigned 1 byte */
+ .min.u = 0,
+ .max.u = 255
+ },
+ { /* 1 = unsigned 2 bytes */
+ .min.u = 0,
+ .max.u = 65535u
+ },
+ { /* 2 = unsigned 4 bytes */
+ .min.u = 0,
+ .max.u = 4294967295u
+ },
+ { /* 3 = unsigned 4 bytes */
+ .min.u = 0,
+ .max.u = 18446744073709551615u
+ }
+ },
+ {
+ {
+ /* 0 = float 4 bytes */
+ .min.f = -1,
+ .max.f = 2
+ },
+ {
+ /* 1 = float 8 bytes */
+ .min.f = -1,
+ .max.f = 2,
+ },
+ {
+ /* 2 = float 16 bytes */
+ .min.f = -1,
+ .max.f = 2,
+ }
+ }
+};
+
+/*
* Initializaion of type pointers were done with
* a C99 initilizator '... = &(Type) {...', but
* c compiler in Plan9 gives error with this
@@ -197,6 +258,34 @@ static Symbol dummy0 = {.u.i = 0, .type = &types[9]},
Symbol *zero = &dummy0, *one = &dummy1;
+struct limits *
+getlimits(Type *tp)
+{
+ int ntable, ntype;
+
+ switch (tp->op) {
+ case INT:
+ ntable = tp->sign;
+ switch (tp->size) {
+ case 1: ntype = 0; break;
+ case 2: ntype = 1; break;
+ case 4: ntype = 2; break;
+ case 8: ntype = 3; break;
+ }
+ break;
+ case FLOAT:
+ ntable = 2;
+ switch (tp->size) {
+ case 4: ntype = 0; break;
+ case 8: ntype = 1; break;
+ case 16: ntype = 2; break;
+ }
+ break;
+ }
+
+ return &limits[ntable][ntype];
+}
+
Type *
ctype(unsigned type, unsigned sign, unsigned size)
{