From 2e6376aef7ba4613a90f5edd9916c9013f01c6dd Mon Sep 17 00:00:00 2001 From: bakhtiyartemirov Date: Sat, 1 Nov 2025 18:24:34 +0800 Subject: [PATCH 1/2] Added support for built-in operators --- .../codegen/wasm/text/Instructions.scala | 104 ++++ .../hkmc2/codegen/wasm/text/WatBuilder.scala | 180 ++++-- .../src/test/mlscript/wasm/Operators.mls | 517 ++++++++++++++++++ 3 files changed, 768 insertions(+), 33 deletions(-) create mode 100644 hkmc2/shared/src/test/mlscript/wasm/Operators.mls diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala index 1d6b1abe12..501f0361ac 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala @@ -115,6 +115,110 @@ object Instructions: stackargs = Seq(lhs, rhs), resultType = S(I32Type) ) + + /** Creates an `i32.sub` instruction. */ + def sub(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.sub", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.mul` instruction. */ + def mul(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.mul", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.div_s` instruction. */ + def div_s(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.div_s", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.rem_s` instruction. */ + def rem_s(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.rem_s", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.and` instruction. */ + def and(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.and", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.or` instruction. */ + def or(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.or", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.eq` instruction. */ + def eq(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.eq", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.ne` instruction. */ + def ne(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.ne", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.lt_s` instruction. */ + def lt_s(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.lt_s", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.le_s` instruction. */ + def le_s(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.le_s", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.gt_s` instruction. */ + def gt_s(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.gt_s", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.ge_s` instruction. */ + def ge_s(lhs: Expr, rhs: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.ge_s", + instrargs = Seq.empty, + stackargs = Seq(lhs, rhs), + resultType = S(I32Type) + ) + + /** Creates an `i32.eqz` instruction. */ + def eqz(value: Expr): FoldedInstr = FoldedInstr( + mnemonic = "i32.eqz", + instrargs = Seq.empty, + stackargs = Seq(value), + resultType = S(I32Type) + ) end i32 object ref: diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 967bbee669..54a437b544 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -107,6 +107,98 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: def operand(a: Arg)(using Ctx, Raise, Scope): Expr = if a.spread.nonEmpty then die else subexpression(a.value) + private def opTempPrefix(op: Str): Str = op match + case "+" => "plus" + case "-" => "minus" + case "*" => "mul" + case "/" => "div" + case "%" => "mod" + case "==" => "eq" + case "!=" => "ne" + case "<" => "lt" + case "<=" => "le" + case ">" => "gt" + case ">=" => "ge" + case "&&" => "and" + case "||" => "or" + case "!" => "not" + case other => s"op_${other.flatMap(_.toString)}" + + private def binaryI31Op( + lhs: Arg, + rhs: Arg, + opName: Str + )( + compute: (FoldedInstr, FoldedInstr) => FoldedInstr, + wrapResult: FoldedInstr => Expr = ref.i31 + )(using Ctx, Raise, Scope): Expr = + val lhsExpr = operand(lhs) + val rhsExpr = operand(rhs) + + val prefix = opTempPrefix(opName) + val lhsTmp = TempSymbol(N, s"${prefix}_lhs") + val rhsTmp = TempSymbol(N, s"${prefix}_rhs") + val lhsIdx = ctx.addLocal(lhsTmp) + val rhsIdx = ctx.addLocal(rhsTmp) + scope.allocateName(lhsTmp) + scope.allocateName(rhsTmp) + + val bothI31 = i32.and( + ref.test(local.get(lhsIdx, RefType.anyref), RefType.i31ref), + ref.test(local.get(rhsIdx, RefType.anyref), RefType.i31ref) + ) + + val lhsI32 = i31.get(ref.cast(local.get(lhsIdx, RefType.anyref), RefType.i31ref), true) + val rhsI32 = i31.get(ref.cast(local.get(rhsIdx, RefType.anyref), RefType.i31ref), true) + val resultExpr = wrapResult(compute(lhsI32, rhsI32)) + + Instructions.block( + label = N, + children = Seq( + local.set(lhsIdx, lhsExpr), + local.set(rhsIdx, rhsExpr), + `if`( + condition = bothI31, + ifTrue = resultExpr, + ifFalse = S(unreachable), + resultTypes = Seq(Result(RefType.i31ref)) + ) + ), + resultTypes = Seq(Result(RefType.i31ref)) + ) + + private def unaryI31Op( + arg: Arg, + opName: Str + )( + compute: (Expr, FoldedInstr) => Expr + )(using Ctx, Raise, Scope): Expr = + val argExpr = operand(arg) + val prefix = opTempPrefix(opName) + val argTmp = TempSymbol(N, s"${prefix}_arg") + val argIdx = ctx.addLocal(argTmp) + scope.allocateName(argTmp) + + val isI31 = ref.test(local.get(argIdx, RefType.anyref), RefType.i31ref) + val casted = ref.cast(local.get(argIdx, RefType.anyref), RefType.i31ref) + val argI32 = i31.get(casted, true) + val resultExpr = compute(casted, argI32) + val resultTypes = Seq(Result(RefType.i31ref)) + + Instructions.block( + label = N, + children = Seq( + local.set(argIdx, argExpr), + `if`( + condition = isI31, + ifTrue = resultExpr, + ifFalse = S(unreachable), + resultTypes = resultTypes + ) + ), + resultTypes = resultTypes + ) + def subexpression(r: codegen.Result)(using Ctx, Raise, Scope): Expr = r match case r: Lambda => errExpr( @@ -160,38 +252,31 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: if l.binary then l.nme match case "+" => - // TODO(Derppening): Refactor to lower to `Call(plus_impl, ...)` - def castOperand(expr: Expr, opSide: Str): Expr = - expr.resultType match - case S(RefType(HeapType.Any, _)) => `if`( - ref.test(expr, RefType.i31ref), - ifTrue = castOperand(ref.cast(expr, RefType.i31ref), opSide), - ifFalse = S(unreachable), - resultTypes = Seq(Result(I32Type)) - ) - case S(RefType(HeapType.I31, _)) => i31.get(expr, true) - case S(I32Type) => expr - case ty => - errExpr( - Ls( - msg"WatBuilder::result for binary builtin symbol '${l.nme.toString}' ($opSide.type=${ty.fold("(none)")(_.toWat.mkString())}) not implemented yet" -> r.toLoc - ), - extraInfo = S(r.toString) - ) - - val lhsOp = castOperand(operand(lhs), "lhs") - val rhsOp = castOperand(operand(rhs), "rhs") - - (lhsOp.resultType, rhsOp.resultType) match - case (S(I32Type), S(I32Type)) => - ref.i31(i32.add(lhsOp, rhsOp)) - case (lhsType, rhsType) => - errExpr( - Ls( - msg"WatBuilder::result for binary builtin symbol '${l.nme.toString}' for (${lhsType.fold("(none)")(_.toWat.mkString())}, ${rhsType.fold("(none)")(_.toWat.mkString())}) not implemented yet" -> r.toLoc - ), - extraInfo = S(r.toString) - ) + binaryI31Op(lhs, rhs, "+")((lhsI32, rhsI32) => i32.add(lhsI32, rhsI32)) + case "-" => + binaryI31Op(lhs, rhs, "-")((lhsI32, rhsI32) => i32.sub(lhsI32, rhsI32)) + case "*" => + binaryI31Op(lhs, rhs, "*")((lhsI32, rhsI32) => i32.mul(lhsI32, rhsI32)) + case "/" => + binaryI31Op(lhs, rhs, "/")((lhsI32, rhsI32) => i32.div_s(lhsI32, rhsI32)) + case "%" => + binaryI31Op(lhs, rhs, "%")((lhsI32, rhsI32) => i32.rem_s(lhsI32, rhsI32)) + case "==" => + binaryI31Op(lhs, rhs, "==")((lhsI32, rhsI32) => i32.eq(lhsI32, rhsI32)) + case "!=" => + binaryI31Op(lhs, rhs, "!=")((lhsI32, rhsI32) => i32.ne(lhsI32, rhsI32)) + case "<" => + binaryI31Op(lhs, rhs, "<")((lhsI32, rhsI32) => i32.lt_s(lhsI32, rhsI32)) + case "<=" => + binaryI31Op(lhs, rhs, "<=")((lhsI32, rhsI32) => i32.le_s(lhsI32, rhsI32)) + case ">" => + binaryI31Op(lhs, rhs, ">")((lhsI32, rhsI32) => i32.gt_s(lhsI32, rhsI32)) + case ">=" => + binaryI31Op(lhs, rhs, ">=")((lhsI32, rhsI32) => i32.ge_s(lhsI32, rhsI32)) + case "&&" => + binaryI31Op(lhs, rhs, "&&")((lhsI32, rhsI32) => i32.and(lhsI32, rhsI32)) + case "||" => + binaryI31Op(lhs, rhs, "||")((lhsI32, rhsI32) => i32.or(lhsI32, rhsI32)) case lNme => errExpr( Ls( @@ -204,6 +289,29 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: msg"Cannot call non-binary builtin symbol '${l.nme}'" -> r.toLoc )) + case Call(Value.Ref(l: BuiltinSymbol), arg :: Nil) if !l.functionLike => + if l.unary then + l.nme match + case "-" => + unaryI31Op(arg, "-")((_, value) => + ref.i31(i32.sub(i32.const(0), value)) + ) + case "+" => + unaryI31Op(arg, "+")((casted, _) => casted) + case "!" => + unaryI31Op(arg, "!")((_, value) => ref.i31(i32.eqz(value))) + case lNme => + errExpr( + Ls( + msg"WatBuilder::result for unary builtin symbol '${lNme.toString}' not implemented yet" -> r.toLoc + ), + extraInfo = S(r.toString) + ) + else + errExpr(Ls( + msg"Cannot call non-unary builtin symbol '${l.nme}'" -> r.toLoc + )) + case Call(fun, args) => val base = subexpression(fun) if base.resultTypes.exists(_ is UnreachableType) then return base @@ -630,8 +738,14 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: vars def block(t: Block)(using Ctx, Raise, Scope): (Expr, Seq[Local]) = + val localsBefore = ctx.getAllWasmLocals.headOption.getOrElse(Nil) val locals = blockPreamble(t.definedVars) - (returningTerm(t), locals) + val expr = returningTerm(t) + val localsAfter = ctx.getAllWasmLocals.headOption.getOrElse(Nil) + val beforeSet = localsBefore.toSet + val declaredSet = locals.toSet + val extraLocals = localsAfter.filter(sym => !beforeSet(sym) && !declaredSet(sym)) + (expr, locals ++ extraLocals) def body(t: Block)(using Ctx, Raise, Scope): (Expr, Seq[Local]) = scope.nest givenIn: diff --git a/hkmc2/shared/src/test/mlscript/wasm/Operators.mls b/hkmc2/shared/src/test/mlscript/wasm/Operators.mls new file mode 100644 index 0000000000..4958258511 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/wasm/Operators.mls @@ -0,0 +1,517 @@ +:wasm +:wat +2 + 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $plus_lhs (ref null any)) +//│ (local $plus_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.add +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 3 + +//│ Error: hkmc2.ErrorReport: Import of symbol `/Users/bakhtiyartemirov/Desktop/FYP/mlscript/hkmc2/shared/src/test/mlscript-compile/Predef.mjs` not implemented yet +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (nop)) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) + +2 - 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $minus_lhs (ref null any)) +//│ (local $minus_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.sub +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 1 + +2 * 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $mul_lhs (ref null any)) +//│ (local $mul_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.mul +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 2 + +2 / 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $div_lhs (ref null any)) +//│ (local $div_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.div_s +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 2 + +2 % 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $mod_lhs (ref null any)) +//│ (local $mod_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.rem_s +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 0 + +2 == 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $eq_lhs (ref null any)) +//│ (local $eq_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.eq +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 0 + +2 != 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $ne_lhs (ref null any)) +//│ (local $ne_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.ne +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 1 + +2 < 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $lt_lhs (ref null any)) +//│ (local $lt_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.lt_s +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 0 + +2 <= 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $le_lhs (ref null any)) +//│ (local $le_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.le_s +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 0 + +2 > 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $gt_lhs (ref null any)) +//│ (local $gt_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.gt_s +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 1 + +2 >= 1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $ge_lhs (ref null any)) +//│ (local $ge_rhs (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (local.set 1 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (i32.and +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (ref.test (ref null i31) +//│ (local.get 1))) +//│ (then +//│ (ref.i31 +//│ (i32.ge_s +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 1)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 1 + +-1 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $minus_arg (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 1))) +//│ (if (result (ref null i31)) +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (then +//│ (ref.i31 +//│ (i32.sub +//│ (i32.const 0) +//│ (i31.get_s +//│ (ref.cast (ref null i31) +//│ (local.get 0)))))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = -1 + ++2 +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $plus_arg (ref null any)) +//│ (block (result (ref null i31)) +//│ (local.set 0 +//│ (ref.i31 +//│ (i32.const 2))) +//│ (if (result (ref null i31)) +//│ (ref.test (ref null i31) +//│ (local.get 0)) +//│ (then +//│ (ref.cast (ref null i31) +//│ (local.get 0))) +//│ (else +//│ (unreachable))))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ = 2 + +!false +//│ FAILURE: Unexpected parse error +//│ FAILURE LOCATION: simpleExprImpl (Parser.scala:697) +//│ ╔══[PARSE ERROR] Expected an expression; found dynamic selector instead +//│ ║ l.458: !false +//│ ╙── ^^^^^^ +//│ FAILURE: Unexpected parse error +//│ FAILURE LOCATION: parseAll (Parser.scala:248) +//│ ╔══[PARSE ERROR] Expected end of input; found dynamic selector instead +//│ ║ l.458: !false +//│ ╙── ^^^^^^ +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (nop)) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ FAILURE: Unexpected runtime error +//│ FAILURE LOCATION: mkQuery (WasmDiffMaker.scala:128) +//│ ═══[RUNTIME ERROR] undefined +//│ // Standard Error: +//│ Fatal: 1:54: error: popping from empty stack +//│ +//│ + +true && false +//│ FAILURE: Unexpected compilation error +//│ FAILURE LOCATION: errExpr (WatBuilder.scala:516) +//│ FAILURE INFO: FunDefn: +//│ owner = N +//│ sym = member:lambda +//│ params = Ls of +//│ ParamList: +//│ flags = () +//│ params = Nil +//│ restParam = N +//│ body = Return: +//│ res = Lit of BoolLit of false +//│ implct = false +//│ ═══[COMPILATION ERROR] WatBuilder::returningTerm for FunDefn(...) where `!sym.nameIsMeaningful` not implemented yet +//│ FAILURE: Unexpected compilation error +//│ FAILURE LOCATION: errExpr (WatBuilder.scala:96) +//│ FAILURE INFO: Block IR: `$runtime` +//│ Scope: Scope(Some(Scope(None,Some(Some(globalThis:globalThis)),HashMap($wasm -> wasm))),None,HashMap(member:lambda -> lambda)) +//│ Wasm Locals: List(List(member:lambda)) +//│ ═══[COMPILATION ERROR] WatBuilder::getVar for TempSymbol (symbol not in top-level scope) not implemented yet +//│ FAILURE: Unexpected exception +//│ /!!!\ Uncaught error: java.lang.Exception: Internal Error: Symbol for Select(...) expression must be resolved +//│ at: mlscript.utils.package$.lastWords(package.scala:230) +//│ at: hkmc2.codegen.wasm.text.WatBuilder.$anonfun$5(WatBuilder.scala:343) +//│ at: scala.Option.getOrElse(Option.scala:201) +//│ at: hkmc2.codegen.wasm.text.WatBuilder.result(WatBuilder.scala:343) +//│ at: hkmc2.codegen.wasm.text.WatBuilder.subexpression(WatBuilder.scala:208) +//│ at: hkmc2.codegen.wasm.text.WatBuilder.result(WatBuilder.scala:316) +//│ at: hkmc2.codegen.wasm.text.WatBuilder.returningTerm(WatBuilder.scala:649) +//│ at: hkmc2.codegen.wasm.text.WatBuilder.returningTerm(WatBuilder.scala:632) +//│ at: hkmc2.codegen.wasm.text.WatBuilder.block(WatBuilder.scala:743) +//│ at: hkmc2.codegen.wasm.text.WatBuilder.program(WatBuilder.scala:707) From 941a2836050adb148832350f9fbf4cfc42dd5638 Mon Sep 17 00:00:00 2001 From: bakhtiyartemirov Date: Sat, 1 Nov 2025 21:03:52 +0800 Subject: [PATCH 2/2] Added support for Label, Break, Continue --- .../codegen/wasm/text/Instructions.scala | 23 ++++ .../hkmc2/codegen/wasm/text/WatBuilder.scala | 64 ++++++++++ .../shared/src/test/mlscript/wasm/ConFlow.mls | 109 ++++++++++++++++++ 3 files changed, 196 insertions(+) create mode 100644 hkmc2/shared/src/test/mlscript/wasm/ConFlow.mls diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala index 501f0361ac..be7cbda361 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Instructions.scala @@ -21,6 +21,21 @@ object Instructions: resultTypes = resultTypes.map(_.valtype) ) + /** Creates a `loop` instruction. */ + def loop( + label: Opt[Str], + children: Seq[Expr], + resultTypes: Seq[Result] + ): FoldedInstr = + val labelWat = label.map(lbl => doc"$$$lbl") + + FoldedInstr( + mnemonic = "loop", + instrargs = labelWat.toSeq ++ resultTypes, + stackargs = children, + resultTypes = resultTypes.map(_.valtype) + ) + /** Creates an `if` instruction. */ def `if`( condition: Expr, @@ -99,6 +114,14 @@ object Instructions: resultType = S(UnreachableType) ) + /** Creates a br (branch) instruction. */ + def br(label: Str): FoldedInstr = FoldedInstr( + mnemonic = "br", + instrargs = Seq(doc"$$$label"), + stackargs = Seq.empty, + resultType = S(UnreachableType) + ) + object i32: /** Creates an `i32.const` instruction. */ def const(value: Int): FoldedInstr = FoldedInstr( diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala index 54a437b544..5612fe7b36 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala @@ -34,6 +34,21 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: type Context = Ctx + private case class LabelContext(symbol: Local, breakTarget: Str, continueTarget: Str) + + private var labelContextStack: List[LabelContext] = Nil + + private def pushLabelContext(ctx: LabelContext): Unit = + labelContextStack = ctx :: labelContextStack + + private def popLabelContext(): Unit = + labelContextStack = labelContextStack match + case _ :: tail => tail + case Nil => Nil + + private def lookupLabelContext(symbol: Local): Opt[LabelContext] = + labelContextStack.find(_.symbol == symbol) + /** * Raises a [[WarningReport]] with the given `warnMsgs` and `extraInfo`, and emits an * `unreachable` instruction. @@ -672,6 +687,55 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder: `return`(S(resWat)) + case Label(label, body, rest) => + val breakTarget = scope.allocateName(label) + val loopSym = TempSymbol(N, "loop") + val continueTarget = scope.allocateName(loopSym) + + pushLabelContext(LabelContext(label, breakTarget, continueTarget)) + val bodyExpr = + try returningTerm(body) + finally popLabelContext() + val restExpr = returningTerm(rest) + + Instructions.block( + label = N, + children = Seq( + Instructions.block( + label = S(breakTarget), + children = Seq( + Instructions.loop( + label = S(continueTarget), + children = Seq( + bodyExpr, + br(breakTarget) + ), + resultTypes = Seq.empty + ) + ), + resultTypes = Seq.empty + ), + restExpr + ), + resultTypes = restExpr.resultTypes.map(ty => Result(ty.asValType_!)) + ) + case Break(label) => + lookupLabelContext(label) match + case S(ctx) => br(ctx.breakTarget) + case N => + errExpr( + Ls(msg"WatBuilder::returningTerm encountered break to unknown label `${label.nme}`" -> label.toLoc), + extraInfo = S(label) + ) + case Continue(label) => + lookupLabelContext(label) match + case S(ctx) => br(ctx.continueTarget) + case N => + errExpr( + Ls(msg"WatBuilder::returningTerm encountered continue to unknown label `${label.nme}`" -> label.toLoc), + extraInfo = S(label) + ) + case End(_) => nop case t => diff --git a/hkmc2/shared/src/test/mlscript/wasm/ConFlow.mls b/hkmc2/shared/src/test/mlscript/wasm/ConFlow.mls new file mode 100644 index 0000000000..7202817882 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/wasm/ConFlow.mls @@ -0,0 +1,109 @@ +:wat +:wasm +mut let i = 0 +while i < 10 do + set i = i + 1 +i +//│ FAILURE: Unexpected type error +//│ FAILURE LOCATION: subterm (Elaborator.scala:734) +//│ ╔══[ERROR] Illegal position for 'mut' modifier. +//│ ║ l.3: mut let i = 0 +//│ ╙── ^^^ +//│ FAILURE: Unexpected type error +//│ FAILURE LOCATION: subterm (Elaborator.scala:370) +//│ ╔══[ERROR] Expected a body for let bindings in expression position +//│ ║ l.3: mut let i = 0 +//│ ╙── ^^^^^ +//│ FAILURE: Unexpected type error +//│ FAILURE LOCATION: subterm (Elaborator.scala:443) +//│ ╔══[ERROR] Name not found: i +//│ ║ l.4: while i < 10 do +//│ ╙── ^ +//│ FAILURE: Unexpected type error +//│ FAILURE LOCATION: subterm (Elaborator.scala:443) +//│ ╔══[ERROR] Name not found: i +//│ ║ l.5: set i = i + 1 +//│ ╙── ^ +//│ FAILURE: Unexpected type error +//│ FAILURE LOCATION: subterm (Elaborator.scala:443) +//│ ╔══[ERROR] Name not found: i +//│ ║ l.5: set i = i + 1 +//│ ╙── ^ +//│ FAILURE: Unexpected type error +//│ FAILURE LOCATION: subterm (Elaborator.scala:443) +//│ ╔══[ERROR] Name not found: i +//│ ║ l.6: i +//│ ╙── ^ +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $i (ref null any)) +//│ (block +//│ (local.set $i +//│ (ref.i31 +//│ (i32.const 0))) +//│ (block +//│ (block $tmp +//│ (loop $loop +//│ (nop) +//│ (br $tmp))) +//│ (nop)))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +//│ Wasm result: +//│ FAILURE: Unexpected runtime error +//│ FAILURE LOCATION: mkQuery (WasmDiffMaker.scala:128) +//│ ═══[RUNTIME ERROR] undefined +//│ // Standard Error: +//│ Fatal: 1:54: error: popping from empty stack +//│ +//│ + +//│ Error: hkmc2.ErrorReport: Import of symbol `/Users/bakhtiyartemirov/Desktop/FYP/mlscript/hkmc2/shared/src/test/mlscript-compile/Predef.mjs` not implemented yet +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (nop)) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry)) +while true do + 0 +//│ FAILURE: Unexpected compilation error +//│ FAILURE LOCATION: errExpr (WatBuilder.scala:745) +//│ FAILURE INFO: Match: +//│ scrut = Ref of $scrut +//│ arms = Ls of +//│ Tuple2: +//│ _1 = Lit of BoolLit of true +//│ _2 = Assign: +//│ lhs = $tmp +//│ rhs = Lit of IntLit of 0 +//│ rest = Continue of $tmp +//│ dflt = S of Assign: +//│ lhs = $tmp +//│ rhs = Select{object:Unit}: +//│ qual = Ref of $runtime +//│ name = Ident of "Unit" +//│ rest = End of "" +//│ rest = End of "" +//│ ═══[COMPILATION ERROR] WatBuilder::returningTerm for expression not implemented yet +//│ Wat: +//│ (module +//│ (type (func (result (ref null any)))) +//│ (func $entry (type 0) (result (ref null any)) +//│ (local $scrut (ref null any)) +//│ (local $tmp (ref null any)) +//│ (block (result (ref null any)) +//│ (block $tmp1 +//│ (loop $loop +//│ (block (result (ref null any)) +//│ (local.set $scrut +//│ (ref.i31 +//│ (i32.const 1))) +//│ (unreachable)) +//│ (br $tmp1))) +//│ (local.get $tmp))) +//│ (export "entry" (func $entry)) +//│ (elem declare func $entry))