package coins.backend;

import java.lang.*;
import java.io.*;
import java.util.*;
import coins.backend.*;
import coins.backend.tmd.*;
import coins.backend.util.*;
import coins.backend.sym.*;
import coins.backend.lir.*;
import coins.backend.cfg.*;
import coins.backend.ana.*;

/** Represent a function. */
public class Function implements LirFactory {
  /** Module this function belongs to */
  public final Module module;

  /** Symbol table entry of this function */
  public final SymStatic symbol;

  /** Local symbol table */
  public final SymTab localSymtab = new SymTab();

  /** Soften interface change */
  public LirFactory newLir = this;

  /** Control flow graph */
  public final FlowGraph flowGraph;

  /** Label table */
  private Map labelTable = new HashMap();

  /** LIR Node serial number generator */
  private int idCounter = 1;

  /** Constant hash table */
  private Map constantTable = new HashMap();

  /** Result of analyses, instances of LocalAnalysis */
  private Map analyses = new HashMap();

  private PrintWriter debOut;

  /** Parse S-expression function description and convert to internal form */
  public Function(Module mod, ImList ptr, PrintWriter output)
    throws SyntaxError {
    module = mod;
    debOut = output;

    // Parse name of the function
    ptr = ptr.next();
    symbol = (SymStatic)module.globalSymtab.get((String)ptr.elem());
    symbol.setBody(this);

    // Parse symbol table
    ptr = ptr.next();
    ImList symp = (ImList)ptr.elem();
    if (symp.elem() != Keyword.SYMTAB)
      throw new SyntaxError("SYMTAB expected");
    for (symp = symp.next(); !symp.atEnd(); symp = symp.next()) {
      localSymtab.addSymbol((ImList)symp.elem());
    }

    // Parse body part
    BiList instructions = new BiList();
    while (!(ptr = ptr.next()).atEnd()) {
      ImList stmt = (ImList)ptr.elem();
      LirNode lir = decodeLir(stmt);
      instructions.add(lir);
    }

    if (Debug.dumpAfterRead) {
      // Dump function
      output.println("Function Just after read:");
      output.println("(FUNCTION \"" + symbol.name + "\"");
      output.print("  (SYMTAB");
      Iterator it = localSymtab.iterator();
      while (it.hasNext()) {
        output.println();
        output.print("    " + (Symbol)it.next());
      }
      output.println(" )");
      for (BiLink p = instructions.first(); !p.atEnd(); p = p.next())
        output.println("    " + (LirNode)p.elem());
      output.println(")");
    }

    // Make control flow graph
    flowGraph = new FlowGraph(this, instructions);
  }


  private LirNode decodeLir(ImList stmt) throws SyntaxError {
    int code = Op.toCode((String)stmt.elem());
    ImList opt = stmt.scanOpt();

    switch (code) {
        // leaf nodes
    case Op.INTCONST:
      return newLir.iconst(Type.decode((String)stmt.elem2nd()),
                           Long.parseLong((String)stmt.elem3rd()), opt);
    case Op.FLOATCONST:
      return newLir.fconst(Type.decode((String)stmt.elem2nd()),
                           Double.parseDouble((String)stmt.elem3rd()), opt);
    case Op.STATIC:
    case Op.FRAME:
    case Op.REG:
      {
        int type = Type.decode((String)stmt.elem2nd());
        String name = (String)stmt.elem3rd();
        Symbol sym = localSymtab.get(name);
        if (sym == null) {
            sym = module.globalSymtab.get(name);
            if (sym == null)
                throw new CantHappenException("Undefined symbol: " + name);
        }
        return newLir.symRef(code, type, sym, opt);
      }

    case Op.SUBREG:
      {
        int type = Type.decode((String)stmt.elem2nd());
        LirNode src = decodeLir((ImList)stmt.elem3rd());
        int pos = Integer.parseInt((String)stmt.elem4th());
        return newLir.subReg(Op.SUBREG, type, src, pos, opt);
      }

      // Label reference/definition
    case Op.LABEL:
      {
        String name = (String)stmt.elem3rd();
        Label label = internLabel(name);
        return newLir.labelRef(Op.LABEL, Type.ADDRESS, label, opt);
      }

    case Op.DEFLABEL:
      {
        String name = (String)stmt.elem2nd();
        Label label = internLabel(name);
        return newLir.operator(code, Type.UNKNOWN, newLir.labelRef(label), opt);
      }

    // unary operators with type
    case Op.NEG:
    case Op.BNOT:
    case Op.CONVSX:
    case Op.CONVZX:
    case Op.CONVIT:
    case Op.CONVFX:
    case Op.CONVFT:
    case Op.CONVFI:
    case Op.CONVSF:
    case Op.CONVUF:
    case Op.MEM:
      return newLir.operator(code, Type.decode((String)stmt.elem2nd()),
                             decodeLir((ImList)stmt.elem3rd()), opt);

    // unary operators w/o type
    case Op.JUMP:
    case Op.USE:
    case Op.CLOBBER:
      return newLir.operator(code, Type.UNKNOWN, decodeLir((ImList)stmt.elem2nd()), opt);

    // binary operators
    case Op.ADD:
    case Op.SUB:
    case Op.MUL:
    case Op.DIVS:
    case Op.DIVU:
    case Op.MODS:
    case Op.MODU:
    case Op.BAND:
    case Op.BOR:
    case Op.BXOR:
    case Op.LSHS:
    case Op.LSHU:
    case Op.RSHS:
    case Op.RSHU:
    case Op.TSTEQ:
    case Op.TSTNE:
    case Op.TSTLTS:
    case Op.TSTLES:
    case Op.TSTGTS:
    case Op.TSTGES:
    case Op.TSTLTU:
    case Op.TSTLEU:
    case Op.TSTGTU:
    case Op.TSTGEU:
    case Op.SET:
      return newLir.operator(code, Type.decode((String)stmt.elem2nd()),
                             decodeLir((ImList)stmt.elem3rd()),
                             decodeLir((ImList)stmt.elem4th()), opt);

    case Op.JUMPC: // ternary operator
      {
        LirNode[] src = new LirNode[3];
        src[0] = decodeLir((ImList)stmt.elem2nd());
        src[1] = decodeLir((ImList)stmt.elem3rd());
        src[2] = decodeLir((ImList)stmt.elem4th());
        return newLir.operator(code, Type.UNKNOWN, src, opt);
      }
        
    case Op.CALL: // currently with type
      {
        int type = Type.decode((String)stmt.elem2nd());
        LirNode callee = decodeLir((ImList)stmt.elem3rd());
        int n = ((ImList)stmt.elem4th()).length(); // parameter list length
        LirNode[] param = new LirNode[n];
        int i = 0;
        for (ImList p = (ImList)stmt.elem4th(); !p.atEnd(); p = p.next())
          param[i++] = decodeLir((ImList)p.elem());
        return newLir.operator(code, type, callee,
                               newLir.operator(Op.LIST, Type.UNKNOWN, param, null),
                               opt);
      }

    case Op.JUMPN:
      {
        LirNode value = decodeLir((ImList)stmt.elem2nd());
        ImList cases = (ImList)stmt.elem3rd();
        int n = cases.length();
        LirNode[] labels = new LirNode[n];
        int i = 0;
        for (ImList p = cases; !p.atEnd(); p = p.next()) {
          ImList c = (ImList)p.elem();
          labels[i++] = newLir.operator(Op.LIST, Type.UNKNOWN,
                                        decodeLir((ImList)c.elem()),
                                        decodeLir((ImList)c.elem2nd()), null);
        }
        return newLir.operator(code, Type.UNKNOWN, value,
                               newLir.operator(Op.LIST, Type.UNKNOWN, labels, null),
                               decodeLir((ImList)stmt.elem4th()), opt);
      }
        
    case Op.PROLOGUE:
    case Op.EPILOGUE:
      {
        int n = stmt.next().length();
        LirNode[] opr = new LirNode[n];
        ImList frame = (ImList)stmt.elem2nd();
        opr[0] = newLir.operator(Op.LIST, Type.UNKNOWN,
                                 newLir.iconst(Type.ADDRESS,
                                               Integer.parseInt((String)frame.elem()), null),
                                 newLir.iconst(Type.ADDRESS,
                                               Integer.parseInt((String)frame.elem2nd()), null),
                                 null);
        int i = 1;
        for (ImList p = stmt.next().next(); !p.atEnd(); p = p.next())
          opr[i++] = decodeLir((ImList)p.elem());
        return newLir.operator(code, Type.UNKNOWN, opr, opt);
      }

    case Op.PARALLEL:
      {
        int n = stmt.next().length();
        LirNode[] src = new LirNode[n];
        int i = 0;
        for (ImList p = stmt.next(); !p.atEnd(); p = p.next())
          src[i++] = decodeLir((ImList)p.elem());
        return newLir.operator(code, Type.UNKNOWN, src, opt);
      }

    default:
      throw new SyntaxError("Unknown opCode");
    }

  }


  /** Install new label which has name 'name' to this function */
  public Label internLabel(String name) {
    Label label = (Label)labelTable.get(name);
    if (label == null) {
      label = new Label(name);
      labelTable.put(name, label);
    }
    return label;
  }


  /** Add local symbol */
  public Symbol addSymbol(String name, int storage, int type,
                          int boundary, int offset, ImList opt) {
    Symbol sym = localSymtab.addSymbol(name, storage, type, boundary, offset, opt);
    // if (storage == Storage.FRAME)
    // sym.setRefNode(newLir.symRef(Op.FRAME, Type.ADDRESS, sym, null));
    // else
    // sym.setRefNode(newLir.symRef(Op.REG, type, sym, null));
    return sym;
  }

  
  /* LirFactory interface implementation */

  /** Return upper bound of LirNode id numbers. */
  public int idBound() { return idCounter; }


  /** Create FLOATCONST node */
  public LirNode fconst(int type, double value, ImList opt) {
    LirNode newObj = new LirFconst(idCounter, type, value, opt);
    LirNode old = (LirNode)constantTable.get(newObj);
    if (old != null)
      return old;
    constantTable.put(newObj, newObj);
    idCounter++;
    return newObj;
  }

  /** Create INTCONST node */
  public LirNode iconst(int type, long value, ImList opt) {
    LirNode newObj = new LirIconst(idCounter, type, value, opt);
    LirNode old = (LirNode)constantTable.get(newObj);
    if (old != null)
      return old;
    constantTable.put(newObj, newObj);
    idCounter++;
    return newObj;
  }

  /** Create STATIC/FRAME/REG node */
  public LirNode symRef(int opCode, int type, Symbol symbol, ImList opt) {
    LirNode obj = symbol.refNode();
    if (opt == null)
      opt = ImList.Empty;
    if (obj == null || !obj.opt.equals(opt)) {
      obj = new LirSymRef(idCounter++, opCode, type, symbol, opt);
      symbol.setRefNode(obj);
    }
    return obj;
  }

  /** Create STATIC/FRAME/REG node, short form */
  public LirNode symRef(Symbol symbol) {
    int op;
    int type = Type.ADDRESS;
    switch (symbol.storage) {
    case Storage.REG: op = Op.REG; type = symbol.type; break;
    case Storage.FRAME: op = Op.FRAME; break;
    case Storage.STATIC: op = Op.STATIC; break;
    default: throw new CantHappenException();
    }
    return symRef(op, type, symbol, ImList.Empty);
  }

  /** Create SUBREG node */
  public LirNode subReg(int opCode, int type, LirNode src, int pos, ImList opt) {
    return new LirSubReg(idCounter++, opCode, type, src, pos, opt);
  }

  /** Create LABEL node */
  public LirNode labelRef(int opCode, int type, Label label, ImList opt) {
    LirNode obj = label.refNode();
    if (opt == null)
      opt = ImList.Empty;
    if (obj == null || !obj.opt.equals(opt)) {
      obj = new LirLabelRef(idCounter++, opCode, type, label, opt);
      label.setRefNode(obj);
    }
    return obj;
  }

  /** Create LABEL node, short form */
  public LirNode labelRef(Label label) {
    return labelRef(Op.LABEL, Type.ADDRESS, label, ImList.Empty);
  }


  /** Create unary operator node */
  public LirNode operator(int opCode, int type, LirNode operand, ImList opt) {
    return new LirUnaOp(idCounter++, opCode, type, operand, opt);
  }

  /** Create binary operator node */
  public LirNode operator(int opCode, int type, LirNode operand0, LirNode operand1, ImList opt) {
    return new LirBinOp(idCounter++, opCode, type, operand0, operand1, opt);
  }

  /** Create ternary operator node */
  public LirNode operator(int opCode, int type, LirNode operand0,
                          LirNode operand1, LirNode operand2, ImList opt) {
    LirNode [] src = new LirNode[3];
    src[0] = operand0;
    src[1] = operand1;
    src[2] = operand2;
    return new LirNaryOp(idCounter++, opCode, type, src, opt);
  }

  /** Create n-ary operator node */
  public LirNode operator(int opCode, int type, LirNode operands[], ImList opt) {
    return new LirNaryOp(idCounter++, opCode, type, operands, opt);
  }

  /** Make a copy of node */
  public LirNode makeCopy(LirNode inst) {
    return inst.makeCopy(this);
  }

  /* end LirFactory interface implementation */


  /** Purge former analysis */
  public void purgeAnalysis() {
    analyses.clear();
  }

  /** Apply some analysis */
  public LocalAnalysis apply(LocalAnalyzer analyzer) {
    LocalAnalysis analysis = analyzer.doIt(this);
    analyses.put(analyzer, analysis);
    return analysis;
  }

  /** Require analysis. */
  public LocalAnalysis require(LocalAnalyzer analyzer) {
    LocalAnalysis ana = (LocalAnalysis)analyses.get(analyzer);
    if (ana == null || !ana.isUpToDate())
      ana = apply(analyzer);
    return ana;
  }

  /** Apply some transformation/optimization. */
  public void apply(LocalTransform transformer) {
    transformer.doIt(this);
  }



  /** Code Generation **/
  public Function generateCode(TMD rule) throws SyntaxError, IOException {
    Function func = this;

    do {
      // Instruction selection (or re-selection).
      func = func.doInstSel(rule);

      // debOut.println("After instsel & reloaded:");
      // func.printIt(debOut);

      // allocate registers.
    } while (!(new RegisterAllocation(func, debOut)).allocate());

    return func;
  }


  /** Convert to machine dependent form. **/
  public Function doInstSel(TMD rule) throws SyntaxError, IOException {
    // Pass this function to InstSel module
    //  (currently implemented by another program)

    ByteArrayOutputStream bytes = new ByteArrayOutputStream();
    PrintWriter out = new PrintWriter(bytes);
    printStandardForm(out);
    out.close();

    String rewrittenCode = rule.restra(bytes.toString());
    String machineDepStr = rule.instsel(rewrittenCode);

    // debOut.println("After instsel: ");
    // debOut.print(machineDepStr);
    // debOut.flush();

    // Reload matched code.
    Function machineDep;
    Object sexp = ImList.readSexp(new PushbackReader
                                    (new StringReader(machineDepStr)));
    if (!(sexp instanceof ImList))
      throw new CantHappenException("readSexp returns null or atom object");

    return new Function(module, (ImList)sexp, debOut);
}

  /** Print L-function in standard form. */
  public void printStandardForm(PrintWriter out) {
    out.println("(FUNCTION \"" + symbol.name + "\"");
    // print symbol table
    localSymtab.printStandardForm(out, "  ");
    // print body
    flowGraph.printStandardForm(out, "  ");
    out.println(")");
  }


  /** Dump internal data structure of the Function object. */

  private static final LocalAnalysis[] emptyAnares = new LocalAnalysis[0];

  public void printIt(PrintWriter out) { printIt(out, emptyAnares); }


  /** Dump internal data structure of the Function with some analyses. */
  public void printIt(PrintWriter out, LocalAnalyzer[] anals) {
    LocalAnalysis [] anares;
    if (anals != null) {
      anares = new LocalAnalysis[anals.length];
      for (int i = 0; i < anals.length; i++)
        anares[i] = require(anals[i]);
    } else
      anares = emptyAnares;
    printIt(out, anares);
  }


  /** Dump internal data structure of the Function with some analyses. */
  public void printIt(PrintWriter out, LocalAnalysis[] anals) {
    out.println();
    out.println("Function \"" + symbol.name + "\":");

    for (int i = 0; i < anals.length; i++)
      anals[i].printBeforeFunction(out);

    out.print(" Local ");
    localSymtab.printIt(out);
    out.println();
    out.println("Control Flow Graph:");
    flowGraph.printIt(out, anals);

    for (int i = 0; i < anals.length; i++)
      anals[i].printAfterFunction(out);
  }

}
