-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathSchool.hs
More file actions
65 lines (58 loc) · 2.43 KB
/
School.hs
File metadata and controls
65 lines (58 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{- | The Gelman and Hill [8-schools case study](https://cran.r-project.org/web/packages/rstan/vignettes/rstan.html),
which quantifies the effect of coaching programs from 8 different schools on students' SAT-V scores.
-}
module School where
import Control.Algebra (Has)
import Control.Monad (replicateM)
import Data.Kind (Constraint)
import Env (Assign ((:=)), nil, Observable, get,
Observables, (<:>), Env)
import Inference.MH as MH (mhRaw)
import Model (Model, deterministic, halfNormal', normal,
normal')
import Sampler (Sampler)
-- | School model environment
type SchEnv = '[
"mu" ':= Double, -- ^ effect of general coaching programs on SAT scores
"theta" ':= [Double], -- ^ variation of each program's effect on SAT scores
"y" ':= Double -- ^ effectiveness on SAT scores
]
-- | School model
schoolModel :: (Observables env '["mu", "y"] Double, Observable env "theta" [Double])
-- | number of schools
=> Int
-- | standard errors of each school
-> [Double]
-- | effectiveness of each school
-> Model env sig m [Double]
schoolModel n_schools σs = do
μ <- normal 0 10 #mu
τ <- halfNormal' 10
ηs <- replicateM n_schools (normal' 0 1)
θs <- deterministic (map ((μ +) . (τ *)) ηs) #theta
ys <- mapM (\(θ, σ) -> normal θ σ #y) (zip θs σs)
return θs
-- | Perform MH inference
mhSchool :: Sampler ([Double], [[Double]])
mhSchool = do
-- Specify model inputs
let n_schools = 8
ys = [28 :: Double, 8, -3, 7, -1, 1, 18, 12]
sigmas = [15, 10, 16, 11, 9, 11, 10, 18]
-- Specify model environment
env :: Env SchEnv
env = #mu := [] <:> #theta := [] <:> #y := ys <:> nil
-- Run MH inference for 10000 iterations
env_mh_out <- MH.mhRaw 10000 (schoolModel n_schools sigmas) env nil (#mu <:> #theta <:> nil)
-- Retrieve and returns the trace of model parameters mu and theta
let mus = concatMap (get #mu) env_mh_out
thetas = concatMap (get #theta) env_mh_out
return (mus, thetas)