Skip to content

Commit beeb921

Browse files
author
DigitalCodeCrafter
committed
added type environment for future type checking
1 parent 76a611f commit beeb921

3 files changed

Lines changed: 209 additions & 3 deletions

File tree

src/compiler/resolver.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use std::collections::HashMap;
22
use crate::compiler::{CompilerError, ast::{AST, NodeId, NodeKind}};
33

4-
type ScopeId = usize;
5-
type SymbolId = usize;
4+
pub type ScopeId = usize;
5+
pub type SymbolId = usize;
66

77
#[derive(Debug, Clone)]
88
pub enum ResolveError {
@@ -250,6 +250,7 @@ impl<'a> Resolver<'a> {
250250

251251
// add parameters
252252
for (param_name, _) in params.iter() {
253+
// FIXME: params shadow function in binding table due to same node_id
253254
let param_sym = Symbol {
254255
name: param_name.clone(),
255256
kind: SymbolKind::Variable,

src/compiler/type_checker.rs

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
2+
pub type TypeId = usize;
3+
4+
#[derive(Debug)]
5+
pub enum UnifyError {
6+
Mismatch(TypeId, TypeId),
7+
}
8+
9+
mod env {
10+
use std::collections::HashMap;
11+
use crate::compiler::resolver::SymbolId;
12+
use super::{TypeId, UnifyError};
13+
14+
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
15+
pub enum TypeKind {
16+
/// Built-in primitive type
17+
Primitive(String),
18+
19+
// User-defined type (structs, enums, aliases)
20+
User(SymbolId),
21+
22+
/// Type variable (for inferrence)
23+
Var(u32),
24+
25+
/// Function type: (params) -> return
26+
Function {
27+
params: Vec<TypeId>,
28+
ret: TypeId
29+
},
30+
31+
/// Tuple type: (T1, T2, ...)
32+
Tuple(Vec<TypeId>),
33+
34+
/// Array type: [T; N]
35+
Array {
36+
elem: TypeId,
37+
len: Option<u32>
38+
},
39+
40+
/// Generic instantiation: e.g. Option<T>, Vec<T>
41+
Generic {
42+
base: TypeId,
43+
args: Vec<TypeId>
44+
},
45+
}
46+
47+
pub struct TypeEnv {
48+
arena: Vec<TypeKind>,
49+
subst: HashMap<u32, TypeId>, // type var -> resolved type
50+
builtins: HashMap<String, TypeId>, // "i32" -> TypeId
51+
interned: HashMap<TypeKind, TypeId>, // type interning
52+
next_var: u32,
53+
}
54+
55+
impl TypeEnv {
56+
pub fn new() -> Self {
57+
let mut env = TypeEnv {
58+
arena: Vec::new(),
59+
subst: HashMap::new(),
60+
builtins: HashMap::new(),
61+
interned: HashMap::new(),
62+
next_var: 0,
63+
};
64+
65+
for name in ["i32", "f64", "bool", "str", "unit"] {
66+
env.intern_builtin(name);
67+
}
68+
69+
env
70+
}
71+
72+
fn intern_builtin(&mut self, name: &str) -> TypeId {
73+
let ty = TypeKind::Primitive(name.to_string());
74+
let id = self.intern(ty.clone());
75+
self.builtins.insert(name.to_string(), id);
76+
id
77+
}
78+
79+
pub fn intern(&mut self, ty: TypeKind) -> TypeId {
80+
if let Some(&id) = self.interned.get(&ty) {
81+
return id;
82+
}
83+
let id = self.arena.len();
84+
self.arena.push(ty.clone());
85+
self.interned.insert(ty, id);
86+
id
87+
}
88+
89+
pub fn new_var(&mut self) -> TypeId {
90+
let id = self.next_var;
91+
self.next_var += 1;
92+
self.intern(TypeKind::Var(id))
93+
}
94+
95+
pub fn get(&self, id: TypeId) -> &TypeKind {
96+
&self.arena[id]
97+
}
98+
99+
pub fn get_builtin(&self, name: &str) -> Option<TypeId> {
100+
self.builtins.get(name).copied()
101+
}
102+
103+
pub fn normalize(&mut self, ty: TypeId) -> TypeId {
104+
match &self.get(ty) {
105+
TypeKind::Var(v) => {
106+
let v = *v;
107+
if let Some(resolved) = self.subst.get(&v) {
108+
let normalized = self.normalize(*resolved);
109+
110+
self.subst.insert(v, normalized);
111+
normalized
112+
} else {
113+
ty
114+
}
115+
}
116+
_ => ty,
117+
}
118+
}
119+
120+
pub fn unify(&mut self, a: TypeId, b: TypeId) -> Result<TypeId, UnifyError> {
121+
let a = self.normalize(a);
122+
let b = self.normalize(b);
123+
124+
if a == b {
125+
return Ok(a);
126+
}
127+
128+
match (self.get(a).clone(), self.get(b).clone()) {
129+
// unify type variables
130+
(TypeKind::Var(va), _) => {
131+
self.subst.insert(va, b);
132+
Ok(b)
133+
}
134+
(_, TypeKind::Var(vb)) => {
135+
self.subst.insert(vb, a);
136+
Ok(a)
137+
}
138+
139+
// unify functions
140+
(
141+
TypeKind::Function { params: pa, ret: ra },
142+
TypeKind::Function { params: pb, ret: rb }
143+
) => {
144+
if pa.len() != pb.len() {
145+
return Err(UnifyError::Mismatch(a, b));
146+
}
147+
for (ta, tb) in pa.iter().zip(pb.iter()) {
148+
self.unify(*ta, *tb)?;
149+
}
150+
self.unify(ra, rb)
151+
}
152+
153+
// unify tuples
154+
(TypeKind::Tuple(ta), TypeKind::Tuple(tb)) => {
155+
if ta.len() != tb.len() {
156+
return Err(UnifyError::Mismatch(a, b));
157+
}
158+
for (xa, xb) in ta.iter().zip(tb.iter()) {
159+
self.unify(*xa, *xb)?;
160+
}
161+
Ok(a)
162+
}
163+
164+
// unify arrays
165+
(
166+
TypeKind::Array { elem: ea, len: la },
167+
TypeKind::Array { elem: eb, len: lb },
168+
) if la == lb => self.unify(ea, eb),
169+
170+
// identical user types
171+
(TypeKind::User(sa), TypeKind::User(sb)) if sa == sb => Ok(a),
172+
173+
// indetical primitives (failsafe)
174+
(TypeKind::Primitive(na), TypeKind::Primitive(nb)) if na == nb => Ok(a),
175+
176+
_ => Err(UnifyError::Mismatch(a, b))
177+
}
178+
}
179+
}
180+
181+
#[cfg(test)]
182+
mod tests {
183+
use super::*;
184+
185+
#[test]
186+
fn simple_inferrence() {
187+
let mut tenv = TypeEnv::new();
188+
189+
let i32_t = tenv.get_builtin("i32").unwrap();
190+
191+
let mut tvar_a = tenv.new_var();
192+
let mut tvar_b = tenv.new_var();
193+
194+
// a = i32
195+
tenv.unify(tvar_a, i32_t).unwrap();
196+
// a = b
197+
tenv.unify(tvar_b, tvar_a).unwrap();
198+
199+
tvar_a = tenv.normalize(tvar_a);
200+
tvar_b = tenv.normalize(tvar_b);
201+
202+
assert_eq!(tenv.get(tvar_a), tenv.get(tvar_b))
203+
}
204+
}
205+
}

src/kep_grammar.ebnf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Statement = ";"
2828
| LetStatement
2929
| ExpressionStatement ;
3030
31-
LetStatement = "let", [ "mut" ], Identifier, [ ":", Type ], "=", Expression, ";" ;
31+
LetStatement = "let", [ "mut" ], Identifier, [ ":", Type ], [ "=", Expression ], ";";
3232
3333
ExpressionStatement = Expression, ";" ;
3434

0 commit comments

Comments
 (0)