1+
2+ pub type TypeId = usize ;
3+
4+ #[ derive( Debug ) ]
5+ pub enum UnifyError {
6+ Mismatch ( TypeId , TypeId ) ,
7+ }
8+
9+ mod env {
10+ use std:: collections:: HashMap ;
11+ use crate :: compiler:: resolver:: SymbolId ;
12+ use super :: { TypeId , UnifyError } ;
13+
14+ #[ derive( Clone , Debug , PartialEq , Eq , Hash ) ]
15+ pub enum TypeKind {
16+ /// Built-in primitive type
17+ Primitive ( String ) ,
18+
19+ // User-defined type (structs, enums, aliases)
20+ User ( SymbolId ) ,
21+
22+ /// Type variable (for inferrence)
23+ Var ( u32 ) ,
24+
25+ /// Function type: (params) -> return
26+ Function {
27+ params : Vec < TypeId > ,
28+ ret : TypeId
29+ } ,
30+
31+ /// Tuple type: (T1, T2, ...)
32+ Tuple ( Vec < TypeId > ) ,
33+
34+ /// Array type: [T; N]
35+ Array {
36+ elem : TypeId ,
37+ len : Option < u32 >
38+ } ,
39+
40+ /// Generic instantiation: e.g. Option<T>, Vec<T>
41+ Generic {
42+ base : TypeId ,
43+ args : Vec < TypeId >
44+ } ,
45+ }
46+
47+ pub struct TypeEnv {
48+ arena : Vec < TypeKind > ,
49+ subst : HashMap < u32 , TypeId > , // type var -> resolved type
50+ builtins : HashMap < String , TypeId > , // "i32" -> TypeId
51+ interned : HashMap < TypeKind , TypeId > , // type interning
52+ next_var : u32 ,
53+ }
54+
55+ impl TypeEnv {
56+ pub fn new ( ) -> Self {
57+ let mut env = TypeEnv {
58+ arena : Vec :: new ( ) ,
59+ subst : HashMap :: new ( ) ,
60+ builtins : HashMap :: new ( ) ,
61+ interned : HashMap :: new ( ) ,
62+ next_var : 0 ,
63+ } ;
64+
65+ for name in [ "i32" , "f64" , "bool" , "str" , "unit" ] {
66+ env. intern_builtin ( name) ;
67+ }
68+
69+ env
70+ }
71+
72+ fn intern_builtin ( & mut self , name : & str ) -> TypeId {
73+ let ty = TypeKind :: Primitive ( name. to_string ( ) ) ;
74+ let id = self . intern ( ty. clone ( ) ) ;
75+ self . builtins . insert ( name. to_string ( ) , id) ;
76+ id
77+ }
78+
79+ pub fn intern ( & mut self , ty : TypeKind ) -> TypeId {
80+ if let Some ( & id) = self . interned . get ( & ty) {
81+ return id;
82+ }
83+ let id = self . arena . len ( ) ;
84+ self . arena . push ( ty. clone ( ) ) ;
85+ self . interned . insert ( ty, id) ;
86+ id
87+ }
88+
89+ pub fn new_var ( & mut self ) -> TypeId {
90+ let id = self . next_var ;
91+ self . next_var += 1 ;
92+ self . intern ( TypeKind :: Var ( id) )
93+ }
94+
95+ pub fn get ( & self , id : TypeId ) -> & TypeKind {
96+ & self . arena [ id]
97+ }
98+
99+ pub fn get_builtin ( & self , name : & str ) -> Option < TypeId > {
100+ self . builtins . get ( name) . copied ( )
101+ }
102+
103+ pub fn normalize ( & mut self , ty : TypeId ) -> TypeId {
104+ match & self . get ( ty) {
105+ TypeKind :: Var ( v) => {
106+ let v = * v;
107+ if let Some ( resolved) = self . subst . get ( & v) {
108+ let normalized = self . normalize ( * resolved) ;
109+
110+ self . subst . insert ( v, normalized) ;
111+ normalized
112+ } else {
113+ ty
114+ }
115+ }
116+ _ => ty,
117+ }
118+ }
119+
120+ pub fn unify ( & mut self , a : TypeId , b : TypeId ) -> Result < TypeId , UnifyError > {
121+ let a = self . normalize ( a) ;
122+ let b = self . normalize ( b) ;
123+
124+ if a == b {
125+ return Ok ( a) ;
126+ }
127+
128+ match ( self . get ( a) . clone ( ) , self . get ( b) . clone ( ) ) {
129+ // unify type variables
130+ ( TypeKind :: Var ( va) , _) => {
131+ self . subst . insert ( va, b) ;
132+ Ok ( b)
133+ }
134+ ( _, TypeKind :: Var ( vb) ) => {
135+ self . subst . insert ( vb, a) ;
136+ Ok ( a)
137+ }
138+
139+ // unify functions
140+ (
141+ TypeKind :: Function { params : pa, ret : ra } ,
142+ TypeKind :: Function { params : pb, ret : rb }
143+ ) => {
144+ if pa. len ( ) != pb. len ( ) {
145+ return Err ( UnifyError :: Mismatch ( a, b) ) ;
146+ }
147+ for ( ta, tb) in pa. iter ( ) . zip ( pb. iter ( ) ) {
148+ self . unify ( * ta, * tb) ?;
149+ }
150+ self . unify ( ra, rb)
151+ }
152+
153+ // unify tuples
154+ ( TypeKind :: Tuple ( ta) , TypeKind :: Tuple ( tb) ) => {
155+ if ta. len ( ) != tb. len ( ) {
156+ return Err ( UnifyError :: Mismatch ( a, b) ) ;
157+ }
158+ for ( xa, xb) in ta. iter ( ) . zip ( tb. iter ( ) ) {
159+ self . unify ( * xa, * xb) ?;
160+ }
161+ Ok ( a)
162+ }
163+
164+ // unify arrays
165+ (
166+ TypeKind :: Array { elem : ea, len : la } ,
167+ TypeKind :: Array { elem : eb, len : lb } ,
168+ ) if la == lb => self . unify ( ea, eb) ,
169+
170+ // identical user types
171+ ( TypeKind :: User ( sa) , TypeKind :: User ( sb) ) if sa == sb => Ok ( a) ,
172+
173+ // indetical primitives (failsafe)
174+ ( TypeKind :: Primitive ( na) , TypeKind :: Primitive ( nb) ) if na == nb => Ok ( a) ,
175+
176+ _ => Err ( UnifyError :: Mismatch ( a, b) )
177+ }
178+ }
179+ }
180+
181+ #[ cfg( test) ]
182+ mod tests {
183+ use super :: * ;
184+
185+ #[ test]
186+ fn simple_inferrence ( ) {
187+ let mut tenv = TypeEnv :: new ( ) ;
188+
189+ let i32_t = tenv. get_builtin ( "i32" ) . unwrap ( ) ;
190+
191+ let mut tvar_a = tenv. new_var ( ) ;
192+ let mut tvar_b = tenv. new_var ( ) ;
193+
194+ // a = i32
195+ tenv. unify ( tvar_a, i32_t) . unwrap ( ) ;
196+ // a = b
197+ tenv. unify ( tvar_b, tvar_a) . unwrap ( ) ;
198+
199+ tvar_a = tenv. normalize ( tvar_a) ;
200+ tvar_b = tenv. normalize ( tvar_b) ;
201+
202+ assert_eq ! ( tenv. get( tvar_a) , tenv. get( tvar_b) )
203+ }
204+ }
205+ }
0 commit comments