scc

simple C compiler
git clone git://git.2f30.org/scc
Log | Files | Refs | README | LICENSE

commit 8e8ce7a93939caeab7fc0e5e149470349746a43f
parent 253f3681886a5e397379a8a2cf9df9884c0926b6
Author: Roberto E. Vargas Caballero <Roberto E. Vargas Caballero>
Date:   Sun,  1 May 2016 21:32:15 +0200

[cc1] Fix use-affer-free bug in switches

Switches build a list of cases which were emitted after the
end of the switch statement, and since the switch statement is
usually a compound statement, it meant that the symbols created
in the cases were freed when they were emitted. To solve this
problem this patch simplifies everything emitting the cases
on the fly, letting the job of creating the switch table to the
backend.

Diffstat:
Mcc1/cc1.h | 22++++++----------------
Mcc1/code.c | 32++++----------------------------
Mcc1/stmt.c | 98++++++++++++++++++++++++++++++++++++++-----------------------------------------
3 files changed, 57 insertions(+), 95 deletions(-)

diff --git a/cc1/cc1.h b/cc1/cc1.h @@ -10,7 +10,7 @@ */ typedef struct type Type; typedef struct symbol Symbol; -typedef struct caselist Caselist; +typedef struct swtch Switch; typedef struct node Node; typedef struct input Input; @@ -87,19 +87,9 @@ struct node { struct node *left, *right; }; -struct scase { - Symbol *label; - Node *expr; - struct scase *next; -}; - -struct caselist { +struct swtch { short nr; - Symbol *deflabel; - Symbol *ltable; - Symbol *lbreak; - Node *expr; - struct scase *head; + char hasdef; }; struct yystype { @@ -335,8 +325,8 @@ enum op { OCALL, ORET, ODECL, - OSWITCH, - OSWITCHT, + OBSWITCH, + OESWITCH, OINIT }; @@ -368,7 +358,7 @@ extern void keywords(struct keyword *key, int ns); extern Symbol *newstring(char *s, size_t len); /* stmt.c */ -extern void compound(Symbol *lbreak, Symbol *lcont, Caselist *lswitch); +extern void compound(Symbol *lbreak, Symbol *lcont, Switch *sw); /* decl.c */ extern Type *typename(void); diff --git a/cc1/code.c b/cc1/code.c @@ -60,6 +60,8 @@ char *optxt[] = { [OCOMMA] = ",", [OLABEL] = "L%d\n", [ODEFAULT] = "\tf\tL%d\n", + [OBSWITCH] = "\ts", + [OESWITCH] = "\tk\n", [OCASE] = "\tv\tL%d", [OJUMP] = "\tj\tL%d\n", [OBRANCH] = "\ty\tL%d", @@ -126,8 +128,8 @@ void (*opcode[])(unsigned, void *) = { [OFUN] = emitfun, [ORET] = emittext, [ODECL] = emitdcl, - [OSWITCH] = emitswitch, - [OSWITCHT] = emitswitcht, + [OBSWITCH] = emittext, + [OESWITCH] = emittext, [OPAR] = emitbin, [OCALL] = emitbin, [OINIT] = emitinit @@ -457,32 +459,6 @@ emitsymid(unsigned op, void *arg) printf(optxt[op], sym->id); } -static void -emitswitch(unsigned op, void *arg) -{ - Caselist *lcase = arg; - - printf("\ts\tL%u", lcase->ltable->id); - emitexp(OEXPR, lcase->expr); -} - -static void -emitswitcht(unsigned op, void *arg) -{ - Caselist *lcase = arg; - struct scase *p, *next; - - printf("\tt\t#%c%0x\n", sizettype->letter, lcase->nr); - for (p = lcase->head; p; p = next) { - emitsymid(OCASE, p->label); - emitexp(OEXPR, p->expr); - next = p->next; - free(p); - } - if (lcase->deflabel) - emitsymid(ODEFAULT, lcase->deflabel); -} - Node * node(unsigned op, Type *tp, Node *lp, Node *rp) { diff --git a/cc1/stmt.c b/cc1/stmt.c @@ -10,7 +10,7 @@ Symbol *curfun; -static void stmt(Symbol *lbreak, Symbol *lcont, Caselist *lswitch); +static void stmt(Symbol *lbreak, Symbol *lcont, Switch *lswitch); static void label(void) @@ -36,7 +36,7 @@ label(void) } static void -stmtexp(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +stmtexp(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { if (accept(';')) return; @@ -62,7 +62,7 @@ condition(void) } static void -While(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +While(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { Symbol *begin, *cond, *end; Node *np; @@ -85,7 +85,7 @@ While(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } static void -For(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +For(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { Symbol *begin, *cond, *end; Node *econd, *einc, *einit; @@ -117,7 +117,7 @@ For(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } static void -Dowhile(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +Dowhile(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { Symbol *begin, *end; Node *np; @@ -137,7 +137,7 @@ Dowhile(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } static void -Return(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +Return(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { Node *np; Type *tp = curfun->type->type; @@ -160,7 +160,7 @@ Return(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } static void -Break(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +Break(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { expect(BREAK); if (!lbreak) { @@ -172,7 +172,7 @@ Break(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } static void -Continue(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +Continue(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { expect(CONTINUE); if (!lcont) { @@ -184,7 +184,7 @@ Continue(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } static void -Goto(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +Goto(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { Symbol *sym; @@ -204,74 +204,70 @@ Goto(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } static void -Switch(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +Swtch(Symbol *obr, Symbol *lcont, Switch *osw) { - Caselist lcase = {.nr = 0, .head = NULL, .deflabel = NULL}; + Switch sw = {0}; Node *cond; + Symbol *lbreak; expect(SWITCH); - expect ('('); + expect ('('); if ((cond = convert(expr(), inttype, 0)) == NULL) { errorp("incorrect type in switch statement"); cond = constnode(zero); } expect (')'); - lcase.expr = cond; - lcase.lbreak = newlabel(); - lcase.ltable = newlabel(); - - emit(OSWITCH, &lcase); - stmt(lbreak, lcont, &lcase); - emit(OJUMP, lcase.lbreak); - emit(OLABEL, lcase.ltable); - emit(OSWITCHT, &lcase); - emit(OLABEL, lcase.lbreak); + lbreak = newlabel(); + emit(OBSWITCH, NULL); + emit(OEXPR, cond); + stmt(lbreak, lcont, &sw); + emit(OESWITCH, NULL); + emit(OLABEL, lbreak); } static void -Case(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +Case(Symbol *lbreak, Symbol *lcont, Switch *sw) { Node *np; - struct scase *pcase; + Symbol *label; expect(CASE); - if (!lswitch) - errorp("case label not within a switch statement"); if ((np = iconstexpr()) == NULL) errorp("case label does not reduce to an integer constant"); - expect(':'); - if (lswitch && lswitch->nr >= 0) { - if (++lswitch->nr == NR_SWITCH) { - errorp("too case labels for a switch statement"); - lswitch->nr = -1; - } else { - pcase = xmalloc(sizeof(*pcase)); - pcase->expr = np; - pcase->next = lswitch->head; - emit(OLABEL, pcase->label = newlabel()); - lswitch->head = pcase; - } + if (!sw) { + errorp("case label not within a switch statement"); + } else if (sw->nr >= 0 && ++sw->nr == NR_SWITCH) { + errorp("too case labels for a switch statement"); + sw->nr = -1; } - stmt(lbreak, lcont, lswitch); + expect(':'); + + label = newlabel(); + emit(OCASE, label); + emit(OEXPR, np); + emit(OLABEL, label); + stmt(lbreak, lcont, sw); } static void -Default(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +Default(Symbol *lbreak, Symbol *lcont, Switch *sw) { - Symbol *ldefault = newlabel(); + Symbol *label = newlabel(); + if (sw->hasdef) + errorp("multiple default labels in one switch"); + sw->hasdef = 1; expect(DEFAULT); expect(':'); - emit(OLABEL, ldefault); - lswitch->deflabel = ldefault; - ++lswitch->nr; - stmt(lbreak, lcont, lswitch); + emit(ODEFAULT, label); + emit(OLABEL, label); + stmt(lbreak, lcont, sw); } static void -If(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +If(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { Symbol *end, *lelse; Node *np; @@ -294,7 +290,7 @@ If(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } static void -blockit(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +blockit(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { switch (yytoken) { case TYPEIDEN: @@ -313,7 +309,7 @@ blockit(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } void -compound(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +compound(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { static int nested; @@ -342,9 +338,9 @@ compound(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) } static void -stmt(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) +stmt(Symbol *lbreak, Symbol *lcont, Switch *lswitch) { - void (*fun)(Symbol *, Symbol *, Caselist *); + void (*fun)(Symbol *, Symbol *, Switch *); switch (yytoken) { case '{': fun = compound; break; @@ -356,7 +352,7 @@ stmt(Symbol *lbreak, Symbol *lcont, Caselist *lswitch) case BREAK: fun = Break; break; case CONTINUE: fun = Continue; break; case GOTO: fun = Goto; break; - case SWITCH: fun = Switch; break; + case SWITCH: fun = Swtch; break; case CASE: fun = Case; break; case DEFAULT: fun = Default; break; default: fun = stmtexp; break;