在scala中使用for表达式做monad运算

来源:互联网 时间:1970-01-01

在haskell中,我们有语法糖‘do’帮助表达monad运算。scala中我们也有相应语法糖‘for’。

for表达式会被scala compiler做一些变换,简单的例子如下:

for {

a <- foo

b <- bar

} yield (a + b)

===>

foo.flatMap((a) => {

bar.map((b) => {

a + b

})

})

所以我们需要实现两个方法 flatMap和map。

还是用前面的state monad作为例子, 我们给类型State加上flatMap和map。

case class State[S, A](runState: S => (S, A))(implicit m : Monad[({type M[a] = State[S, a]})#M]) {

def map[B](f: A => B) : State[S, B] = m.bind(this, (a: A) => m.ret(f(a)))

def flatMap[B](f: A => State[S, B]) : State[S, B] = m.bind(this, f)

}

这里我们使用了一个隐式参数,然后我们可以直接使用ret和bind。

同时加一个helper简化Monad[({type M[a] = State[S, a]})#M].ret

def ret[S, A](a: A) : State[S, A] = Monad[({type M[a] = State[S, a]})#M].ret(a)

好了,我们可以使用for表达式了,例子如下:

object Main {

import StateMonad._

def main(args: Array[String]) {

val r = for {

a <- ret[Int, Int](3)

b <- ret[Int, Int](4)

} yield (a+b)

println(r.runState(1))

}

}


相关阅读:
Top