test_pycosat.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import sys
  2. import copy
  3. import random
  4. from os.path import basename
  5. import unittest
  6. import pycosat
  7. from pycosat import solve, itersolve
  8. # -------------------------- utility functions ---------------------------
  9. def read_cnf(path):
  10. """
  11. read a DIMACS cnf formatted file from `path`, and return the clauses
  12. and number of variables
  13. """
  14. clauses = []
  15. for line in open(path):
  16. parts = line.split()
  17. if not parts or parts[0] == 'c':
  18. continue
  19. if parts[0] == 'p':
  20. assert len(parts) == 4
  21. assert parts[1] == 'cnf'
  22. n_vars, n_clauses = [int(n) for n in parts[2:4]]
  23. continue
  24. if parts[0] == '%':
  25. break
  26. assert parts[-1] == '0'
  27. clauses.append([int(lit) for lit in parts[:-1]])
  28. assert len(clauses) == n_clauses
  29. return clauses, n_vars
  30. def evaluate(clauses, sol):
  31. """
  32. evaluate the clauses with the solution
  33. """
  34. sol_vars = {} # variable number -> bool
  35. for i in sol:
  36. sol_vars[abs(i)] = bool(i > 0)
  37. return all(any(sol_vars[abs(i)] ^ bool(i < 0) for i in clause)
  38. for clause in clauses)
  39. def py_itersolve(clauses):
  40. while True:
  41. sol = pycosat.solve(clauses)
  42. if isinstance(sol, list):
  43. yield sol
  44. clauses.append([-x for x in sol])
  45. else: # no more solutions -- stop iteration
  46. return
  47. def process_cnf_file(path):
  48. sys.stdout.write('%30s: ' % basename(path))
  49. sys.stdout.flush()
  50. clauses, n_vars = read_cnf(path)
  51. sys.stdout.write('vars: %6d cls: %6d ' % (n_vars, len(clauses)))
  52. sys.stdout.flush()
  53. n_sol = 0
  54. for sol in itersolve(clauses, n_vars):
  55. sys.stdout.write('.')
  56. sys.stdout.flush()
  57. assert evaluate(clauses, sol)
  58. n_sol += 1
  59. sys.stdout.write("%d\n" % n_sol)
  60. sys.stdout.flush()
  61. return n_sol
  62. # -------------------------- test clauses --------------------------------
  63. # p cnf 5 3
  64. # 1 -5 4 0
  65. # -1 5 3 4 0
  66. # -3 -4 0
  67. nvars1, clauses1 = 5, [[1, -5, 4], [-1, 5, 3, 4], [-3, -4]]
  68. # p cnf 2 2
  69. # -1 0
  70. # 1 0
  71. nvars2, clauses2 = 2, [[-1], [1]]
  72. # p cnf 2 3
  73. # -1 2 0
  74. # -1 -2 0
  75. # 1 -2 0
  76. nvars3, clauses3 = 2, [[-1, 2], [-1, -2], [1, -2]]
  77. # -------------------------- actual unit tests ---------------------------
  78. tests = []
  79. class TestSolve(unittest.TestCase):
  80. def test_wrong_args(self):
  81. self.assertRaises(TypeError, solve, [[1, 2], [-3]], 'A')
  82. self.assertRaises(TypeError, solve, 1)
  83. self.assertRaises(TypeError, solve, 1.0)
  84. self.assertRaises(TypeError, solve, object())
  85. self.assertRaises(TypeError, solve, ['a'])
  86. self.assertRaises(TypeError, solve, [[1, 2], [3, None]], 5)
  87. self.assertRaises(ValueError, solve, [[1, 2], [3, 0]])
  88. def test_no_clauses(self):
  89. for n in range(7):
  90. self.assertEqual(solve([], n), [-i for i in range(1, n + 1)])
  91. def test_cnf1(self):
  92. self.assertEqual(solve(clauses1), [1, -2, -3, -4, 5])
  93. if sys.version_info[0] == 2:
  94. cls = [[long(lit) for lit in clause] for clause in clauses1]
  95. self.assertEqual(solve(cls), [1, -2, -3, -4, 5])
  96. def test_iter_clauses(self):
  97. self.assertEqual(solve(iter(clauses1)), [1, -2, -3, -4, 5])
  98. def test_each_clause_iter(self):
  99. self.assertEqual(solve([iter(clause) for clause in clauses1]),
  100. [1, -2, -3, -4, 5])
  101. def test_tuple_caluses(self):
  102. self.assertEqual(solve(tuple(clauses1)), [1, -2, -3, -4, 5])
  103. def test_each_clause_tuples(self):
  104. self.assertEqual(solve([tuple(clause) for clause in clauses1]),
  105. [1, -2, -3, -4, 5])
  106. def test_gen_clauses(self):
  107. def gen_clauses():
  108. for clause in clauses1:
  109. yield clause
  110. self.assertEqual(solve(gen_clauses()), [1, -2, -3, -4, 5])
  111. def test_each_clause_gen(self):
  112. self.assertEqual(solve([(x for x in clause) for clause in clauses1]),
  113. [1, -2, -3, -4, 5])
  114. def test_bad_iter(self):
  115. class Liar:
  116. def __iter__(self): return None
  117. self.assertRaises(TypeError, solve, Liar())
  118. def test_cnf2(self):
  119. self.assertEqual(solve(clauses2), "UNSAT")
  120. def test_cnf3(self):
  121. self.assertEqual(solve(clauses3), [-1, -2])
  122. def test_cnf3_3vars(self):
  123. self.assertEqual(solve(clauses3, vars=3), [-1, -2, -3])
  124. def test_cnf1_prop_limit(self):
  125. for lim in range(1, 20):
  126. self.assertEqual(solve(clauses1, prop_limit=lim),
  127. "UNKNOWN" if lim < 8 else [1, -2, -3, -4, 5])
  128. def test_cnf1_vars(self):
  129. self.assertEqual(solve(clauses1, vars=7),
  130. [1, -2, -3, -4, 5, -6, -7])
  131. tests.append(TestSolve)
  132. # -----
  133. class TestIterSolve(unittest.TestCase):
  134. def test_wrong_args(self):
  135. self.assertRaises(TypeError, itersolve, [[1, 2], [-3]], 'A')
  136. self.assertRaises(TypeError, itersolve, 1)
  137. self.assertRaises(TypeError, itersolve, 1.0)
  138. self.assertRaises(TypeError, itersolve, object())
  139. self.assertRaises(TypeError, itersolve, ['a'])
  140. self.assertRaises(TypeError, itersolve, [[1, 2], [3, None]], 5)
  141. self.assertRaises(ValueError, itersolve, [[1, 2], [3, 0]])
  142. def test_no_clauses(self):
  143. for n in range(7):
  144. self.assertEqual(len(list(itersolve([], vars=n))), 2 ** n)
  145. def test_iter_clauses(self):
  146. self.assertTrue(all(evaluate(clauses1, sol) for sol in
  147. itersolve(iter(clauses1))))
  148. def test_each_clause_iter(self):
  149. self.assertTrue(all(evaluate(clauses1, sol) for sol in
  150. itersolve([iter(clause) for clause in clauses1])))
  151. def test_tuple_caluses(self):
  152. self.assertTrue(all(evaluate(clauses1, sol) for sol in
  153. itersolve(tuple(clauses1))))
  154. def test_each_clause_tuples(self):
  155. self.assertTrue(all(evaluate(clauses1, sol) for sol in
  156. itersolve([tuple(clause) for clause in clauses1])))
  157. def test_gen_clauses(self):
  158. def gen_clauses():
  159. for clause in clauses1:
  160. yield clause
  161. self.assertTrue(all(evaluate(clauses1, sol) for sol in
  162. itersolve(gen_clauses())))
  163. def test_each_clause_gen(self):
  164. self.assertTrue(all(evaluate(clauses1, sol) for sol in
  165. itersolve([(x for x in clause) for clause in
  166. clauses1])))
  167. def test_bad_iter(self):
  168. class Liar:
  169. def __iter__(self): return None
  170. self.assertRaises(TypeError, itersolve, Liar())
  171. def test_cnf1(self):
  172. for sol in itersolve(clauses1, nvars1):
  173. #sys.stderr.write('%r\n' % repr(sol))
  174. self.assertTrue(evaluate(clauses1, sol))
  175. sols = list(itersolve(clauses1, vars=nvars1))
  176. self.assertEqual(len(sols), 18)
  177. # ensure solutions are unique
  178. self.assertEqual(len(set(tuple(sol) for sol in sols)), 18)
  179. def test_shuffle_clauses(self):
  180. ref_sols = set(tuple(sol) for sol in itersolve(clauses1))
  181. for _ in range(10):
  182. cnf = copy.deepcopy(clauses1)
  183. # shuffling the clauses does not change the solutions
  184. random.shuffle(cnf)
  185. self.assertEqual(set(tuple(sol) for sol in itersolve(cnf)),
  186. ref_sols)
  187. def test_many_clauses(self):
  188. ref_sols = set(tuple(sol) for sol in itersolve(clauses1))
  189. # repeating the clauses many times does not change the solutions
  190. cnf = 100 * copy.deepcopy(clauses1)
  191. self.assertEqual(set(tuple(sol) for sol in itersolve(cnf)),
  192. ref_sols)
  193. def test_cnf2(self):
  194. self.assertEqual(list(itersolve(clauses2, nvars2)), [])
  195. def test_cnf3_3vars(self):
  196. self.assertEqual(list(itersolve(clauses3, 3)),
  197. [[-1, -2, -3], [-1, -2, 3]])
  198. def test_cnf1_prop_limit(self):
  199. self.assertEqual(list(itersolve(clauses1, prop_limit=2)), [])
  200. tests.append(TestIterSolve)
  201. # ------------------------------------------------------------------------
  202. def run(verbosity=1, repeat=1):
  203. print("sys.prefix: %s" % sys.prefix)
  204. print("sys.version: %s" % sys.version)
  205. print("pycosat version: %r" % pycosat.__version__)
  206. suite = unittest.TestSuite()
  207. for cls in tests:
  208. for _ in range(repeat):
  209. suite.addTest(unittest.makeSuite(cls))
  210. runner = unittest.TextTestRunner(verbosity=verbosity)
  211. return runner.run(suite)
  212. if __name__ == '__main__':
  213. if len(sys.argv) == 1:
  214. run()
  215. else:
  216. for path in sys.argv[1:]:
  217. process_cnf_file(path)