compiler: constants are now put directly inside the program's global table where possible (i.e. whenever the global value would be constant-ish anyway).

This commit is contained in:
2026-05-17 20:05:47 +03:00
parent 1ca4ac2b79
commit df1fad751f
2 changed files with 116 additions and 48 deletions
+59 -10
View File
@@ -5,6 +5,9 @@ module SymbolTable = Scope_analysis.SymbolTable
type instr = Vm.Types.instr type instr = Vm.Types.instr
type pre_global =
| Global of Vm.Types.value
| BackPatchClosure
type pre_instr = type pre_instr =
| Instr of instr | Instr of instr
| BackPatchMkClosure of int | BackPatchMkClosure of int
@@ -13,11 +16,13 @@ type pre_instr =
type program = { type program = {
instrs : pre_instr Dynarray.t; instrs : pre_instr Dynarray.t;
constants : Vm.Types.value Dynarray.t; constants : Vm.Types.value Dynarray.t;
globals : pre_global Dynarray.t;
sym_table : int SymbolTable.t; sym_table : int SymbolTable.t;
(* This array holds the lambda bodies that we have to compiler later, and (* This array holds the lambda bodies that we have to compiler later, and
the index we have to patch the address back into. the index we have to patch the address back into.
*) *)
backpatch : (int * expression) Queue.t; backpatch : (int * expression) Queue.t;
backpatch_const_q : (int * int * expression) Queue.t;
} }
let ( let* ) = Result.bind let ( let* ) = Result.bind
@@ -93,6 +98,8 @@ let rec compile_one p = function
let* _ = compile_one p e1 in let* _ = compile_one p e1 in
let* _ = emit_instr p Vm.Types.Pop in let* _ = emit_instr p Vm.Types.Pop in
compile_one p (Begin (e2 :: rest)) compile_one p (Begin (e2 :: rest))
| Native i ->
emit_constant p (Vm.Types.Native i)
and compile_all p exprs = and compile_all p exprs =
Util.traverse Util.traverse
@@ -109,19 +116,33 @@ and compile_all_no_pop p exprs =
lambdas - that should be fine, they'll just get added to the end lambdas - that should be fine, they'll just get added to the end
of the backpatch queue. of the backpatch queue.
*) *)
let backpatch_one p (i, b) = let backpatch_one_instr p (i, b) =
match Dynarray.get p.instrs i with match Dynarray.get p.instrs i with
| BackPatchMkClosure arg_count -> | BackPatchMkClosure arg_count ->
Dynarray.set p.instrs i (Instr (MakeClosure (arg_count, current_index p))); Dynarray.set p.instrs i (Instr (MakeClosure (arg_count, current_index p)));
let* _ = compile_one p b in let* _ = compile_one p b in
emit_instr p End emit_instr p End
| _ -> failwith "Can't backpatch anything other than a MakeClosure after compilation" | _ -> failwith "Can't backpatch anything other than a MakeClosure after compilation"
let rec backpatch p = let rec backpatch_instrs p =
if Queue.is_empty p.backpatch then if Queue.is_empty p.backpatch then
Ok () Ok ()
else else
(let* _ = backpatch_one p (Queue.pop p.backpatch) in (let* _ = backpatch_one_instr p (Queue.pop p.backpatch) in
backpatch p) backpatch_instrs p)
let backpatch_one_const p (i, arg_count, b) =
let instr_loc = Dynarray.length p.instrs in
let* _ = compile_one p b in
let* _ = emit_instr p End in
Ok (Dynarray.set p.globals i (Global (Vm.Types.Closure (arg_count, instr_loc, []))))
let rec backpatch_consts p =
if Queue.is_empty p.backpatch_const_q then
Ok ()
else
(let* _ = backpatch_one_const p (Queue.pop p.backpatch_const_q) in
backpatch_consts p)
let backpatch p =
let* () = backpatch_instrs p in
backpatch_consts p
let print_instr = function let print_instr = function
@@ -131,24 +152,52 @@ let print_instr = function
let print_instrs = let print_instrs =
Array.mapi_inplace (fun i ins -> Array.mapi_inplace (fun i ins ->
print_endline (Printf.sprintf "%d: %s" i (print_instr ins)); ins) print_endline (Printf.sprintf "%d: %s" i (print_instr ins)); ins)
let smooth_one = function let smooth_one_instr = function
| Instr i -> i | Instr i -> i
| _ -> failwith "backpatching process was not complete!" | _ -> failwith "backpatching process was not complete! (instrs)"
let smooth_instrs p = let smooth_instrs p =
Dynarray.to_array (Dynarray.map smooth_one p.instrs) Dynarray.to_array (Dynarray.map smooth_one_instr p.instrs)
let smooth_one_global = function
| Global c -> c
| _ -> failwith "backpatching process was not complete! (consts)"
let smooth_globals p =
Dynarray.to_array (Dynarray.map smooth_one_global p.globals)
let compile (exprs : expression list) (tbl : int SymbolTable.t) = let rec constantify = function
| Core_ast.Nil -> Vm.Types.Nil
| Core_ast.Int x -> Vm.Types.Int x
| Core_ast.String s -> Vm.Types.String s
| Core_ast.Double x -> Vm.Types.Double x
| Core_ast.Cons (a, b) -> Vm.Types.Cons (constantify a, constantify b)
| Core_ast.Symbol s -> Vm.Types.Symbol s
let mk_constants (tbl : (int * expression) SymbolTable.t) =
let constants = Dynarray.make ((SymbolTable.cardinal tbl) + 1) (Global Vm.Types.Nil) in
let to_backpatch = Queue.create () in
let () = SymbolTable.iter (fun _ (i, v) -> Dynarray.set constants i (match v with
| Scope_analysis.Lambda (a, b) -> Queue.add (i, a, b) to_backpatch; BackPatchClosure
| Scope_analysis.Literal l -> Global (constantify l)
| Native i -> Global (Vm.Types.Native i)
| _ -> Global Vm.Types.Nil)) tbl in
(constants, to_backpatch)
let compile (exprs : expression list) (tbl : (int * expression) SymbolTable.t) =
let (globals, backpatch_const_q) = mk_constants tbl in
let program = { let program = {
instrs=Dynarray.create (); instrs=Dynarray.create ();
constants=Dynarray.create(); constants=Dynarray.create();
sym_table=tbl; globals=globals;
sym_table=SymbolTable.map (fun (a, _) -> a) tbl;
backpatch=Queue.create (); backpatch=Queue.create ();
backpatch_const_q=backpatch_const_q;
} in } in
let* _ = compile_all program exprs in let* _ = compile_all program exprs in
let* _ = emit_instr program End in let* _ = emit_instr program End in
let* _ = backpatch program in let* _ = backpatch program in
let final_instrs = smooth_instrs program in let final_instrs = smooth_instrs program in
Ok (Vm.make_vm final_instrs (Dynarray.to_array program.constants) ((SymbolTable.cardinal tbl) + 1)) let final_globals = smooth_globals program in
let () = print_endline "constants:"; Array.iter (fun v -> print_endline(Vm.Types.print_value v)) final_globals in
Ok (Vm.make_vm final_instrs (Dynarray.to_array program.constants) final_globals) (*((SymbolTable.cardinal tbl) + 1))*)
let compile_src src = let compile_src src =
let* (exprs, tbl) = Scope_analysis.of_src src in let* (exprs, tbl) = Scope_analysis.of_src src in
+56 -37
View File
@@ -38,6 +38,9 @@ type expression =
| If of expression * expression * expression | If of expression * expression * expression
| Set of variable * expression | Set of variable * expression
| Begin of expression list | Begin of expression list
| Native of int
(* Native is effectively a VM primitive. Emitted here for convenience. *)
(* IMPORTANT: (* IMPORTANT:
This is a predefined global table. This is a predefined global table.
@@ -54,8 +57,8 @@ type expression =
*) *)
let default_global_table = let default_global_table =
SymbolTable.of_list [ SymbolTable.of_list [
("print", 0); ("print", (0, Native 0));
("add", 1) ("add", (1, Native 1))
] ]
(* extract all defined global symbols, given the top-level expressions (* extract all defined global symbols, given the top-level expressions
@@ -72,7 +75,7 @@ let extract_globals (top : Core_ast.top_level list) =
let rec aux tbl = function let rec aux tbl = function
| [] -> tbl | [] -> tbl
| Core_ast.Define (sym, _) :: rest -> | Core_ast.Define (sym, _) :: rest ->
aux (SymbolTable.add sym (id ()) tbl) rest aux (SymbolTable.add sym ((id ()), Literal Nil) tbl) rest
| Expr _ :: rest -> | Expr _ :: rest ->
aux tbl rest aux tbl rest
in aux default_global_table top in aux default_global_table top
@@ -86,7 +89,7 @@ let extract_globals (top : Core_ast.top_level list) =
let resolve_global tbl sym = let resolve_global tbl sym =
match SymbolTable.find_opt sym tbl with match SymbolTable.find_opt sym tbl with
| Some x -> Ok (Global x) | Some (x, _) -> Ok (Global x)
| None -> Error ("symbol " ^ sym ^ " is not defined!") | None -> Error ("symbol " ^ sym ^ " is not defined!")
(* First we try to resolve it to a local symbol, then look it up in the (* First we try to resolve it to a local symbol, then look it up in the
@@ -118,6 +121,39 @@ let extract_functions exprs =
let fs = List.map Option.get fs in let fs = List.map Option.get fs in
List.fold_left (fun t (s, args, rest) -> SymbolTable.add s (args, rest) t) SymbolTable.empty fs List.fold_left (fun t (s, args, rest) -> SymbolTable.add s (args, rest) t) SymbolTable.empty fs
let rec analyze global_tbl =
let rec aux tbl current = function
| Core_ast.Literal s -> Ok (Literal s)
| Var sym -> resolve_var tbl current sym
| Set (sym, expr) ->
let* inner = analyze global_tbl tbl current expr in
resolve_set tbl current sym inner
| Lambda (args, rest, body) ->
let args = (match rest with
| Some s -> List.append args [s]
| None -> args) in
let* body = (aux global_tbl (args :: current) body) in
Ok (Lambda (List.length args, body))
| Apply (f, es) ->
let* f = aux tbl current f in
let* e = Util.traverse (aux tbl current) es in
Ok (Apply (f, e))
| If (test, pos, neg) ->
let* test = aux tbl current test in
let* pos = aux tbl current pos in
let* neg = aux tbl current neg in
Ok (If (test, pos, neg))
| Begin el ->
let* body = traverse (aux tbl current) el in
Ok (Begin body)
in aux
let is_constantish = function
| Literal _ -> true
| Lambda _ -> true
| Native _ -> true
| _ -> false
(* We need to do some more sophisticated analysis to detect cases where (* We need to do some more sophisticated analysis to detect cases where
a symbol is accessed before it is defined. a symbol is accessed before it is defined.
If a symbol is accessed in a lambda body, that is fine, since that computation If a symbol is accessed in a lambda body, that is fine, since that computation
@@ -138,41 +174,24 @@ let extract_functions exprs =
I may consider adding special support for let forms, as this is pretty annoying. I may consider adding special support for let forms, as this is pretty annoying.
*) *)
let convert program = let convert program =
let global_tbl = extract_globals program in let global_tbl = ref (extract_globals program) in
let rec analyze tbl current = function let rec aux tbl = function
| Core_ast.Literal s -> Ok (Literal s) | [] -> Ok []
| Var sym -> resolve_var tbl current sym | (Core_ast.Expr e) :: rest ->
| Set (sym, expr) -> let* analysis = (analyze !global_tbl tbl [] e) in
let* inner = analyze tbl current expr in let* rest = aux tbl rest in
resolve_set tbl current sym inner Ok (analysis :: rest)
| Lambda (args, rest, body) ->
let args = (match rest with
| Some s -> List.append args [s]
| None -> args) in
let* body = (analyze global_tbl (args :: current) body) in
Ok (Lambda (List.length args, body))
| Apply (f, es) ->
let* f = analyze tbl current f in
let* e = Util.traverse (analyze tbl current) es in
Ok (Apply (f, e))
| If (test, pos, neg) ->
let* test = analyze tbl current test in
let* pos = analyze tbl current pos in
let* neg = analyze tbl current neg in
Ok (If (test, pos, neg))
| Begin el ->
let* body = traverse (analyze tbl current) el in
Ok (Begin body)
in
let[@tail_mod_cons] rec aux tbl = function
| [] -> []
| (Core_ast.Expr e) :: rest -> (analyze tbl [] e) :: (aux tbl rest)
| (Define (s, e)) :: rest -> | (Define (s, e)) :: rest ->
let tbl = SymbolTable.add s (SymbolTable.find s global_tbl) tbl in let (id, _) = SymbolTable.find s !global_tbl in
(analyze tbl [] (Set (s, e))) :: (aux tbl rest) let* analysis = analyze !global_tbl tbl [] e in
global_tbl := SymbolTable.remove s !global_tbl;
global_tbl := SymbolTable.add s (id, analysis) !global_tbl;
let tbl = SymbolTable.add s (SymbolTable.find s !global_tbl) tbl in
let* rest = aux tbl rest in
if is_constantish analysis then Ok (rest) else Ok (analysis :: rest)
in in
let* program = traverse (fun x -> x) (aux default_global_table program) in let* program = (aux default_global_table program) in
Ok (program, global_tbl) Ok (program, !global_tbl)
let of_src src = let of_src src =
let* core = (Core_ast.of_src src) in let* core = (Core_ast.of_src src) in