Scalaで再帰プログラミング

Table of Contents

1 概要

Scalaでの再帰プログラミングについて学ぶ.

1.1 注意

本Webページの作成には Emacs org-mode を用い, 数式等の表示は MathJax を用いています. IEでは正しく表示されないことがあるため, Firefox, Safari等のWebブラウザでJavaScriptを有効にしてお使いください. また org-info.js を利用しており, 「m」キーをタイプするとinfoモードでの表示になります. 利用できるショートカットは「?」で表示されます.

2 再帰的定義

2.1 階乗

数学では,定義の中に自分自身が現れることがある. たとえば n の階乗は次のように定義される. \begin{eqnarray*} n! & = & \left\{ \begin{array}{ll} 1 & (n=0) \\ n\times (n-1)! & (n\geq1) \end{array} \right. \end{eqnarray*} このように定義の中に自分自身が現れる定義を, 再帰的定義 (recursive definition)と呼ぶ.

Scalaでは,階乗を求める関数 fact は次のように再帰的に定義できる. なお,Scalaで関数を再帰的定義する場合は,返り値のデータ型を指定する必要がある. 下では BigInt としている.

scala> def fact(n: Int): BigInt = if (n == 0) 1 else n * fact(n - 1)

実行してみると以下のようになる.

scala> fact(30)
res: BigInt = 265252859812191058636308480000000

2.2 練習問題

  1. 上の階乗の関数はいくつの階乗まで計算できるだろうか. 大きな数の階乗を計算した場合,どのようなエラーが表示されるだろうか.
  2. \(\sum_{i=1}^n 1/i\) をDoubleで計算する関数 g(n) を再帰的に定義せよ.
    (解答例)
    scala> def g(n: Int): Double = if (n == 0) 0 else 1.0/n + g(n - 1)
    

    注意: 上の定義は,大きい値から和を計算しているため誤差が大きくなる可能性がある. \(g(n) - \ln n\) は \(n \rightarrow \infty\) でおよそ 0.57721 の値に収束する (オイラー定数).

    scala> g(1000) - math.log(1000)
    res: Double = 0.5777155815682065
    

2.3 フィボナッチ数

以下の漸化式で定義されるフィボナッチ数を計算する関数を考えよう (参考: Project Eulerに挑戦: 問題2). \begin{eqnarray*} fib(n) & = & \left\{ \begin{array}{ll} n & (n<2) \\ fib(n-1)+fib(n-2) & (n\geq 2) \end{array} \right. \end{eqnarray*} 関数 fib は次のように再帰的に定義できる.

scala> def fib(n: Int): BigInt = if (n < 2) n else fib(n - 1) + fib(n - 2)

実行してみると以下のようになる.

scala> fib(30)
res: BigInt = 832040

2.4 練習問題

  1. fib(50) を実行してみよ.
  2. フィボナッチ数の一般項はどうなるか.
    (解答例)

    Project Eulerに挑戦: 問題2 を参照.

  3. fib(n) は,nについて指数的な計算時間を必要とすることを示せ.
    (解答例)

    fib(n) の再帰呼出しの回数を表す漸化式を考え, その一般項を求めることでわかる.

2.5 リストの総和

与えられた整数リストの要素の総和を計算する関数を考えよう. すなわち以下のように動作する関数sumを実装する.

scala> val list = List(3,1,4,2)
list: List[Int] = List(3, 1, 4, 2)
scala> sum(list)
res: Int = 10

整数リストxsの要素の総和を計算する関数sum(xs)は,以下のように再帰的に定義できる.

  1. xs が空リストの場合,0 を返す.
  2. xs が空リストでない場合,xs.tail に対して再帰的に総和を求めた値に xs.head を加えた和を返す.

これをそのままScalaで記述したもの,および実行結果は以下のようになる.

scala> def sum(xs: List[Int]): Int = if (xs.isEmpty) 0 else xs.head + sum(xs.tail)
sum: (xs: List[Int])Int
scala> sum(list)
res: Int = 10

xs.tail に対して再帰的に処理するのではなく, 以下のように xs.init に対して再帰的に処理することも可能だ.

def sum(xs: List[Int]): Int = if (xs.isEmpty) 0 else xs.last + sum(xs.init)

しかしリストのデータ構造は,xs.head と xs.tail が効率良く求められるようになっている. したがって xs.tail に対して再帰的に処理するほうが自然だし効率が良い.

2.6 練習問題

  1. 与えられた整数リストの要素の積を計算する関数prodを再帰的に定義せよ.
    (解答例)
    scala> def prod(xs: List[Int]): Int = if (xs.isEmpty) 1 else xs.head * prod(xs.tail)
    prod: (xs: List[Int])Int
    scala> prod(list)
    res: Int = 24
    

2.7 リストの要素のインクリメント

与えられた整数リストの各要素をインクリメントする関数を考えよう. すなわち以下のように動作する関数mapIncを実装する.

scala> val list = List(3,1,4,2)
list: List[Int] = List(3, 1, 4, 2)
scala> mapInc(list)
res: List[Int] = List(4, 2, 5, 3)

関数mapInc(xs)は,以下のように再帰的に定義できる.

  1. xs が空リストの場合,空リストを返す.
  2. xs が空リストでない場合,xs.tail に対して再帰的に処理した結果のリストの先頭に xs.head+1 を加えたリストを返す.

これをそのままScalaで記述したもの,および実行結果は以下のようになる.

scala> def mapInc(xs: List[Int]): List[Int] =
         if (xs.isEmpty) Nil else xs.head + 1 :: mapInc(xs.tail)
mapInc: (xs: List[Int])List[Int]
scala> mapInc(list)
res: List[Int] = List(4, 2, 5, 3)

2.8 練習問題

  1. 与えられた整数リストの偶数要素だけからなるリストを求める関数filterEvenを再帰的に定義せよ.
    (解答例)
    scala> def filterEven(xs: List[Int]): List[Int] =
             if (xs.isEmpty) Nil
             else if (xs.head % 2 == 0) xs.head :: filterEven(xs.tail)
             else filterEven(xs.tail)
    filterEven: (xs: List[Int])List[Int]
    scala> filterEven(list)
    res: List[Int] = List(4, 2)
    

2.9 ファイルに保存しての実行

上の関数をファイルに保存して実行してみる. まず,以下の内容を記述した Ex.scala ファイルを作成し保存する.

 1:  object Ex01 {
 2:    def fact(n: Int): BigInt =
 3:      if (n == 0) 1 else n * fact(n - 1)
 4:    def fib(n: Int): BigInt =
 5:      if (n < 2) n else fib(n - 1) + fib(n - 2)
 6:    def sum(xs: List[Int]): Int =
 7:      if (xs.isEmpty) 0 else xs.head + sum(xs.tail)
 8:    def mapInc(xs: List[Int]): List[Int] =
 9:      if (xs.isEmpty) Nil else xs.head + 1 :: mapInc(xs.tail)
10:  }

Scala REPLからロードして実行するには以下のようにする.

$ scala
scala> :load Ex.scala
Loading Ex.scala...
defined module Ex01
scala> Ex01.fact(10)
res: BigInt = 3628800

Scalaコンパイラでコンパイルして実行するには以下のようにする.

$ scalac Ex.scala
$ scala
scala> Ex01.fact(10)
res: BigInt = 3628800

3 match構文

3.1 階乗

Scalaのmatch構文を用いると場合分けの処理をわかりやすく記述できる.

def fact(n: Int): BigInt =
  n match {
    case 0 => 1
    case _ => n * fact(n - 1)
  }

matchの前の n が場合分けの対象で, match以降の case の後が場合分けのパターンを表す. 特に _ は残りのすべての場合に対応する (Javaのswitch構文のdefaultに対応する).

3.2 練習問題

  1. match構文を用いて,フィボナッチ数を計算する関数 fib を再帰的に定義せよ.
    (解答例)
    def fib(n: Int): BigInt =
      n match {
        case 0 | 1 => n
        case _ => fib(n - 1) + fib(n - 2)
      }
    

    caseのパターンは | で複数の場合を記述できる.

3.3 リストの総和

与えられた整数リストの要素の総和を計算する関数sumをmatch構文で定義しよう.

def sum(xs: List[Int]): Int =
  xs match {
    case Nil => 0
    case x :: xs1 => x + sum(xs1)
  }

リストの先頭要素 x と残りのリスト xs1 をパターンマッチで取り出している点に注意する.

3.4 練習問題

  1. match構文を用いて,与えられた整数リストの各要素をインクリメントする関数mapIncを再帰的に定義せよ.
    (解答例)
    def mapInc(xs: List[Int]): List[Int] =
      xs match {
        case Nil => Nil
        case x :: xs1 => x + 1 :: mapInc(xs1)
      }
    
  2. match構文を用いて,与えられた整数リストの偶数要素だけからなるリストを求める関数filterEvenを再帰的に定義せよ.
    (解答例)
    def filterEven(xs: List[Int]): List[Int] =
      xs match {
        case Nil => Nil
        case x :: xs1 if x % 2 == 0 => x :: filterEven(xs1)
        case x :: xs1 => filterEven(xs1)
      }
    

    リストの先頭要素 x と残りのリスト xs1 をパターンマッチで取り出し, x の条件を if でチェックしている点に注意する.

4 末尾再帰

4.1 階乗

上で定義した階乗関数 fact は,大きな n については スタックオーバーフローのエラーが表示される.

scala> fact(10000)
java.lang.StackOverflowError

以下のようにしてJVMのスタックサイズを大きくして scala を起動すれば実行できる (デフォールトのサイズは 512KB).

$ JAVA_OPTS="-Xss2M" scala
scala> fact(10000)
res: BigInt = 284625968...

しかし fact のプログラムを書き換えて, スタックの消費量を減らせば,JVMのスタックサイズを大きくする必要がなくなり, メモリの有効利用が図れる.

では,どのようにすればスタックの消費量を減らすことができるのだろうか. fact のプログラムの再帰呼出しの部分を見てみると n * fact(n - 1) のようになっている. したがって fact(10000) の計算では fact(9999) が再帰的に呼び出され, その実行が終了すると fact(10000) の処理に戻ってきて 10000 倍の計算が行われる.

再帰呼出しから戻ってきてから積を計算するのでなく, 再帰呼出し過程で積を累積して計算する方法に変更をしてみる. 具体的には,以下のように2引数の関数 fact を定義する.

def fact(n: Int, f: BigInt): BigInt =
  n match {
    case 0 => f
    case _ => fact(n - 1, n * f)
  }

n の階乗は fact(n, 1) を実行する.

scala> fact(10000, 1)
res: BigInt = 284625968...

このプログラムでの再帰呼出しの部分は fact(n - 1, n * f) となっており, 再帰呼出しから戻って来た場合,そのまま上位に戻るだけである. このような場合,処理系はスタックを消費せずに再帰呼出しを処理できる.

このようなプログラムは 末尾再帰的 (tail recursive)なプログラムと呼ばれ, 処理系は末尾再帰呼出しの最適化を行うことが可能になる.

fact(n) は fact(n, 1) として定義すれば良いが, 後者は前者からしか呼び出されることがないため, 以下のように前者の内部に定義すれば良い.

def fact(n: Int): BigInt = {
  def fact(n: Int, f: BigInt): BigInt =
    n match {
      case 0 => f
      case _ => fact(n - 1, n * f)
    }
  fact(n, 1)
}

4.2 練習問題

  1. フィボナッチ数を計算する関数も,引数を2つ追加すれば末尾再帰にできる. どのようにすれば良いか.
    (解答例)
    def fib(n: Int): BigInt = {
      def fib(n: Int, f0: BigInt, f1: BigInt): BigInt =
        n match {
          case 0 => f0
          case _ => fib(n - 1, f1, f0 + f1)
        }
      fib(n, 0, 1)
    }
    

4.3 リストの総和

与えられた整数リストの要素の総和を計算する関数sumも末尾再帰にできる.

def sum(xs: List[Int]): Int = {
  def sum(xs: List[Int], s: Int): Int =
    xs match {
      case Nil => s
      case x :: xs1 => sum(xs1, x + s)
    }
  sum(xs, 0)
}

4.4 練習問題

  1. 与えられた整数リストの各要素をインクリメントする関数mapIncを末尾再帰的に定義せよ. ヒント: 引数を一つ追加し,再帰呼出しの際にその最後尾にインクリメントした値を追加する.
    (解答例)
    def mapInc(xs: List[Int]): List[Int] = {
      def mapInc(xs: List[Int], ys: List[Int]): List[Int] =
        xs match {
          case Nil => ys
          case x :: xs1 => mapInc(xs1, ys :+ x + 1)
        }
      mapInc(xs, Nil)
    }
    

    上の定義では :+ により要素をリストの最後尾に追加している. リスト構造を用いた場合,この方法の効率は悪い. x + 1 :: ys で先頭に追加し,xsが空リストの場合に ys.reverse を返すようにするか, 最後尾への追加が効率良く行えるデータ構造を用いるのが良い.

  2. 与えられた整数リストの偶数要素だけからなるリストを求める関数filterEvenを末尾再帰的に定義せよ. ヒント: 引数を一つ追加し,再帰呼出しの際にその最後尾に偶数要素を追加する.
    (解答例)
    def filterEven(xs: List[Int]): List[Int] = {
      def filterEven(xs: List[Int], ys: List[Int]): List[Int] =
        xs match {
          case Nil => ys
          case x :: xs1 if x % 2 == 0 => filterEven(xs1, ys :+ x)
          case x :: xs1 => filterEven(xs1, ys)
        }
      filterEven(xs, Nil)
    }
    

    上の定義では :+ により要素をリストの最後尾に追加している. リスト構造を用いた場合,この方法の効率は悪い. x + 1 :: ys で先頭に追加し,xsが空リストの場合に ys.reverse を返すようにするか, 最後尾への追加が効率良く行えるデータ構造を用いるのが良い.

5 高階関数

Scalaでは関数をデータとして利用できる. 関数を引数とする関数や,関数を返す関数を 高階関数 (higher-order function)という.

5.1 map関数

与えられた整数リストの各要素をインクリメントする関数mapIncを抽象化し, 各要素に適用できる関数を引数として渡せるようにしよう. すなわち以下のように動作する高階関数mapを実装する.

scala> val list = List(3,1,4,2)
list: List[Int] = List(3, 1, 4, 2)
scala> map(list, x => x+1)
res: List[Int] = List(4, 2, 5, 3)

map の2番目の引数は整数を受け取り整数を返す関数だから そのデータ型は Int => Int となる. このことに注意すると map 関数は以下のように定義できる.

def map(xs: List[Int], f: Int => Int): List[Int] =
  xs match {
    case Nil => Nil
    case x :: xs1 => f(x) :: map(xs1, f)
  }

5.2 練習問題

  1. 上のmap関数を末尾再帰にせよ.
    (解答例)
    def map(xs: List[Int], f: Int => Int): List[Int] = {
      def map(xs: List[Int], ys: List[Int]): List[Int] =
        xs match {
          case Nil => ys.reverse
          case x :: xs1 => map(xs1, f(x) :: ys)
        }
      map(xs, Nil)
    }
    

    引数 f は内部メソッド実行中には変化しないので, 引数として渡す必要がない点に注意する.

5.3 filter関数

次に,与えられた整数リストの偶数要素だけからなるリストを求める関数filterEvenを抽象化し, 各要素に適用できる関数を引数として渡せるようにする. すなわち以下のように動作する高階関数filterを実装する.

scala> filter(list, x => x % 2 == 0)
res: List[Int] = List(4, 2, 5, 3)

2番目の引数のデータ型が Int => Boolean となることに注意すれば, filter 関数の定義は以下のようになる.

def filter(xs: List[Int], f: Int => Boolean): List[Int] =
  xs match {
    case Nil => Nil
    case x :: xs1 if f(x) => x :: filter(xs1, f)
    case x :: xs1 => filter(xs1, f)
  }

5.4 練習問題

  1. 上のfilter関数を末尾再帰にせよ.
    (解答例)
    def filter(xs: List[Int], f: Int => Boolean): List[Int] = {
      def filter(xs: List[Int], ys: List[Int]): List[Int] =
        xs match {
          case Nil => ys.reverse
          case x :: xs1 if f(x) => filter(xs1, x :: ys)
          case x :: xs1 => filter(xs1, ys)
        }
      filter(xs, Nil)
    }
    

5.5 fold関数

次に,与えられた整数リストの要素の総和を計算する関数sumを抽象化し, 要素に適用する関数を引数として渡せるようにする. すなわち以下のように動作する高階関数foldを実装する.

scala> fold(list, 0, (x,y) => x + y)
res: Int = 10

2番目の引数の 0 は,リストが空リストの場合の値を指定している. 3番目の引数のデータ型が (Int,Int) => Int となることに注意すれば, fold 関数の定義は以下のようになる.

def fold(xs: List[Int], z: Int, f: (Int,Int) => Int): Int =
  xs match {
    case Nil => z
    case x :: xs1 => f(x, fold(xs1, z, f))
  }

5.6 練習問題

  1. 上のfold関数は,リストのメソッドのfoldLeftとfoldRightのどちらになっているだろうか.
    (解答例)

    上のfold関数は,リスト x1, x2, …, xn について f(x1, f(x2, …f(xn, z)…)) を計算している. したがってfoldRightに相当する.

  2. 上のfold関数を末尾再帰にせよ. ただし引数として与えられる関数 f は結合的とする. すなわち任意の x, y, z について f(x,f(y,z)) = f(f(x,y),z) とする.
    (解答例)
    def fold(xs: List[Int], z: Int, f: (Int,Int) => Int): Int = {
      def fold(xs: List[Int], s: Int): Int =
        xs match {
          case Nil => s
          case x :: xs1 => fold(xs1, f(s, x))
        }
      fold(xs, z)
    }
    

    このfold関数は,リスト x1, x2, …, xn について f(…f(f(z, x1), x2)…, xn) を計算している. したがってfoldLeftに相当する. 適用する関数が結合的な場合,foldLeftの値とfoldRightの値は同一になる.

6 総称関数

ここまではリストについては,整数を要素とするリストを考えてきたが, 任意のデータ型のリストについて動作する関数を定義する.

6.1 リストの長さ

まず,整数リストの長さを求める関数 size を定義してみよう.

def size(xs: List[Int]): Int =
  xs match {
    case Nil => 0
    case _ :: xs1 => 1 + size(xs1)
  }

整数リストに対しては動作するが,文字列リストに対してはエラーとなる.

scala> size(List(3,1,4,2))
res: Int = 4

scala> size(List("p","i"))
<console>:9: error: type mismatch;

Javaでは class ClassName<T> のようにして, 型パラメータ T を持つ総称型 (generic type)を定義できた. Scalaでも同様に総称型を定義できるだけでなく, 関数についても以下のように型パラメータを与えた 総称関数 (generic function)を定義できる.

def size[A](xs: List[A]): Int =
  xs match {
    case Nil => 0
    case _ :: xs1 => 1 + size(xs1)
  }

ここで size の直後の [A] が型パラメータの指定で, A が型パラメータ名である. 引数 xs のデータ型は,型パラメータを用いて List[A] と指定されている.

この関数は,整数リストに対しても文字列リストに対しても動作する.

scala> size(List("p","i"))
res: Int = 2

scala> size(List(3,1,4,2))
res: Int = 4

関数呼び出し時は size[Int](List(3,1,4,2)) のように データ型を指定しなくても良い点に注意する(もちろん指定しても良い). これは,Scala処理系が型推論によりデータ型を推論しているためである.

6.2 練習問題

  1. 上のsize関数を末尾再帰にせよ.
    (解答例)
    def size[A](xs: List[A]): Int = {
      def size(xs: List[A], s: Int): Int =
        xs match {
          case Nil => s
          case _ :: xs1 => size(xs1, s + 1)
        }
      size(xs, 0)
    }
    
  2. リストの最後の要素を求める関数lastを定義せよ.
    (解答例)
    def last[A](xs: List[A]): A =
      xs match {
        case x :: Nil => x
        case _ :: xs1 => last(xs1)
      }
    

    この定義ではmatch構文中に xs が Nil の場合が含まれていないため, コンパイル時あるいはロード時に "warning: match is not exhaustive!" というwarningが表示される. そのままでも問題なく実行できるが, warningの表示を抑制するには match の直前の xs(xs: @unchecked) に変更すれば良い.

6.3 map関数

次に,以前に定義したmap関数を総称化しよう. 元の要素のデータ型を A とし,結果の要素のデータ型を B とすると 以下のように定義できる.

def map[A, B](xs: List[A], f: A => B): List[B] =
  xs match {
    case Nil => Nil
    case x :: xs1 => f(x) :: map(xs1, f)
  }

x => x + 1 の関数を適用してみるとエラーとなってしまう (型推論できても良いと思うが…).

scala> map(List(3,1,4), x => x + 1)
<console>:9: error: missing parameter type

関数の引数のデータ型を与えれば正しく実行できる.

scala> map(List(3,1,4), (x: Int) => x + 1)
res: List[Int] = List(4, 2, 5)

scala> map(List(3,1,4), (_: Int) + 1)
res: List[Int] = List(4, 2, 5)

6.4 練習問題

  1. 上のmap関数を末尾再帰にせよ.
    (解答例)
    def map[A, B](xs: List[A], f: A => B): List[B] = {
      def map(xs: List[A], ys: List[B]): List[B] =
        xs match {
          case Nil => ys.reverse
          case x :: xs1 => map(xs1, f(x) :: ys)
        }
      map(xs, Nil)
    }
    
  2. 以前に定義したfilter関数を総称化せよ.
    (解答例)
    def filter[A](xs: List[A], f: A => Boolean): List[A] =
      xs match {
        case Nil => Nil
        case x :: xs1 if f(x) => x :: filter(xs1, f)
        case x :: xs1 => filter(xs1, f)
      }
    
  3. 以前に定義したfold関数を総称化せよ.
    (解答例)
    def fold[A, B](xs: List[A], z: B, f: (A,B) => B): B =
      xs match {
        case Nil => z
        case x :: xs1 => f(x, fold(xs1, z, f))
      }
    

7 Seqの利用

ここまで List クラスを用いたプログラムを提示してきたが, Seq クラスを用いたほうがより汎用的になる.

たとえば List に対して定義した以下のsum関数は,Seq, Vector, Range 等に対して適用できない.

scala> def sum(xs: List[Int]): Int = if (xs.isEmpty) 0 else xs.head + sum(xs.tail)
sum: (xs: List[Int])Int
scala> sum(List(3,1,4,2))
res: Int = 10
scala> sum(Seq(3,1,4,2))
<console>:9: error: type mismatch;
scala> sum(1 to 10)
<console>:9: error: type mismatch;

一方,Seq は List, Vector, Range 等の上位クラスであるので, いずれにも適用できる.

scala> def sum(xs: Seq[Int]): Int = if (xs.isEmpty) 0 else xs.head + sum(xs.tail)
sum: (xs: Seq[Int])Int
scala> sum(List(3,1,4,2))
res: Int = 10
scala> sum(Seq(3,1,4,2))
res: Int = 10
scala> sum(1 to 10)
res: Int = 55

ただし Seq では :: で先頭要素を追加することはできない. 代わりに +: を用いる必要がある.

scala> 0 :: Seq(3,1,4)
<console>:8: error: value :: is not a member of Seq[Int]
scala> 0 +: Seq(3,1,4)
res: Seq[Int] = List(0, 3, 1, 4)

また :: をパターンマッチに用いることはできない. 代わりに以下のように記述する必要がある.

def sum(xs: Seq[Int]): Int =
  xs match {
    case Seq() => 0
    case Seq(x, xs1 @ _*) => x + sum(xs1)
  }

ここで Seq() は空の Seq のパターンを表し, Seq(x, xs1 @ _*) は先頭が x で残りが xs1 となっている Seq のパターンを表す. 正確には _* が残りを表すパターンで, @ の前の xs1 がパターンに一致する部分の名前を表している.

7.1 練習問題

  1. Seqに対するlast関数を定義せよ.
    (解答例)
    def last[A](xs: Seq[A]): A =
      xs match {
        case Seq(x) => x
        case Seq(_, xs1 @ _*) => last(xs1)
      }
    

Date: 2020-12-15 00:30:57 JST

Author: 田村直之

Validate XHTML 1.0