型推論

プログラミング言語 Scala Wiki型推論実習(?)を参考に、というかほとんど写経して、Haskell型推論プログラムを書いてみた。

-- TypeInfer.hs

import Control.Monad.State
import Data.Function
import Data.List

import Debug.Trace

data Term =
    Var String
  | Lam String Term
  | App Term Term
  | Let String Term Term
  | Letrec String Term Term
  deriving Eq

instance Show Term where
  show (Var id)     = id
  show (Lam x v) = "(" ++ x ++ " -> " ++ show v ++ ")"
  show (App v u)  = "(" ++ show v ++ " " ++ show u ++ ")"
  show (Let x v u) = "(let " ++ x ++ " = " ++ show v ++ " in " ++ show u ++ ")"
  show (Letrec x v u) = "(letrec " ++ x ++ " = " ++ show v ++ " in " ++ show u ++ ")"


data Type =
    TVar Int
  | TArrow Type Type
  | TCon String [Type]
  deriving Eq

instance Show Type where
  show (TVar id)    = show id
  show (TArrow t u) = "(" ++ show t ++ " -> " ++ show u ++ ")"
  show (TCon k ts)  = "(" ++ k ++ concatMap ((" "++).show) ts ++ ")"


type Subst = [(Int,Type)]

apply :: Subst -> Type -> Type
apply sb t@(TVar n) = case lookup n sb of
                         Just u | t == u -> t
                                | otherwise -> apply sb u
                         Nothing -> t

apply sb (TArrow t1 t2) = TArrow (apply sb t1) (apply sb t2)
apply sb (TCon k ts)    = TCon k (map (apply sb) ts)


type Infer = State Int
newTVar = do
  n <- get
  put (n+1)
  return $ TVar (n+1)

data TypeScheme = Sch [Int] Type
  deriving Show

newInstance :: TypeScheme -> Infer Type
newInstance (Sch vars t) = do
  sb <- forM vars $ \var -> fmap ((,) var) newTVar
  return $ apply sb t

type Env = [(String, TypeScheme)]

forall :: Env -> Subst -> Type -> TypeScheme
forall env sb t = Sch (tyvars u \\ envvars env) u
  where
    u = apply sb t
    schvars :: TypeScheme -> [Int]
    schvars (Sch vars _) = vars

    envvars :: Env -> [Int]
    envvars env = concatMap (schvars.snd) env

exist :: Subst -> Type -> TypeScheme
exist sb t = Sch [] (apply sb t)

tyvars :: Type -> [Int]
tyvars (TVar n)     = [n]
tyvars (TArrow t u) = tyvars t ++ tyvars u
tyvars (TCon _ ts)  = concatMap tyvars ts

unify :: Subst -> (Type, Type) -> Subst
unify sb (t, u) =
  case ((apply sb t), (apply sb u)) of
    (TVar n, TVar m) | n == m                 -> (n,u):sb
    (TVar n, _)      | n `notElem` (tyvars u) -> (n,u):sb
    (_, TVar _)                               -> unify sb (u, t)
    (TArrow t1 u1, TArrow t2 u2)
        -> unify (unify sb (t1, t2)) (u1, u2)

    (TCon k1 ts1, TCon k2 ts2)
      | k1 == k2 && ((==) `on` length) ts1 ts2
        -> foldl unify sb $ zip ts1 ts2

    _   -> error $ show t ++ " and " ++ show u
                   ++ "can't unify!"

typ :: Term -> Type -> Subst -> Env -> Infer Subst
typ tr@(Var name) expect sb env = do
  case lookup name env of
    Just sch -> do u <- newInstance sch
                   return $ unify sb (expect,u)
    Nothing  -> error $ name ++ " is not in scope"

typ tr@(Lam x term) expect sb env = do
  a <- newTVar
  b <- newTVar
  let nsb  = unify sb (expect, (TArrow a b))
      nenv = (x, exist sb a) : env
  typ term b nsb nenv

typ tr@(App term1 term2) expect sb env = do
  a   <- newTVar
  nsb <- typ term1 (TArrow a expect) sb env
  typ term2 a nsb env

typ tr@(Let name term1 term2) expect sb env = do
  a   <- newTVar
  nsb <- typ term1 a sb env
  typ term2 expect nsb $ (name, forall env nsb a):env

typ tr@(Letrec name term1 term2) expect sb env = do
  a <- newTVar
  let nenv =(name, exist sb a) : env
  nsb <- typ term1 a sb nenv
  typ term2 expect nsb $ (name, forall nenv nsb a):nenv

infer :: Env -> Term -> Infer Type
infer env term = do
  a  <- newTVar
  sb <- typ term a [] env
  return $ apply sb a

test term = flip evalState 0 $ do

  let booleanType   = TCon "Boolean" []
      intType       = TCon "Int" []
      listType t    = TCon "List" [t]

  a <- newTVar

  let sch' t = forall [] [] t
      env = [("true", sch' booleanType),
             ("false", sch' booleanType),
             ("if", sch' (TArrow booleanType (TArrow a (TArrow a a)))),
             ("zero", sch' intType),
             ("succ", sch' (TArrow intType intType)),
             ("add", sch' (TArrow (TArrow intType intType) intType)),
             ("nil",  sch' (listType a)),
             ("cons", sch' (TArrow a (TArrow (listType a) (listType a)))),
             ("isEmpty", sch' (TArrow (listType a) booleanType)),
             ("head", sch' (TArrow (listType a) a)),
             ("tail", sch' (TArrow (listType a) a)),
             ("fix",  sch' (TArrow (TArrow a a) a))]

  infer env term

いちおう、課題のletrecも実装し、再帰関数lengthの型を推論してみる:

sh$ ghci TypeInfer.hs
...
*Main> test (Letrec "length" (Lam "xs" (App (App (App (Var "if") (App (Var "isEmpty") (Var "xs"))) (Var "zero")) (App (Var "succ") (App (Var "tail") (Var "xs"))))) (Var "length"))
((List (Int)) -> (Int))
*Main> 

うまく動いている。型推論について大雑把な理解が得られたと思う。このアルゴリズムは、ひとつの見方としては:

  1. すべての部分式に型変項を仮に割り当てる*1
  2. それぞれの型変項からまた別の型変項へ、あるいは型定項(?)への巨大な有方向グラフを作る
    1. このとき、二つの(部分式の)型変項が同じ型を表すべきであることが分かれば、同じ項を指すように有方向グラフを成長させる
    2. ある型変項がある具体的な型を表すべきだと分かれば、その型変項がその型定項を指すように有方向グラフを成長させる

という作業を再帰的に繰り返すことによって、型推論を行っている。

うーん… 1パラメータの型クラスの導入はそう難しくはないかな。

*1:厳密にはちょっと違う。