[mypy] Fixes typing errors in other/dpll (#5759)

+ As per usage examples, clause literals are a list of strings.
  + Note: symbols extracted from literals are expected to be exactly two characters.
+ self.literal boolean values are initialized to None, so must be optional
+ model values should be Booleans, but aren't guaranteed to be non-None
  in the code.
+ uses newer '... | None' annotation for Optional values
+ clauses are passed to the Formula initializer as both lists and sets, they
  are stored as lists.  Returned clauses will always be lists.
+ use explicit tuple annotation from __future__  rather than using (..., ...)
  in return signatures
+ mapping returned by dpll_algorithm is optional per the documentation.
This commit is contained in:
Andrew Grangaard 2021-11-03 13:32:49 -07:00 committed by GitHub
parent 765be4581e
commit 7954a3ae16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -11,6 +11,7 @@ For more information about the algorithm: https://en.wikipedia.org/wiki/DPLL_alg
from __future__ import annotations
import random
from typing import Iterable
class Clause:
@ -27,12 +28,12 @@ class Clause:
True
"""
def __init__(self, literals: list[int]) -> None:
def __init__(self, literals: list[str]) -> None:
"""
Represent the literals and an assignment in a clause."
"""
# Assign all literals to None initially
self.literals = {literal: None for literal in literals}
self.literals: dict[str, bool | None] = {literal: None for literal in literals}
def __str__(self) -> str:
"""
@ -52,7 +53,7 @@ class Clause:
"""
return len(self.literals)
def assign(self, model: dict[str, bool]) -> None:
def assign(self, model: dict[str, bool | None]) -> None:
"""
Assign values to literals of the clause as given by model.
"""
@ -68,7 +69,7 @@ class Clause:
value = not value
self.literals[literal] = value
def evaluate(self, model: dict[str, bool]) -> bool:
def evaluate(self, model: dict[str, bool | None]) -> bool | None:
"""
Evaluates the clause with the assignments in model.
This has the following steps:
@ -97,7 +98,7 @@ class Formula:
{{A1, A2, A3'}, {A5', A2', A1}} is ((A1 v A2 v A3') and (A5' v A2' v A1))
"""
def __init__(self, clauses: list[Clause]) -> None:
def __init__(self, clauses: Iterable[Clause]) -> None:
"""
Represent the number of clauses and the clauses themselves.
"""
@ -139,14 +140,14 @@ def generate_formula() -> Formula:
"""
Randomly generate a formula.
"""
clauses = set()
clauses: set[Clause] = set()
no_of_clauses = random.randint(1, 10)
while len(clauses) < no_of_clauses:
clauses.add(generate_clause())
return Formula(set(clauses))
return Formula(clauses)
def generate_parameters(formula: Formula) -> (list[Clause], list[str]):
def generate_parameters(formula: Formula) -> tuple[list[Clause], list[str]]:
"""
Return the clauses and symbols from a formula.
A symbol is the uncomplemented form of a literal.
@ -173,8 +174,8 @@ def generate_parameters(formula: Formula) -> (list[Clause], list[str]):
def find_pure_symbols(
clauses: list[Clause], symbols: list[str], model: dict[str, bool]
) -> (list[str], dict[str, bool]):
clauses: list[Clause], symbols: list[str], model: dict[str, bool | None]
) -> tuple[list[str], dict[str, bool | None]]:
"""
Return pure symbols and their values to satisfy clause.
Pure symbols are symbols in a formula that exist only
@ -198,11 +199,11 @@ def find_pure_symbols(
{'A1': True, 'A2': False, 'A3': True, 'A5': False}
"""
pure_symbols = []
assignment = dict()
assignment: dict[str, bool | None] = dict()
literals = []
for clause in clauses:
if clause.evaluate(model) is True:
if clause.evaluate(model):
continue
for literal in clause.literals:
literals.append(literal)
@ -225,8 +226,8 @@ def find_pure_symbols(
def find_unit_clauses(
clauses: list[Clause], model: dict[str, bool]
) -> (list[str], dict[str, bool]):
clauses: list[Clause], model: dict[str, bool | None]
) -> tuple[list[str], dict[str, bool | None]]:
"""
Returns the unit symbols and their values to satisfy clause.
Unit symbols are symbols in a formula that are:
@ -263,7 +264,7 @@ def find_unit_clauses(
Ncount += 1
if Fcount == len(clause) - 1 and Ncount == 1:
unit_symbols.append(sym)
assignment = dict()
assignment: dict[str, bool | None] = dict()
for i in unit_symbols:
symbol = i[:2]
assignment[symbol] = len(i) == 2
@ -273,8 +274,8 @@ def find_unit_clauses(
def dpll_algorithm(
clauses: list[Clause], symbols: list[str], model: dict[str, bool]
) -> (bool, dict[str, bool]):
clauses: list[Clause], symbols: list[str], model: dict[str, bool | None]
) -> tuple[bool | None, dict[str, bool | None] | None]:
"""
Returns the model if the formula is satisfiable, else None
This has the following steps: