Skip to content

Commit 6154af9

Browse files
committed
Added support for binary addition
1 parent 2fd0f16 commit 6154af9

4 files changed

Lines changed: 65 additions & 52 deletions

File tree

hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,11 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
455455
)(true, false)))
456456
else
457457
subTerm_nonTail(arg2): ar2 =>
458-
k(Call(Value.Ref(sym).withLocOf(ref), Arg(N, ar1) :: Arg(N, ar2) :: Nil)(true, false))
458+
val targetPath =
459+
if config.target == CompilationTarget.Wasm && sym.nme == "+"
460+
then Value.Ref(State.wasmSymbol).selN(Tree.Ident("plus_impl"))
461+
else Value.Ref(sym).withLocOf(ref)
462+
k(Call(targetPath, Arg(N, ar1) :: Arg(N, ar2) :: Nil)(true, false))
459463
case _ => fail:
460464
ErrorReport(
461465
msg"Unexpected arguments for builtin symbol '${sym.nme}'" -> arg.toLoc :: Nil, S(arg),
@@ -484,7 +488,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
484488
case t if t.resolvedSym.isDefined && (t.resolvedSym.get is ctx.builtins.js.try_catch) =>
485489
conclude(Value.Ref(State.runtimeSymbol).selN(Tree.Ident("try_catch")))
486490
case t if t.resolvedSym.exists(_ is ctx.builtins.wasm.plus_impl) =>
487-
conclude(Value.Ref(State.runtimeSymbol).selN(Tree.Ident("plus_impl")))
491+
conclude(Value.Ref(State.wasmSymbol).selN(Tree.Ident("plus_impl")))
488492
case t if t.resolvedSym.exists(_ is ctx.builtins.Int31) =>
489493
conclude(Value.Ref(State.runtimeSymbol).selN(Tree.Ident("Int31")))
490494
case t if t.resolvedSym.isDefined && (t.resolvedSym.get is ctx.builtins.debug.printStack) =>

hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/WatBuilder.scala

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -156,42 +156,19 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
156156
ref.func(funcIdx, RefType(ctx.getFuncInfo_!(l).typeIdx, nullable = false))
157157
case N => getVar(l, r.toLoc)
158158

159+
case c @ Call(fun, lhs :: rhs :: Nil) if isWasmIntrinsic(fun, "plus_impl") =>
160+
compilePlusIntrinsic(lhs, rhs)
161+
159162
case Call(Value.Ref(l: BuiltinSymbol), lhs :: rhs :: Nil) if !l.functionLike =>
160163
if l.binary then
161164
l.nme match
162165
case "+" =>
163-
// TODO(Derppening): Refactor to lower to `Call(plus_impl, ...)`
164-
def castOperand(expr: Expr, opSide: Str): Expr =
165-
expr.resultType match
166-
case S(RefType(HeapType.Any, _)) => `if`(
167-
ref.test(expr, RefType.i31ref),
168-
ifTrue = castOperand(ref.cast(expr, RefType.i31ref), opSide),
169-
ifFalse = S(unreachable),
170-
resultTypes = Seq(Result(I32Type))
171-
)
172-
case S(RefType(HeapType.I31, _)) => i31.get(expr, true)
173-
case S(I32Type) => expr
174-
case ty =>
175-
errExpr(
176-
Ls(
177-
msg"WatBuilder::result for binary builtin symbol '${l.nme.toString}' ($opSide.type=${ty.fold("(none)")(_.toWat.mkString())}) not implemented yet" -> r.toLoc
178-
),
179-
extraInfo = S(r.toString)
180-
)
181-
182-
val lhsOp = castOperand(operand(lhs), "lhs")
183-
val rhsOp = castOperand(operand(rhs), "rhs")
184-
185-
(lhsOp.resultType, rhsOp.resultType) match
186-
case (S(I32Type), S(I32Type)) =>
187-
ref.i31(i32.add(lhsOp, rhsOp))
188-
case (lhsType, rhsType) =>
189-
errExpr(
190-
Ls(
191-
msg"WatBuilder::result for binary builtin symbol '${l.nme.toString}' for (${lhsType.fold("(none)")(_.toWat.mkString())}, ${rhsType.fold("(none)")(_.toWat.mkString())}) not implemented yet" -> r.toLoc
192-
),
193-
extraInfo = S(r.toString)
194-
)
166+
errExpr(
167+
Ls(
168+
msg"WatBuilder::result encountered builtin '+' which should be lowered to wasm.plus_impl" -> r.toLoc
169+
),
170+
extraInfo = S(r.toString)
171+
)
195172
case lNme =>
196173
errExpr(
197174
Ls(
@@ -287,6 +264,25 @@ class WatBuilder(using TraceLogger, State) extends CodeBuilder:
287264
)
288265
end result
289266

267+
private def isWasmIntrinsic(path: Path, name: Str): Bool = path match
268+
case Select(Value.Ref(sym), ident) =>
269+
(sym eq State.wasmSymbol) && ident.name == name
270+
case _ => false
271+
272+
private def compilePlusIntrinsic(lhs: Arg, rhs: Arg)(using
273+
Ctx,
274+
Raise,
275+
Scope
276+
): Expr =
277+
val lhsCast = ref.cast(operand(lhs), RefType.i31ref)
278+
val rhsCast = ref.cast(operand(rhs), RefType.i31ref)
279+
ref.i31(
280+
i32.add(
281+
i31.get(lhsCast, true),
282+
i31.get(rhsCast, true)
283+
)
284+
)
285+
290286
def returningTerm(t: Block)(using Ctx, Raise, Scope): Expr = t match
291287
case _: HandleBlock =>
292288
errExpr(

hkmc2/shared/src/test/mlscript/wasm/Basics.mls

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,12 @@ foo() + foo()
7272
//│ (ref.func $foo)))
7373
//│ (ref.i31
7474
//│ (i32.add
75-
//│ (if (result i32)
76-
//│ (ref.test (ref null i31)
77-
//│ (local.get $tmp))
78-
//│ (then
79-
//│ (i31.get_s
80-
//│ (ref.cast (ref null i31)
81-
//│ (local.get $tmp))))
82-
//│ (else
83-
//│ (unreachable)))
84-
//│ (if (result i32)
85-
//│ (ref.test (ref null i31)
86-
//│ (local.get $tmp1))
87-
//│ (then
88-
//│ (i31.get_s
89-
//│ (ref.cast (ref null i31)
90-
//│ (local.get $tmp1))))
91-
//│ (else
92-
//│ (unreachable)))))))))
75+
//│ (i31.get_s
76+
//│ (ref.cast (ref null i31)
77+
//│ (local.get $tmp)))
78+
//│ (i31.get_s
79+
//│ (ref.cast (ref null i31)
80+
//│ (local.get $tmp1)))))))))
9381
//│ (export "entry" (func $entry))
9482
//│ (elem declare func $entry))
9583
//│ Wasm result:
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
:global
3+
:wasm
4+
:wat
5+
6+
1 + 3
7+
//│ Wat:
8+
//│ (module
9+
//│ (type (func (result (ref null any))))
10+
//│ (func $entry (type 0) (result (ref null any))
11+
//│ (ref.i31
12+
//│ (i32.add
13+
//│ (i31.get_s
14+
//│ (ref.cast (ref null i31)
15+
//│ (ref.i31
16+
//│ (i32.const 1))))
17+
//│ (i31.get_s
18+
//│ (ref.cast (ref null i31)
19+
//│ (ref.i31
20+
//│ (i32.const 3)))))))
21+
//│ (export "entry" (func $entry))
22+
//│ (elem declare func $entry))
23+
//│ Wasm result:
24+
//│ = 4
25+
//│

0 commit comments

Comments
 (0)