sudoku.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. """
  2. The implementation of this Sudoku solver is based on the paper:
  3. "A SAT-based Sudoku solver" by Tjark Weber
  4. https://www.lri.fr/~conchon/mpri/weber.pdf
  5. If you want to understand the code below, in particular the function valid(),
  6. which calculates the 324 clauses corresponding to 9 cells, you are strongly
  7. encouraged to read the paper first. The paper is very short, but contains
  8. all necessary information.
  9. """
  10. import pycosat
  11. def v(i, j, d):
  12. """
  13. Return the number of the variable of cell i, j and digit d,
  14. which is an integer in the range of 1 to 729 (including).
  15. """
  16. return 81 * (i - 1) + 9 * (j - 1) + d
  17. def sudoku_clauses():
  18. """
  19. Create the (11745) Sudoku clauses, and return them as a list.
  20. Note that these clauses are *independent* of the particular
  21. Sudoku puzzle at hand.
  22. """
  23. res = []
  24. # for all cells, ensure that the each cell:
  25. for i in range(1, 10):
  26. for j in range(1, 10):
  27. # denotes (at least) one of the 9 digits (1 clause)
  28. res.append([v(i, j, d) for d in range(1, 10)])
  29. # does not denote two different digits at once (36 clauses)
  30. for d in range(1, 10):
  31. for dp in range(d + 1, 10):
  32. res.append([-v(i, j, d), -v(i, j, dp)])
  33. def valid(cells):
  34. # Append 324 clauses, corresponding to 9 cells, to the result.
  35. # The 9 cells are represented by a list tuples. The new clauses
  36. # ensure that the cells contain distinct values.
  37. for i, xi in enumerate(cells):
  38. for j, xj in enumerate(cells):
  39. if i < j:
  40. for d in range(1, 10):
  41. res.append([-v(xi[0], xi[1], d), -v(xj[0], xj[1], d)])
  42. # ensure rows and columns have distinct values
  43. for i in range(1, 10):
  44. valid([(i, j) for j in range(1, 10)])
  45. valid([(j, i) for j in range(1, 10)])
  46. # ensure 3x3 sub-grids "regions" have distinct values
  47. for i in 1, 4, 7:
  48. for j in 1, 4 ,7:
  49. valid([(i + k % 3, j + k // 3) for k in range(9)])
  50. assert len(res) == 81 * (1 + 36) + 27 * 324
  51. return res
  52. def solve(grid):
  53. """
  54. solve a Sudoku grid inplace
  55. """
  56. clauses = sudoku_clauses()
  57. for i in range(1, 10):
  58. for j in range(1, 10):
  59. d = grid[i - 1][j - 1]
  60. # For each digit already known, a clause (with one literal).
  61. # Note:
  62. # We could also remove all variables for the known cells
  63. # altogether (which would be more efficient). However, for
  64. # the sake of simplicity, we decided not to do that.
  65. if d:
  66. clauses.append([v(i, j, d)])
  67. # solve the SAT problem
  68. sol = set(pycosat.solve(clauses))
  69. def read_cell(i, j):
  70. # return the digit of cell i, j according to the solution
  71. for d in range(1, 10):
  72. if v(i, j, d) in sol:
  73. return d
  74. for i in range(1, 10):
  75. for j in range(1, 10):
  76. grid[i - 1][j - 1] = read_cell(i, j)
  77. def test():
  78. from pprint import pprint
  79. # hard Sudoku problem, see Fig. 3 in paper by Weber
  80. hard = [[0, 2, 0, 0, 0, 0, 0, 0, 0],
  81. [0, 0, 0, 6, 0, 0, 0, 0, 3],
  82. [0, 7, 4, 0, 8, 0, 0, 0, 0],
  83. [0, 0, 0, 0, 0, 3, 0, 0, 2],
  84. [0, 8, 0, 0, 4, 0, 0, 1, 0],
  85. [6, 0, 0, 5, 0, 0, 0, 0, 0],
  86. [0, 0, 0, 0, 1, 0, 7, 8, 0],
  87. [5, 0, 0, 0, 0, 9, 0, 0, 0],
  88. [0, 0, 0, 0, 0, 0, 0, 4, 0]]
  89. solve(hard)
  90. pprint(hard)
  91. assert [[1, 2, 6, 4, 3, 7, 9, 5, 8],
  92. [8, 9, 5, 6, 2, 1, 4, 7, 3],
  93. [3, 7, 4, 9, 8, 5, 1, 2, 6],
  94. [4, 5, 7, 1, 9, 3, 8, 6, 2],
  95. [9, 8, 3, 2, 4, 6, 5, 1, 7],
  96. [6, 1, 2, 5, 7, 8, 3, 9, 4],
  97. [2, 6, 9, 3, 1, 4, 7, 8, 5],
  98. [5, 4, 8, 7, 6, 9, 2, 3, 1],
  99. [7, 3, 1, 8, 5, 2, 6, 4, 9]] == hard