Simple State Monad in Haskell with Examples

2021-09-11
haskell
ko

This post is based on this video from Graham Hutton's playlist on advanced functional programming. I tried to write this post without rewatching the video, to deliver my own understanding of the concept.

Defining State

State를 표현하는 어떤 값이 있을 때, 이 state를 변경시키는 어떤 black box가 있다고 하자. 그 black box는 이렇게 표현할 수 있을 것 같다.

blackBox :: a -> a -- where a denotes the state's type

이때 black box에서 input으로 들어온 state를 이용해 어떤 값을 계산한다고 가정한다면, 현재 함수의 타입 정보로는 해당 결과를 잘 표현하기 어렵다. 따라서 output을 두 개 돌려주는 다음과 같은 타입이 더 편리하다.

blackBox :: a -> (b, a) -- where a denotes the state's type, and b the return value's type

이러한 black box에 해당하는 새로운 타입을 다음과 같이 작성해보자. Record에 대해 매번 pattern match를 하는 것을 방지하기 위해, State와 state 값이 주어졌을 때, transformer를 호출해주는 helper 또한 정의하자.

newtype State a b = S {getState :: a -> (b, a)} runState :: State a b -> a -> (b, a) runState S {getState = f} = f

0부터 시작하는 id를 생성하는 함수가 있다고 하자. 이러한 상황에서는 ST를 다음과 같이 표현하여 id를 뽑아낼 수 있을 것이다.

-- Plain.hs next :: State Int Int next = S $ \n -> (n, n + 1) main :: IO () main = do let n0 = 0 let (id, n) = runState next n0 let (id', n') = runState next n let (id'', n'') = runState next n' print id -- 0 print id' -- 1 print id'' -- 2 print n'' -- 3

next 라는 state transformer를 사용하면 id를 하나 생성하고 state를 하나 증가시키는 것을 확인할 수 있다. 따라서 id, id', id''next ST에 새로운 state를 input으로 넣어 생성한 id들은 각각 0, 1, 2, 3의 값을 가지게 된다.

기존 방식의 문제점과 bind 함수

위 코드에서 문제점이 있다고 할 수 있는 부분은 composition이 불편하다는 점이다. runState 는 새로운 state를 만들어 내기 때문에, 이 함수를 여러 번 호출해야 하는 상황에서는 최신 state를 잘 넘겨주도록 유의해야 한다. 이를 위해 매번 pattern match를 해야 하고, 매번 state를 explicit하게 넘겨주어야 한다. 흐름 상으로만 보면 반복되는 코드인데 필요는 하기 때문에 일종의 boilerplate인 것이다. 이는 코드를 작성할 때 error prone하고, 읽을 때도 불필요한 디테일이다. 다음과 같은 문제점을 monad를 통해 해결해보자!

먼저 bind 함수를 정의해보자. bind 의 타입은 다음과 같다.

bind :: m a -> (a -> m b) -> m b

이때 m 은 type constructor으로, 현재 상황의 경우 State s 에 해당한다.

bind :: State s a -> (a -> State s b) -> State s b

state transformation을 두 번 시행함으로써 구현할 수 있다.

bind :: State s a -> (a -> State s b) -> State s b bind sa f = S $ \s -> let (a, s') = runState sa s in runState (f a) s'

이 함수만 있어도 composition이 더 깔끔해진다. 위에 있는 main 함수를 리팩토링해보자.

-- Bind.hs next :: State Int Int next = S $ \n -> (n, n + 1) main :: IO () main = fst $ runState ( next `bind` \id -> next `bind` \id' -> next `bind` \id'' -> S $ \s -> ( do print id print id' print id'', s )) 0

(표준 출력을 위한 IO 조작 코드가 있는데...무시하고 보자)

runState 의 첫 번째 인자로 넘겨진 State 코드와, 위에 있는 리팩토링 이전 코드를 비교했을 때, trivial한 state 관리 코드가 아예 사라진 것을 확인할 수 있다. composition을 위해 매번 새로운 state를 다음 호출에 넘겨주던 패턴을 bind 라는 함수 안에 넣어두었기 때문에 핵심 코드가 더욱 중점적으로 드러나게 되었다.

bind 라는 함수는 Monad 타입클래스에 (>>=) 라는 함수의 형태로 존재한다. StateMonad 타입클래스의 인스턴스로 만들면 do notation의 syntactic sugar 또한 사용할 수 있게 된다. 다만 Monad 타입클래스의 인스턴스로 만들기 위해서는 그 superclass인 Applicative 까지 구현을 해야 하고, Applicative 의 superclass인 Functor 까지 구현해야 한다. 구현 코드는 다음과 같이 작성할 수 있다.

import Control.Applicative -- a type that contains a state transformer -- the state transformer accepts a state, and returns a state with some return value newtype State a b = S {getState :: a -> (b, a)} runState :: State a b -> a -> (b, a) runState S {getState = f} = f instance Functor (State a) where -- fmap :: (a -> b) -> State s a -> State s b fmap f (S transform) = S (\s -> let (r, s') = transform s in (f r, s')) instance Applicative (State a) where -- pure :: a -> State s a pure x = S $ \s -> (x, s) -- (<*>) :: State s (a->b) -> State s a -> State s b sf <*> sa = S $ \s -> let (f', s') = runState sf s in runState (f' `fmap` sa) s' instance Monad (State a) where -- (>>=) :: State s a -> (a -> State s b) -> State s b (S f) >>= t = S $ \s -> let (a, s') = f s in runState (t a) s'

이제 do notation을 이용해 코드를 2차로 리팩토링해보자.

next :: State Int Int next = S $ \n -> (n, n + 1) main :: IO () main = fst $ runState ( do id <- next id' <- next id'' <- next S $ \s -> ( do print id print id' print id'', s ) ) 0

bind 호출이 없으니 조금 더 간결해졌다. 마지막으로, State monad를 사용하는 코드를 분리하여 리팩토링을 해보자.

-- Main.hs next :: State Int Int next = S $ \n -> (n, n + 1) nextThree :: State Int (Int, Int, Int) nextThree = do id <- next id' <- next id'' <- next return (id, id', id'') getThreeIds :: Int -> (Int, Int, Int) getThreeIds start = fst $ runState nextThree start main :: IO () main = do let (id, id', id'') = getThreeIds 0 print id print id' print id''

특히 처음에 비하면 코드가 훨씬 간결해진 것을 확인할 수 있다.

Tree Labeling

해당 monad를 또 다른 예시에 적용해보자. 이는 위에 언급된 동영상에 있는 예시이기도 하다.

data Tree a = Leaf a | Node (Tree a) (Tree a) deriving Show

다음과 같은 트리 자료구조가 있을 때, 각각 leaf마다 숫자 id를 붙여 label하는 함수를 작성해보자. 먼저 monad 없이 매뉴얼하게 작성해보자.

label :: Tree a -> Int -> (Tree Int, Int) label (Leaf v) n = (Leaf n, n + 1) label (Node l r) n = (Node l' r', n'') where (l', n') = label l n (r', n'') = label r n'

이 함수는 state를 관리한다. 그 다음에 사용할 id인 n이다. 함수를 pure하게 유지하기 위해서는 바늘구멍에 실 궤매듯이 이 state를 계속 들고 다녀야 한다.그럼 위에서 만든 State 를 이용하여 같은 코드를 작성해보자.

mlabel :: Tree a -> State Int (Tree Int) mlabel (Leaf v) = do n <- next return (Leaf n) mlabel (Node l r) = do l' <- mlabel l r' <- mlabel r return (Node l' r')

실 꿰매듯 state를 운반하는 코드는 더이상 보이지 않는다. 이 함수는 이렇게 사용할 수 있을 것이다.

-- Tree.hs testTree :: Tree String testTree = Node (Leaf "foo") (Node (Leaf "bar") (Leaf "baz")) main :: IO () main = do let labeled = fst $ runState (mlabel testTree) 0 print labeled

(runState 랑 0이 불편하면 다른 함수로 만들어두면 된다...여기선 생략한다)

이로써 순수 함수형 언어인 haskell에서 monad를 통해 state를 깔끔하게 처리하는 방법을 알아보았다! 글을 쓰며 나 역시 많이 배운 것 같다. 이 게시글에서 사용한 코드는 이 레포에 업로드되어있다!