リスト内包表記の活用 - 数独ソルバー

Peter Norvig 氏の Solving Every Sudoku Puzzle というエッセイで、数独の解き方が Python を使って示されています。

ちょうど SRFI-42 (eager comprehension) というライブラリを使ってみたいなと思っていたところに見つけたので、ジェネレータ式というものが多用されているこの Python コードは恰好の題材でした。

ということで、原文の流れに沿いながら Scheme に訳していきたいと思います。


まず用語を紹介しておきます。

9x9マスの縦の列を column、横の列(行)を row と言うのは自明と思いますが、3x3 のまとまりは block、そして各マスを square と呼んでいます。

さらに、row、column、block それぞれ9マスずつのまとまりを unit、1つの square が所属する unit 内の他の square のことを peer と呼びます。

これらの情報を Scheme で表現するにあたって、1つだけ悩んだ点がありました。それは、 Python 版では square や square のリスト (column など) を文字列として表していることです。特に、

for c in "string"

のように、for 構文で文字列から一字ずつ取り出せたりするのは大変便利なんですが、またその文字自体を for で反復できることに戸惑ってしまいました。ここでの c は (文字列型と区別される) 文字型の値ではなく、1字の文字列なんですね?

そんな不慣れな部分もあり、Scheme への置き換えにいきなり躓いたんですが、結局同値・存在チェックのし易さを優先して、square はシンボル、unit はシンボルのリストとして、数字は数値型で表すことにしました。


まずは行と列を作ります:

(define rows
  (list-ec (:char-range c #\A #\I)
           (string->symbol (string c))))

(define cols
  (list-ec (:range i 1 10) i))

(define digits cols)

rows は A から I までのシンボルのリスト、cols は 1 から 9 までの数値リストです。リスト内包の最も基本的な例になっています。

list-ec というのが文字通りリストを作る構文で、コロンから始まるジェネレータというものでリストの中身のデータを生成します。Python とは順番が逆になっているんですが、それ以降の式でフィルタリングなどを (必要ならば) して、最後の式でリスト要素の値を与えます。

:range ジェネレータは指定範囲の整数値を生成するもので、上限値 - 1 までの値を生成します。それに対し :char-range ジェネレータは上限値を含むことに注意してください。

次に、行と列に基づいて square のリストを作ります:

(define (square i j)
  (string->symbol (format "~a~a" i j)))

(define (cross A B)
  (list-ec (:list a A)
           (:list b B)
           (square a b)))

(define squares (cross rows cols))

square は A1、A2…のように行と列のインデックスの組合わせで表し、ハッシュ表によって各 square にアクセスする方法を取っていきます。

(テーブル全体を 2次元ベクタ (配列) で表す数独プログラムが多いと思いますが、それだとコードが煩雑になりがちです。賢明な方法だと思います)

ここでは :list ジェネレータのネストによってリスト A と B の要素の組み合わせを作っています。これで (A1 A2 ... I9) まで、長さ 81 のリストの出来上がりです。

ちなみに map で同じことをするならば

(foldr append
       '()
       (map (lambda (a)
              (map (lambda (b)
                     (square a b))
                   cols))
            rows))

と、ちょっとややこしいことになります。3重対を作ろうと思ったらもっと大変です。内包表記を使うと map のネストが直線的に表現できて非常に便利なわけです (filter のネストに関しても同じことが言えます)。

次に unit の定義です:

(define (split l n)
  (let ((len (length l)))
    (let lp ((r '()) (i 0))
      (if (>= i len)
          (reverse r)
          (lp (cons (take (drop l i) n)
                    r)
              (+ i n))))))

(define unitlist
  (append (list-ec (:list c cols)
                   (cross rows (list c)))
          (list-ec (:list r rows)
                   (cross (list r) cols))
          (list-ec (:list rs (split rows 3))
                   (:list cs (split cols 3))
                   (cross rs cs))))

split はリストを分割する関数です。(split rows 3) は

((A B C) (D E F) (G H I))

を返します (3 は分割の個数ではなく各小リストの要素数です)。手作業で '((A B C)...) と書くのが面倒だったのでわざわざ作りました (変ですか?)。

unitlist として、row, column, block それぞれの unit のリストを作っています。ここから例えば A1 が所属する unit を抜き出してみると

(list-ec (:list u unitlist)
         (if (memq 'A1 u))
         u)
; =>
((A1 B1 C1 D1 E1 F1 G1 H1 I1)
 (A1 A2 A3 A4 A5 A6 A7 A8 A9)
 (A1 A2 A3 B1 B2 B3 C1 C2 C3))

上から1列目、A行目、そして左上のブロックの各 unit となっています。

これに基づき、各 square に対応する unit, peer をハッシュ表に登録していきます:

(define-syntax hash-ec
  (syntax-rules ()
    ((hash-ec e ... (k v))
     (let ((hsh (hash)))
       (do-ec e ... (hsh k v))
       hsh))))

(define units
  (hash-ec (:list s squares)
           (s (list-ec (:list u unitlist)
                       (if (memq s u))
                       u))))

hash-ec として、ハッシュ表を作るための新しい構文を定義しています。Python の dict 関数でジェネレータ式を使っているのを真似たかったんです。

do-ec は副作用のための構文です。最後の式はコマンドとして実行され、値を返しません。

list-ec では if というキーワードを使っています。(if exp) で、exp が偽でない時にその次の式へ進む、という動作をします。Haskellモナドにおける guard 関数みたいな感覚でしょうか。働きとしては filter 関数と同等のものと考えることができます。

同種の論理演算子に not, and, or があります。


peer は重複を省いた1重のリストです:

(define (adjoin x l)
  (if (memq x l) l (cons x l)))

(define-syntax set-ec
  (syntax-rules ()
    ((set-ec e1 e2 ...)
     (reverse (fold-ec '() e1 e2 ... adjoin)))))

(define peers
  (hash-ec (:list s squares)
           (s (set-ec (:list u (units s))
                      (:list s2 u)
                      (not (eq? s2 s))
                      s2))))

次に、問題をパースする関数です:

(define/kw (join l #:optional (sep ""))
  (parameterize ((current-output-port (open-output-string)))
    (let lp ((l l))
      (cond ((null? l)
             (get-output-string (current-output-port)))
            (else
             (display (car l))
             (if (pair? (cdr l))
                 (display sep))
             (lp (cdr l)))))))

(define (parse-grid grid)
  (define oks
    (string->list (join (append '(#\. 0) digits))))
  (let/ec return
    (let ((grid (list-ec (:string c grid)
                         (if (memv c oks))
                         (string->number (string c))))
          (vals (hash-ec (:list s squares)
                         (s digits))))
      (do-ec (:list s+d (zip squares grid))
             (:let s (car s+d))
             (:let d (cadr s+d))
             (if (memv d digits))
             (not (assign vals s d))
             (return #f))
      vals)))

:let が新出ですが、新しい変数を導入する特殊ジェネレータです。

join は文字や文字列、数値などのリストを文字列として連結する関数です。Python ではセパレータ文字に対するメソッドですが、ここではセパレータはオプショナル引数としました。

なお、let/ec は脱出のための継続を捉える PLT Scheme の特殊構文です。

parse-grid で問題の文字列を受け取ります。1 から 9 までの数字のほか、穴の部分の文字として 0 と . (ピリオド) が有効になっています。テーブルの上から下に向かって、行単位では左から右に、順序良く 81 字並んだ文字列が正しい入力となります。


例:

"4.....8.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......"

有効な文字以外は無視されるので、テーブルの形のままでも構いません:

"+-------+-------+-------+
 | . . 3 | 2 . . | . . . |
 | . 4 . | . 9 . | . . . |
 | 6 . . | . . 8 | . 1 . |
 +-------+-------+-------+
 | 2 . . | . . . | . . 3 |
 | . 1 . | . . 6 | . 4 . |
 | . . 7 | . . . | 5 . . |
 +-------+-------+-------+
 | . . . | . . 1 | . . 2 |
 | . 9 . | . 4 . | . 6 . |
 | . . . | 5 . . | 7 . . |
 +-------+-------+-------+"

emacssudoku モードを実行中ならば、scratch バッファで

(mapconcat #'identity
           (mapcar (lambda (row)
                     (mapconcat (lambda (x)
                                  (format "%s" x))
                                row
                                ""))
                   current-board)
           "")

を評価すると問題の文字列が得られます。


次に、この数独ソルバーの肝である制約システムの実装です。

(define (assign vals s d)
  (and (every?-ec (:list d2 (vals s))
                  (not (= d2 d))
                  (eliminate vals s d2))
       vals))

(define (eliminate vals s d)
  (let/ec return
    (when (memv d (vals s))
      (vals s (delete d (vals s)))
      (case (length (vals s))
        ((0) (return #f))
        ((1) (let ((d2 (car (vals s))))
               (unless (every?-ec (:list s2 (peers s))
                                  (eliminate vals s2 d2))
                 (return #f)))))
      (do-ec (:list u (units s))
             (:let dplaces
                   (list-ec (:list s u)
                            (if (memv d (vals s)))
                            s))
             (case (length dplaces)
               ((0) (return #f))
               ((1) (unless (assign vals (car dplaces) d)
                      (return #f))))))
    vals))

簡単に説明します。

入力をパースした時点で、全 square に対して 1 から 9 までのリストが、そこに入る数字の候補として割り当てられています。そして、最初から数字が与えられている square に対しては候補をその数字1つに決定していきます (assign)。

(実際には、assign は与えられた数字を直接 square に割り当てるのではなく、square からそれ以外の数字を取り除くことによって候補を1つに絞り込む、というやり方をしています。)

一方で、候補が決定した square に対応する peer にはその数字が入らないことは明らかなので、全 peer に対してその数字を取り除く (eliminate) 操作を行います。

といったことを相互再帰的に行う関数のペアです。


新しい comprehension 構文として every?-ec が登場しています。最後の式が全ての要素に対して真 (#f でない値) となる時のみ真を返す、というものです。Python の all という関数に対応します。同種の構文に any?-ec があります。


次に、パースした結果を表示する関数です。

(define (repeat x n)
  (list-ec (:range i n) x))

(define (center str width)
  (define (pad-size)
    (let ((size (- width (string-length str))))
      (if (> size 0)
          (quotient/remainder size 2)
          (values 0 0))))
  (define (padding n)
    (string-ec (:range i n) #\ ))
  (receive (q r) (pad-size)
    (string-append (padding q)
                   str
                   (padding (+ q r)))))

(define (print-board vals)
  (let* ((width (add1 (max-ec (:list s squares)
                              (length (vals s)))))
         (line (string-append
                "\n"
                (join (repeat
                       (string-ec (:range i (* width 3))
                                  #\-)
                       3)
                      "+")))
         (c-sep '(3 6))
         (r-sep (list-ec (:list i c-sep)
                         (list-ref rows (sub1 i)))))
    (do-ec (:list r rows)
           (begin
             (newline)
             (do-ec (:list c cols)
                    (begin
                      (display (center (join (vals (square r c)))
                                       width))
                      (display (if (memv c c-sep)
                                   "|" "")))))
           (display (if (memq r r-sep)
                        line "")))
    (newline)))

新しい comprehension 構文が2つ出てきました。max-ec は最大値を返す構文、それと文字を連結する string-ec です。文字列を連結する string-append-ec という構文もあります。

print-board の1つ目の do-ec の中で使われている begin は Scheme の begin ではなく、comprehension 構文の中に属するキーワードです。意味は Scheme と同じです。


初級レベルの問題を表示してみましょう:

(print-board
 (parse-grid
  "..3.2.6..9..3.5..1..18.64....81.29..7.......8..67.82....26.95..8..2.3..9..5.1.3.."))
4 8 3 |9 2 1 |6 5 7
9 6 7 |3 4 5 |8 2 1
2 5 1 |8 7 6 |4 9 3
------+------+------
5 4 8 |1 3 2 |9 7 6
7 2 9 |5 6 4 |1 3 8
1 3 6 |7 9 8 |2 4 5
------+------+------
3 7 2 |6 8 9 |5 1 4
8 1 4 |2 5 3 |7 6 9
6 9 5 |4 1 7 |3 8 2

パースだけで全ての穴が埋まってしまいました。

超難問で試してみましょう:

(print-board
 (parse-grid
  (string-append-ec
   (:list r '((0 0 3 2 0 0 0 0 0)
              (0 4 0 0 9 0 0 0 0)
              (6 0 0 0 0 8 0 1 0)
              (2 0 0 0 0 0 0 0 3)
              (0 1 0 0 0 6 0 4 0)
              (0 0 7 0 0 0 5 0 0)
              (0 0 0 0 0 1 0 0 2)
              (0 9 0 0 4 0 0 6 0)
              (0 0 0 5 0 0 7 0 0)))
   (string-append-ec (:list s r) (number->string s)))))
 15789   578     3   |   2    1567    457  | 4689   5789  456789
 1578     4    1258  | 1367     9     357  | 2368   23578  5678
   6     257    259  |  347    357     8   | 2349     1    4579
---------------------+---------------------+---------------------
   2     568   45689 | 14789  1578   4579  | 1689    789     3
 3589     1     589  | 3789   23578    6   |  289     4     789
 3489    368     7   | 13489  1238   2349  |   5     289   1689
---------------------+---------------------+---------------------
 34578  35678  4568  | 36789  3678     1   | 3489   3589     2
 13578    9    1258  |  378     4     237  |  138     6     158
 1348   2368   12468 |   5    2368    239  |   7     389   1489

数字が与えられていない所では、最も少ないところで候補が3つまでしか絞り込めていません。人間が解こうと思っても殆ど無理な状況ではないでしょうか。試しに emacs で解いてみてください。sudoku モード実行中に以下を評価し、Ctrl-C Ctrl-R でプレイできます。

(setq start-board
      '((0 0 3 2 0 0 0 0 0)
        (0 4 0 0 9 0 0 0 0)
        (6 0 0 0 0 8 0 1 0)
        (2 0 0 0 0 0 0 0 3)
        (0 1 0 0 0 6 0 4 0)
        (0 0 7 0 0 0 5 0 0)
        (0 0 0 0 0 1 0 0 2)
        (0 9 0 0 4 0 0 6 0)
        (0 0 0 5 0 0 7 0 0)))

このような難問にアタックする関数が、以下の search です:

(define-syntax some-ec
  (syntax-rules ()
    ((some-ec e1 ... e2)
     (let/ec return
       (do-ec e1 ...
              (:let v e2)
              (if v)
              (return v))
       #f))))

(define (square/least-digits vals)
  (cdr
   (fold-ec '(10 . dummy)
            (:list s squares)
            (:let len (length (vals s)))
            (if (> len 1))
            s
            (lambda (s res)
              (if (< len (car res))
                  (cons len s)
                  res)))))

(define (search vals)
  (cond ((not vals) #f)
        ((every?-ec (:list s squares)
                    (= (length (vals s)) 1))
         vals)
        (else
         (let ((s (square/least-digits vals)))
           (some-ec (:list d (vals s))
                    (search (assign (hash-copy vals) s d)))))))

まず some-ec という構文を定義しています。ブール値を返しそうな名前ですが、最初に見つかった偽でない値を返すものです。

search の定義を見ると、この関数は全 square の数字 ((vals s) で square `s' に入っている数字の候補が得られます) の個数が1つになった時のみ vals (ハッシュ表) を返し、それ以外は #f を返すことが分かります。

some-ec 内で search が再帰的に呼ばれることによって解の探索が行われるわけですが、その際最も残りの候補数の少ない square から試行を始めることで、劇的に計算量を減らすことができるそうです。

また、ハッシュ表のコピーを作って再帰することにより、失敗を気にすることなく自然にバックトラックすることが可能になっています。

これで、どんな難しい問題でもほぼ一瞬で答えを得られるようになりました。

以上のまとめとして、http://norvig.com/top95.txt の形式のファイルを受け取って一気に解く関数を作りました:

(define (solve-sudoku grid)
  (cond ((parse-grid grid) => search) (else #f)))

(define/kw (solve-file path #:optional (action void))
  (call-with-input-file path
    (lambda (in)
      (let ((results
             (list-ec (:port grid (index i) in read-line)
                      (:let res (solve-sudoku grid))
                      (begin (printf "~%;; Problem ~a ~aed"
                                     (add1 i)
                                     (if res "solv" "fail"))
                             (if res (action res)))
                      res)))
        (printf ";; Got ~a out of ~a~%"
                (sum-ec (:list r results) (if r) 1)
                (length results))
        results))))

ここでは :port ジェネレータを使ってみました。(index i) はオプションで、:list とか :range その他ジェネレータ一般で使えます。変数名の後に付けることで現在の反復の回数が得られます。ここではファイルの行数を数えているわけです。

:port の最後の引数にはポートを読む関数を渡します。read-char で文字単位、read で S 式単位での読み込みとなります。

sum-ec は数値シークエンスの総和を返す構文です。


結果:
scheme-sudoku


最後に、ここで使っているハッシュ表ライブラリを貼っておきます。

;; Adapted from:
;; http://schemewiki.org/view/Cookbook/ApplicableHashTables
(module hash mzscheme
  (require (lib "kw.ss"))

  (define-values (s:hash s:make-hash s:hash? s:hash-ref s:hash-set!)
    (make-struct-type 'hash #f 2 0 #f '() #f
                      (case-lambda
                        ((hsh key)
                         (hash-table-get (fst hsh) key (snd hsh)))
                        ((hsh key val)
                         (hash-table-put! (fst hsh) key val)
                         val))))

  (define (fst hsh) (s:hash-ref hsh 0))
  (define (snd hsh) (s:hash-ref hsh 1))
  (provide (rename fst hash->hash-table))

  (define/kw (hash #:key default #:body args)
    (s:make-hash (apply make-hash-table args)
                 default))
  (provide hash
           (rename s:hash? hash?))

  (define (hash-copy hsh)
    (s:make-hash (hash-table-copy (fst hsh))
                 (snd hsh)))
  (provide hash-copy)

  (define-syntax (provide-prims x)
    (define ((prepend str) x)
      (datum->syntax-object
       x
       (string->symbol (format "~a-~a" str (syntax-e x)))))
    (syntax-case x ()
      [(provide-prims prims ...)
       (let* ([prims (syntax->list #'(prims ...))]
              [defs (map (lambda (my mz)
                           #`(begin
                               (define (#,my hsh . args)
                                 (apply #,mz (fst hsh) args))
                               (provide #,my)))
                         (map (prepend "hash") prims)
                         (map (prepend "hash-table") prims))])
         #`(begin #,@defs))]))

  (provide-prims map for-each get put! remove! count
                 iterate-first iterate-key iterate-next iterate-value))