1- import time
21import re
3- from pysat .formula import CNF
4- from pysat .solvers import Solver
5- from pysat .card import *
6- import clingo
7- import operator
2+ from typing import Optional , Sequence , Set
3+
84import numbers
5+ import operator
6+ import re
7+ from collections import defaultdict
8+ from typing import Optional , Sequence , Set
9+
10+ import clingo
911import clingo .script
1012import pkg_resources
11- from collections import defaultdict
12- from . util import rule_is_recursive , Constraint , Literal , format_rule , remap_variables
13+
14+ from .util import rule_is_recursive , Constraint , Literal , remap_variables
15+
1316clingo .script .enable_python ()
14- from clingo import Function , Number , Tuple_
17+ from clingo import Function , Number , Tuple_ , Model , Symbol
1518from itertools import permutations
1619import dataclasses
1720from . abs_generate import Generator as AbstractGenerator
21+ from . abs_generate import Rule , RuleBase
1822
1923@dataclasses .dataclass (frozen = True )
2024class Var :
@@ -78,6 +82,7 @@ def build_rule_literals(rule, rule_var, pi=False):
7882
7983class Generator (AbstractGenerator ):
8084
85+ model : Optional [Model ]
8186 def __init__ (self , settings , bkcons = []):
8287 self .savings = 0
8388 self .settings = settings
@@ -245,13 +250,13 @@ def __init__(self, settings, bkcons=[]):
245250 self .solver = solver
246251
247252 # @profile
248- def get_prog (self ):
253+ def get_prog (self ) -> Optional [ RuleBase ] :
249254 if self .handle is None :
250255 self .handle = iter (self .solver .solve (yield_ = True ))
251256 self .model = next (self .handle , None )
252257 if self .model is None :
253258 return None
254- atoms = self .model .symbols (shown = True )
259+ atoms : Sequence [ Symbol ] = self .model .symbols (shown = True )
255260
256261 if self .settings .single_solve :
257262 return self .parse_model_single_rule (atoms )
@@ -271,7 +276,7 @@ def gen_symbol(self, literal, backend):
271276 self .seen_symbols [k ] = symbol
272277 return symbol
273278
274- def parse_model_recursion (self , model ):
279+ def parse_model_recursion (self , model ) -> RuleBase :
275280 settings = self .settings
276281 rule_index_to_body = defaultdict (set )
277282 head = settings .head_literal
@@ -293,21 +298,21 @@ def parse_model_recursion(self, model):
293298
294299 return frozenset (prog )
295300
296- def parse_model_single_rule (self , model ) :
301+ def parse_model_single_rule (self , model : Sequence [ Symbol ]) -> RuleBase :
297302 settings = self .settings
298- head = settings .head_literal
299- body = set ()
303+ head : Literal = settings .head_literal
304+ body : Set [ Literal ] = set ()
300305 cached_literals = settings .cached_literals
301306 for atom in model :
302307 args = atom .arguments
303308 predicate = args [1 ].name
304309 atom_args = tuple (args [3 ].arguments )
305310 literal = cached_literals [(predicate , atom_args )]
306311 body .add (literal )
307- rule = head , frozenset (body )
312+ rule : Rule = head , frozenset (body )
308313 return frozenset ([rule ])
309314
310- def parse_model_pi (self , model ):
315+ def parse_model_pi (self , model ) -> RuleBase :
311316 settings = self .settings
312317 # directions = defaultdict(lambda: defaultdict(lambda: '?'))
313318 rule_index_to_body = defaultdict (set )
0 commit comments