Coverage for datesat / symbolic_bitvector / alpha_beta_table_bv.py: 83.4%
251 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-10 23:47 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-10 23:47 +0000
1"""
2Alpha-beta-table DateSAT using a 4-year (48-month) table.
4Representation:
5- alpha (months_var): months since epoch month 2000-03 (March 2000 = 0)
6- beta (beta_var): 0-based day index within that month (DOM = beta + 1)
8We avoid full ordinal decode by using a 48-month DIM/DBM table.
9"""
11from typing import List, Tuple, Union
13from z3 import (
14 UGE,
15 And,
16 BitVec,
17 BitVecRef,
18 BitVecSort,
19 BitVecVal,
20 BoolRef,
21 CheckSatResult,
22 If,
23 K,
24 ModelRef,
25 Not,
26 Optimize,
27 Or,
28 Select,
29 Solver,
30 Store,
31 sat,
32 unsat,
33)
34from ..symbolic_int.alpha_beta_table_int import build_dim_dbm_48_from_epoch
35from ..core import Date, Period
36from .bitwidths import LEGACY_BITS
38# Epoch constants as Python ints for table construction and concrete decoding
39_EPOCH_YEAR = 2000
40_EPOCH_MONTH = 3
41# Linearized epoch month as a Z3 BitVec numeral
42_EPOCH_LINEAR = BitVecVal(_EPOCH_YEAR * 12 + _EPOCH_MONTH, LEGACY_BITS)
43_FOUR_YEAR_MONTHS = 48
44_FOUR_YEAR_DAYS = 1461
45# Alpha bounds constants (months since epoch)
46_ALPHA_MIN = (1900 - _EPOCH_YEAR) * 12 + (3 - _EPOCH_MONTH) # -1200
47_ALPHA_MAX = (2100 - _EPOCH_YEAR) * 12 + (2 - _EPOCH_MONTH) # 1199
48_DIM48_LIST_PY, _DBM48_LIST_PY = build_dim_dbm_48_from_epoch()
51def const_array(values: list[int]) -> BitVecRef:
52 a = K(BitVecSort(LEGACY_BITS), BitVecVal(0, LEGACY_BITS))
53 for i, v in enumerate(values):
54 a = Store(a, BitVecVal(i, LEGACY_BITS), BitVecVal(v, LEGACY_BITS))
55 return a
58_DIM48_LIST = const_array(_DIM48_LIST_PY)
59_DBM48_LIST = const_array(_DBM48_LIST_PY)
62def mod48(x) -> BitVecRef:
63 return x % BitVecVal(_FOUR_YEAR_MONTHS, LEGACY_BITS)
66def _floor_div_12(x) -> BitVecRef:
67 """Implement floor division by 12 for bitvectors to match Python's // behavior."""
68 sign_bit = 2 ** (LEGACY_BITS - 1)
69 wrap_around = 2**LEGACY_BITS
70 is_negative = UGE(x, BitVecVal(sign_bit, LEGACY_BITS))
71 signed_x = If(is_negative, x - BitVecVal(wrap_around, LEGACY_BITS), x)
72 q_trunc = signed_x / BitVecVal(12, LEGACY_BITS)
73 r = signed_x % BitVecVal(12, LEGACY_BITS)
74 is_negative_and_has_remainder = And(
75 UGE(signed_x, BitVecVal(sign_bit, LEGACY_BITS)), r != BitVecVal(0, LEGACY_BITS)
76 )
77 q = If(is_negative_and_has_remainder, q_trunc - BitVecVal(1, LEGACY_BITS), q_trunc)
78 return q
81def _floor_div_four_year_days(x) -> BitVecRef:
82 """Implement floor division by FOUR_YEAR_DAYS for bitvectors to match Python's // behavior."""
83 sign_bit = 2 ** (LEGACY_BITS - 1)
84 wrap_around = 2**LEGACY_BITS
85 is_negative = UGE(x, BitVecVal(sign_bit, LEGACY_BITS))
86 signed_x = If(is_negative, x - BitVecVal(wrap_around, LEGACY_BITS), x)
87 q_trunc = signed_x / BitVecVal(_FOUR_YEAR_DAYS, LEGACY_BITS)
88 r = signed_x % BitVecVal(_FOUR_YEAR_DAYS, LEGACY_BITS)
89 is_negative_and_has_remainder = And(
90 UGE(signed_x, BitVecVal(sign_bit, LEGACY_BITS)), r != BitVecVal(0, LEGACY_BITS)
91 )
92 q = If(is_negative_and_has_remainder, q_trunc - BitVecVal(1, LEGACY_BITS), q_trunc)
93 return q
96def months_since_epoch_from_ym(y, m) -> BitVecRef:
97 return (y * BitVecVal(12, LEGACY_BITS) + m) - _EPOCH_LINEAR
100def alpha_to_abs_month(alpha) -> BitVecRef:
101 return alpha + _EPOCH_LINEAR
104def eom_clamp(dim, beta) -> BitVecRef:
105 return If(
106 beta < BitVecVal(0, LEGACY_BITS),
107 BitVecVal(0, LEGACY_BITS),
108 If(
109 beta > dim - BitVecVal(1, LEGACY_BITS),
110 dim - BitVecVal(1, LEGACY_BITS),
111 beta,
112 ),
113 )
116class DateVar:
117 """Symbolic date variable using alpha-beta representation.
119 alpha (months_var): months since epoch month 2000-03 (March 2000 = 0)
120 beta (beta_var): extra days within that month (0-based), so DOM = 1+beta
121 """
123 def __init__(self, name: str):
124 """Create a symbolic date variable."""
125 self.name = name
126 # Alpha: Z3 bitvector variable for months since epoch-month
127 self.months_var = BitVec(f"{name}_months", LEGACY_BITS)
128 # Beta: Z3 bitvector variable for extra days (0-based) within month
129 self.beta_var = BitVec(f"{name}_beta", LEGACY_BITS)
130 # Solver reference for adding bounds to intermediate dates (set after creation if needed)
131 self._solver = None
133 def __str__(self) -> str:
134 return f"DateVar({self.name})"
136 @property
137 def year(self) -> BitVecRef:
138 """Get symbolic year component (decodes from months_var)."""
139 k = self.months_var + _EPOCH_LINEAR
140 y = (k - BitVecVal(1, LEGACY_BITS)) / BitVecVal(12, LEGACY_BITS)
141 return y
143 @property
144 def month(self) -> BitVecRef:
145 """Get symbolic month component (decodes from months_var)."""
146 k = self.months_var + _EPOCH_LINEAR
147 y = (k - BitVecVal(1, LEGACY_BITS)) / BitVecVal(12, LEGACY_BITS)
148 m = k - y * BitVecVal(12, LEGACY_BITS)
149 return m
151 @property
152 def day(self) -> BitVecRef:
153 """Get symbolic day component (beta_var + 1, since beta is 0-based)."""
154 return self.beta_var + BitVecVal(1, LEGACY_BITS)
156 def to_concrete_date(self, model: ModelRef) -> Date:
157 """Convert Z3 model to concrete Date using (alpha, beta)."""
158 alpha_val = model.evaluate(
159 self.months_var, model_completion=True
160 ).as_signed_long()
161 beta_val = model.evaluate(self.beta_var, model_completion=True).as_signed_long()
162 k = alpha_val + (_EPOCH_YEAR * 12 + _EPOCH_MONTH)
163 year = (k - 1) // 12
164 month = k - year * 12
165 day = beta_val + 1
166 try:
167 return Date(year, month, day)
168 except ValueError:
169 return Date(year, month, day, bounded=False)
171 def _add_bounds(self) -> None:
172 """Add date validation bounds to this DateVar if solver is available."""
173 if self._solver is None: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true
174 return
176 # Alpha bounds: months since 2000-03
177 # 1900-03 => -1200, 2100-02 => 1199
178 self._solver.add(self.months_var >= BitVecVal(_ALPHA_MIN, LEGACY_BITS))
179 self._solver.add(self.months_var <= BitVecVal(_ALPHA_MAX, LEGACY_BITS))
181 # Beta bounds: 0 <= beta < DIM
182 idx = mod48(self.months_var)
183 dim = Select(_DIM48_LIST, idx)
184 self._solver.add(And(self.beta_var >= BitVecVal(0, LEGACY_BITS), self.beta_var < dim))
186 def __ge__(self, other) -> BoolRef:
187 """Support x >= date comparison."""
188 if isinstance(other, Date):
189 alpha_o = months_since_epoch_from_ym(
190 BitVecVal(other.year, LEGACY_BITS), BitVecVal(other.month, LEGACY_BITS)
191 )
192 beta_o = BitVecVal(other.day - 1, LEGACY_BITS)
193 return Or(
194 self.months_var > alpha_o,
195 And(self.months_var == alpha_o, self.beta_var >= beta_o),
196 )
197 elif isinstance(other, DateVar): 197 ↛ 205line 197 didn't jump to line 205 because the condition on line 197 was always true
198 return Or(
199 self.months_var > other.months_var,
200 And(
201 self.months_var == other.months_var, self.beta_var >= other.beta_var
202 ),
203 )
204 else:
205 raise TypeError(f"Cannot compare DateVar with {type(other)}")
207 def __le__(self, other) -> BoolRef:
208 """Support x <= date comparison."""
209 if isinstance(other, Date):
210 alpha_o = months_since_epoch_from_ym(
211 BitVecVal(other.year, LEGACY_BITS), BitVecVal(other.month, LEGACY_BITS)
212 )
213 beta_o = BitVecVal(other.day - 1, LEGACY_BITS)
214 return Or(
215 self.months_var < alpha_o,
216 And(self.months_var == alpha_o, self.beta_var <= beta_o),
217 )
218 elif isinstance(other, DateVar): 218 ↛ 226line 218 didn't jump to line 226 because the condition on line 218 was always true
219 return Or(
220 self.months_var < other.months_var,
221 And(
222 self.months_var == other.months_var, self.beta_var <= other.beta_var
223 ),
224 )
225 else:
226 raise TypeError(f"Cannot compare DateVar with {type(other)}")
228 def __lt__(self, other) -> BoolRef:
229 """Support x < date comparison."""
230 if isinstance(other, (Date, DateVar)): 230 ↛ 233line 230 didn't jump to line 233 because the condition on line 230 was always true
231 return Not(self.__ge__(other))
232 else:
233 raise TypeError(f"Cannot compare DateVar with {type(other)}")
235 def __gt__(self, other) -> BoolRef:
236 """Support x > date comparison."""
237 if isinstance(other, (Date, DateVar)): 237 ↛ 240line 237 didn't jump to line 240 because the condition on line 237 was always true
238 return Not(self.__le__(other))
239 else:
240 raise TypeError(f"Cannot compare DateVar with {type(other)}")
242 def __eq__(self, other) -> BoolRef:
243 """Support x == date comparison."""
244 if isinstance(other, Date):
245 alpha_o = months_since_epoch_from_ym(
246 BitVecVal(other.year, LEGACY_BITS), BitVecVal(other.month, LEGACY_BITS)
247 )
248 beta_o = BitVecVal(other.day - 1, LEGACY_BITS)
249 return And(self.months_var == alpha_o, self.beta_var == beta_o)
250 elif isinstance(other, DateVar): 250 ↛ 255line 250 didn't jump to line 255 because the condition on line 250 was always true
251 return And(
252 self.months_var == other.months_var, self.beta_var == other.beta_var
253 )
254 else:
255 raise TypeError(f"Cannot compare DateVar with {type(other)}")
257 def __ne__(self, other) -> BoolRef:
258 """Support x != date comparison using ordinal arithmetic."""
259 if isinstance(other, (Date, DateVar)): 259 ↛ 262line 259 didn't jump to line 262 because the condition on line 259 was always true
260 return Not(self.__eq__(other))
261 else:
262 raise TypeError(f"Cannot compare DateVar with {type(other)}")
264 def __add__(self, other) -> "DateVar":
265 if isinstance(other, Period): 265 ↛ 270line 265 didn't jump to line 270 because the condition on line 265 was always true
266 result = DateVar(f"{self.name}_plus")
267 months_delta = BitVecVal(other.years * 12 + other.months, LEGACY_BITS)
268 days_delta = BitVecVal(other.days, LEGACY_BITS)
269 else:
270 raise TypeError(f"Cannot add {type(other)} to DateVar")
272 # Fast path: days-only period (skip month shift)
273 if other.years == 0 and other.months == 0:
274 # Check if result stays within same month
275 alpha1 = self.months_var
276 idx1 = mod48(alpha1)
277 abs1 = alpha_to_abs_month(alpha1)
278 dim1 = Select(_DIM48_LIST, idx1)
279 beta1 = eom_clamp(dim1, self.beta_var)
281 # Within-month fast path: if beta1 + days_delta stays in [0, dim1)
282 new_beta = beta1 + days_delta
283 stays_in_month = And(new_beta >= BitVecVal(0, LEGACY_BITS), new_beta < dim1)
285 # Within-month: simple addition
286 alpha_within = alpha1
287 beta_within = new_beta
289 # Fallback: use full table lookup (when days cross month boundary)
290 base48 = Select(_DBM48_LIST, idx1) + beta1
291 total = base48 + days_delta
293 q0 = _floor_div_four_year_days(total)
294 r0 = total % BitVecVal(_FOUR_YEAR_DAYS, LEGACY_BITS)
296 # Compute idx2 by scanning all 48 months with century correction at target
297 best = BitVecVal(0, LEGACY_BITS)
298 for i in range(1, _FOUR_YEAR_MONTHS):
299 dbm_i_corr = Select(_DBM48_LIST, BitVecVal(i, LEGACY_BITS))
300 best = If(r0 >= dbm_i_corr, BitVecVal(i, LEGACY_BITS), best)
302 idx2 = best
303 diff2 = idx2 - idx1
304 abs2 = alpha_to_abs_month(alpha1 + q0 * BitVecVal(_FOUR_YEAR_MONTHS, LEGACY_BITS) + diff2)
305 beta2 = r0 - (Select(_DBM48_LIST, idx2))
307 dim2 = Select(_DIM48_LIST, idx2)
308 carry = If(beta2 >= dim2, BitVecVal(1, LEGACY_BITS), BitVecVal(0, LEGACY_BITS))
310 alpha_ordinal = alpha1 + q0 * BitVecVal(_FOUR_YEAR_MONTHS, LEGACY_BITS) + diff2 + carry
311 beta_ordinal = If(carry == BitVecVal(1, LEGACY_BITS), beta2 - dim2, beta2)
313 # Select result based on within-month condition
314 result.months_var = If(stays_in_month, alpha_within, alpha_ordinal)
315 result.beta_var = If(stays_in_month, beta_within, beta_ordinal)
316 # Add bounds to intermediate result
317 result._solver = self._solver
318 result._add_bounds()
319 return result
321 # Full path: with month shift
322 alpha1 = self.months_var + months_delta
323 idx1 = mod48(alpha1)
324 abs1 = alpha_to_abs_month(alpha1)
325 dim1 = Select(_DIM48_LIST, idx1)
326 beta1 = eom_clamp(dim1, self.beta_var)
328 # Fast path: years/months-only period (no days)
329 if other.days == 0:
330 # No day addition needed - we're done!
331 result.months_var = alpha1
332 result.beta_var = beta1
334 result._solver = self._solver
335 result._add_bounds()
336 return result
337 else:
338 # Within-month fast path: if adding days stays in same month
339 new_beta = beta1 + days_delta
340 stays_in_month = And(new_beta >= BitVecVal(0, LEGACY_BITS), new_beta < dim1)
342 # Within-month: simple addition
343 alpha_within = alpha1
344 beta_within = new_beta
346 # Full table lookup path
347 base48 = Select(_DBM48_LIST, idx1) + beta1
348 total = base48 + days_delta
350 q0 = _floor_div_four_year_days(total)
351 r0 = total % BitVecVal(_FOUR_YEAR_DAYS, LEGACY_BITS)
353 # Compute idx2 by scanning all 48 months with century correction at target
354 best = BitVecVal(0, LEGACY_BITS)
355 for i in range(1, _FOUR_YEAR_MONTHS):
356 dbm_i_corr = Select(_DBM48_LIST, BitVecVal(i, LEGACY_BITS))
357 best = If(r0 >= dbm_i_corr, BitVecVal(i, LEGACY_BITS), best)
359 idx2 = best
360 diff2 = idx2 - idx1
361 abs2 = alpha_to_abs_month(alpha1 + q0 * BitVecVal(_FOUR_YEAR_MONTHS, LEGACY_BITS) + diff2)
362 beta2 = r0 - (Select(_DBM48_LIST, idx2))
364 # End-of-month overflow carry: if beta2 equals/exceeds the month length,
365 # advance one month and wrap beta into the next month.
366 dim2 = Select(_DIM48_LIST, idx2)
367 carry = If(beta2 >= dim2, BitVecVal(1, LEGACY_BITS), BitVecVal(0, LEGACY_BITS))
369 alpha_ordinal = alpha1 + q0 * BitVecVal(_FOUR_YEAR_MONTHS, LEGACY_BITS) + diff2 + carry
370 beta_ordinal = If(carry == BitVecVal(1, LEGACY_BITS), beta2 - dim2, beta2)
372 # Select result based on within-month condition
373 result.months_var = If(stays_in_month, alpha_within, alpha_ordinal)
374 result.beta_var = If(stays_in_month, beta_within, beta_ordinal)
375 # Add bounds to intermediate result
376 result._solver = self._solver
377 result._add_bounds()
378 return result
380 def __sub__(self, other) -> "DateVar":
381 """DateVar - Period implemented as DateVar + (-Period)."""
382 if isinstance(other, Period): 382 ↛ 386line 382 didn't jump to line 386 because the condition on line 382 was always true
383 neg = Period(-other.years, -other.months, -other.days)
384 return self.__add__(neg)
385 else:
386 raise TypeError(f"Cannot subtract {type(other)} from DateVar")
389class AlphaBetaTableSolver:
390 """Alpha-beta date constraint solver using epoch-based conversion."""
392 def __init__(self, timeout_ms=600000, use_maxsat=False):
393 """Initialize the solver with timeout.
395 Args:
396 timeout_ms: Timeout in milliseconds (default: 60 seconds)
397 use_maxsat: If True, use MaxSAT optimization with soft constraints
398 """
399 self.use_maxsat = use_maxsat
400 if use_maxsat: 400 ↛ 401line 400 didn't jump to line 401 because the condition on line 400 was never true
401 self.solver = Optimize()
402 else:
403 self.solver = Solver()
404 self.solver.set("timeout", timeout_ms)
405 self.date_vars = {}
406 self.constraints = []
407 self.timeout_ms = timeout_ms
409 def add_date_var(self, name: str) -> DateVar:
410 """Add a symbolic date variable with basic constraints."""
411 date_var = DateVar(name)
412 date_var._solver = self.solver
413 self.date_vars[name] = date_var
415 # Add bounds using _add_bounds method
416 date_var._add_bounds()
417 return date_var
419 def add_constraint(self, constraint: BoolRef) -> None:
420 """Add a constraint to the solver."""
421 self.constraints.append(constraint)
422 self.solver.add(constraint)
424 def check(self) -> CheckSatResult:
425 """Check if constraints are satisfiable."""
426 return self.solver.check()
428 def model(self) -> ModelRef:
429 """Get the model if satisfiable."""
430 return self.solver.model()
432 def get_concrete_dates(self, model: ModelRef) -> dict:
433 """Get concrete dates from the model."""
434 return {
435 name: var.to_concrete_date(model) for name, var in self.date_vars.items()
436 }
438 def solve(self) -> Union[bool, dict]:
439 """Solve the constraints."""
440 # Add MaxSAT soft constraints if enabled
441 if self.use_maxsat: 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true
442 from datetime import date
444 today = date.today()
445 # Calculate months since epoch for today
446 today_months = (today.year - _EPOCH_YEAR) * 12 + (
447 today.month - _EPOCH_MONTH
448 )
450 # Convert years to months
451 months_50_years = 50 * 12 # 600 months
452 months_10_years = 10 * 12 # 120 months
454 # Add soft constraints for each date variable
455 for name, date_var in self.date_vars.items():
456 # High weight: today ± 50 years
457 within_50_years = And(
458 date_var.months_var
459 >= BitVecVal(today_months - months_50_years, LEGACY_BITS),
460 date_var.months_var
461 <= BitVecVal(today_months + months_50_years, LEGACY_BITS),
462 )
463 self.solver.add_soft(within_50_years, weight=100)
465 # Low weight: today ± 10 years
466 within_10_years = And(
467 date_var.months_var
468 >= BitVecVal(today_months - months_10_years, LEGACY_BITS),
469 date_var.months_var
470 <= BitVecVal(today_months + months_10_years, LEGACY_BITS),
471 )
472 self.solver.add_soft(within_10_years, weight=10)
474 result = self.check()
475 if result == sat:
476 model = self.model()
477 return {
478 "status": "sat",
479 "dates": self.get_concrete_dates(model),
480 }
481 elif result == unsat: 481 ↛ 485line 481 didn't jump to line 485 because the condition on line 481 was always true
482 return {"status": "unsat", "dates": {}}
483 else:
484 # result == unknown (timeout or resource limit)
485 return {"status": "timeout", "dates": {}}
487 def to_smt2(self) -> str:
488 """Return the current problem in SMT-LIB v2 format."""
489 return self.solver.to_smt2()
491 def get_assertions(self) -> List[BoolRef]:
492 """Return the list of current Z3 assertions (BoolRef)."""
493 return list(self.solver.assertions())