Coverage for datesat / symbolic_bitvector / naive_bv.py: 84.5%

189 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-10 23:47 +0000

1""" 

2Naive DateSAT implementation using component-based representation. 

3 

4This module implements the naive approach where dates are represented 

5as separate year, month, and day variables, and period arithmetic is done 

6component-wise with proper normalization. 

7""" 

8 

9from typing import List, Tuple, Union 

10 

11from z3 import ( 

12 UGE, 

13 And, 

14 BitVec, 

15 BitVecRef, 

16 BitVecVal, 

17 BoolRef, 

18 CheckSatResult, 

19 If, 

20 ModelRef, 

21 Not, 

22 Optimize, 

23 Or, 

24 Solver, 

25 sat, 

26 unsat, 

27) 

28from ..core import Date, Period 

29from .bitwidths import LEGACY_BITS 

30 

31_NONLEAP_PREFIX = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334] 

32_LEAP_PREFIX = [0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335] 

33 

34 

35def is_leap(year) -> BoolRef: 

36 """Check if a year is a leap year.""" 

37 return Or( 

38 And( 

39 year % BitVecVal(4, LEGACY_BITS) == BitVecVal(0, LEGACY_BITS), 

40 year % BitVecVal(100, LEGACY_BITS) != BitVecVal(0, LEGACY_BITS), 

41 ), 

42 year % BitVecVal(400, LEGACY_BITS) == BitVecVal(0, LEGACY_BITS), 

43 ) 

44 

45 

46def days_in_month(year, month) -> BitVecRef: 

47 """Get the number of days in a month, accounting for leap years.""" 

48 return If( 

49 month == BitVecVal(2, LEGACY_BITS), 

50 If(is_leap(year), BitVecVal(29, LEGACY_BITS), BitVecVal(28, LEGACY_BITS)), 

51 If( 

52 Or( 

53 month == BitVecVal(4, LEGACY_BITS), 

54 month == BitVecVal(6, LEGACY_BITS), 

55 month == BitVecVal(9, LEGACY_BITS), 

56 month == BitVecVal(11, LEGACY_BITS), 

57 ), 

58 BitVecVal(30, LEGACY_BITS), 

59 BitVecVal(31, LEGACY_BITS), 

60 ), 

61 ) 

62 

63 

64def normalize_month(y, m) -> Tuple[BitVecRef, BitVecRef]: 

65 """ 

66 NormMonth(y,m) = (y + ((m-1) div 12), ((m-1) mod 12) + 1) 

67 Works for concrete and symbolic inputs. 

68 """ 

69 # Check if m is negative (>= 2^(LEGACY_BITS-1), the sign bit) 

70 sign_bit = 2 ** (LEGACY_BITS - 1) 

71 is_negative = UGE(m, BitVecVal(sign_bit, LEGACY_BITS)) 

72 

73 # Convert to signed value 

74 wrap_around = 2**LEGACY_BITS 

75 signed_m = If(is_negative, m - BitVecVal(wrap_around, LEGACY_BITS), m) 

76 

77 # Normalize using signed arithmetic with floor division 

78 t = signed_m - BitVecVal(1, LEGACY_BITS) 

79 

80 # Implement floor division: if t < 0 and t % 12 != 0, subtract 1 from quotient 

81 q_trunc = t / BitVecVal(12, LEGACY_BITS) # Truncating division 

82 r = t % BitVecVal(12, LEGACY_BITS) # Modulo 

83 # Check if t is negative by checking if it's >= sign bit (unsigned comparison) 

84 is_negative_t = UGE(t, BitVecVal(sign_bit, LEGACY_BITS)) 

85 is_negative_and_has_remainder = And(is_negative_t, r != BitVecVal(0, LEGACY_BITS)) 

86 q = If(is_negative_and_has_remainder, q_trunc - BitVecVal(1, LEGACY_BITS), q_trunc) 

87 

88 return y + q, r + BitVecVal(1, LEGACY_BITS) 

89 

90def eom_clamp(year, month, day) -> BitVecRef: 

91 """ 

92 End-of-month clamp: ensure day is valid for the given year/month. 

93 """ 

94 max_day = days_in_month(year, month) 

95 return If( 

96 day < BitVecVal(1, LEGACY_BITS), 

97 BitVecVal(1, LEGACY_BITS), 

98 If(day > max_day, max_day, day), 

99 ) 

100 

101def add_days_componentwise( 

102 y, m, d, delta_days: int 

103) -> Tuple[BitVecRef, BitVecRef, BitVecRef]: 

104 """ 

105 Add a concrete day offset by iteratively carrying into months/years. 

106 """ 

107 if delta_days == 0: 

108 return y, m, d 

109 

110 one = BitVecVal(1, LEGACY_BITS) 

111 twelve = BitVecVal(12, LEGACY_BITS) 

112 cur_y, cur_m, cur_d = y, m, d 

113 

114 if delta_days > 0: 

115 for _ in range(delta_days): 

116 max_day = days_in_month(cur_y, cur_m) 

117 next_day = cur_d + one 

118 overflow = next_day > max_day 

119 month_plus_one = cur_m + one 

120 month_wrap = month_plus_one > twelve 

121 

122 new_day = If(overflow, one, next_day) 

123 new_month = If( 

124 overflow, 

125 If(month_wrap, one, month_plus_one), 

126 cur_m, 

127 ) 

128 new_year = If( 

129 overflow, 

130 If(month_wrap, cur_y + one, cur_y), 

131 cur_y, 

132 ) 

133 

134 cur_y, cur_m, cur_d = new_year, new_month, new_day 

135 return cur_y, cur_m, cur_d 

136 

137 for _ in range(-delta_days): 

138 prev_day = cur_d - one 

139 underflow = prev_day < one 

140 month_minus_one = cur_m - one 

141 month_wrap = month_minus_one < one 

142 

143 prev_year = If(month_wrap, cur_y - one, cur_y) 

144 normalized_month = If(month_wrap, twelve, month_minus_one) 

145 prev_max_day = days_in_month(prev_year, normalized_month) 

146 

147 new_day = If(underflow, prev_max_day, prev_day) 

148 new_month = If(underflow, normalized_month, cur_m) 

149 new_year = If(underflow, prev_year, cur_y) 

150 

151 cur_y, cur_m, cur_d = new_year, new_month, new_day 

152 

153 return cur_y, cur_m, cur_d 

154 

155class DateVar: 

156 """Symbolic date variable for naive implementation.""" 

157 

158 def __init__(self, name: str): 

159 """Create a symbolic date variable.""" 

160 self.name = name 

161 # Create separate Z3 bitvector variables for year, month, day 

162 self.year = BitVec( 

163 f"{name}_year", LEGACY_BITS 

164 ) # Use LEGACY_BITS for arithmetic compatibility 

165 self.month = BitVec( 

166 f"{name}_month", LEGACY_BITS 

167 ) # Use LEGACY_BITS for arithmetic compatibility 

168 self.day = BitVec( 

169 f"{name}_day", LEGACY_BITS 

170 ) # Use LEGACY_BITS for arithmetic compatibility 

171 # Solver reference for adding bounds to intermediate dates (set after creation if needed) 

172 self._solver = None 

173 

174 def __str__(self) -> str: 

175 return f"DateVar({self.name})" 

176 

177 def to_concrete_date(self, model: ModelRef) -> Date: 

178 """Convert Z3 model to concrete Date.""" 

179 year = model.evaluate(self.year, model_completion=True).as_long() 

180 month = model.evaluate(self.month, model_completion=True).as_long() 

181 day = model.evaluate(self.day, model_completion=True).as_long() 

182 return Date(year, month, day) 

183 

184 def __ge__(self, other) -> BoolRef: 

185 """Support x >= date comparison.""" 

186 if isinstance(other, (Date, DateVar)): 186 ↛ 208line 186 didn't jump to line 208 because the condition on line 186 was always true

187 # Convert Date to bitvector values if needed 

188 if isinstance(other, Date): 

189 other_year = BitVecVal(other.year, LEGACY_BITS) 

190 other_month = BitVecVal(other.month, LEGACY_BITS) 

191 other_day = BitVecVal(other.day, LEGACY_BITS) 

192 else: # isinstance(other, DateVar) 

193 other_year = other.year 

194 other_month = other.month 

195 other_day = other.day 

196 

197 return Or( 

198 self.year > other_year, 

199 And( 

200 self.year == other_year, 

201 Or( 

202 self.month > other_month, 

203 And(self.month == other_month, self.day >= other_day), 

204 ), 

205 ), 

206 ) 

207 else: 

208 raise TypeError(f"Cannot compare DateVar with {type(other)}") 

209 

210 def __le__(self, other) -> BoolRef: 

211 """Support x <= date comparison.""" 

212 if isinstance(other, (Date, DateVar)): 212 ↛ 234line 212 didn't jump to line 234 because the condition on line 212 was always true

213 # Convert Date to bitvector values if needed 

214 if isinstance(other, Date): 

215 other_year = BitVecVal(other.year, LEGACY_BITS) 

216 other_month = BitVecVal(other.month, LEGACY_BITS) 

217 other_day = BitVecVal(other.day, LEGACY_BITS) 

218 else: # isinstance(other, DateVar) 

219 other_year = other.year 

220 other_month = other.month 

221 other_day = other.day 

222 

223 return Or( 

224 self.year < other_year, 

225 And( 

226 self.year == other_year, 

227 Or( 

228 self.month < other_month, 

229 And(self.month == other_month, self.day <= other_day), 

230 ), 

231 ), 

232 ) 

233 else: 

234 raise TypeError(f"Cannot compare DateVar with {type(other)}") 

235 

236 def __lt__(self, other) -> BoolRef: 

237 """Support x < date comparison.""" 

238 if isinstance(other, (Date, DateVar)): 238 ↛ 241line 238 didn't jump to line 241 because the condition on line 238 was always true

239 return Not(self.__ge__(other)) 

240 else: 

241 raise TypeError(f"Cannot compare DateVar with {type(other)}") 

242 

243 def __gt__(self, other) -> BoolRef: 

244 """Support x > date comparison.""" 

245 if isinstance(other, (Date, DateVar)): 245 ↛ 248line 245 didn't jump to line 248 because the condition on line 245 was always true

246 return Not(self.__le__(other)) 

247 else: 

248 raise TypeError(f"Cannot compare DateVar with {type(other)}") 

249 

250 def __eq__(self, other) -> BoolRef: 

251 """Support x == date comparison.""" 

252 if isinstance(other, (Date, DateVar)): 252 ↛ 269line 252 didn't jump to line 269 because the condition on line 252 was always true

253 # Convert Date to bitvector values if needed 

254 if isinstance(other, Date): 

255 other_year = BitVecVal(other.year, LEGACY_BITS) 

256 other_month = BitVecVal(other.month, LEGACY_BITS) 

257 other_day = BitVecVal(other.day, LEGACY_BITS) 

258 else: # isinstance(other, DateVar) 

259 other_year = other.year 

260 other_month = other.month 

261 other_day = other.day 

262 

263 return And( 

264 self.year == other_year, 

265 self.month == other_month, 

266 self.day == other_day, 

267 ) 

268 else: 

269 raise TypeError(f"Cannot compare DateVar with {type(other)}") 

270 

271 def __ne__(self, other) -> BoolRef: 

272 """Support x != date comparison using ordinal arithmetic.""" 

273 if isinstance(other, (Date, DateVar)): 273 ↛ 276line 273 didn't jump to line 276 because the condition on line 273 was always true

274 return Not(self.__eq__(other)) 

275 else: 

276 raise TypeError(f"Cannot compare DateVar with {type(other)}") 

277 

278 def _add_bounds(self) -> None: 

279 """Add date validation bounds to this DateVar if solver is available.""" 

280 if self._solver is None: 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true

281 return 

282 

283 # Add comprehensive date validation constraints 

284 # Valid range is 1900-03-01 to 2100-02-28 

285 self._solver.add( 

286 Or( 

287 # 1900-03-01 to 1900-12-31 

288 And( 

289 self.year == BitVecVal(1900, LEGACY_BITS), 

290 self.month >= BitVecVal(3, LEGACY_BITS), 

291 self.month <= BitVecVal(12, LEGACY_BITS), 

292 self.day >= BitVecVal(1, LEGACY_BITS), 

293 self.day <= days_in_month(self.year, self.month), 

294 ), 

295 # 1901-01-01 to 2099-12-31 

296 And( 

297 self.year >= BitVecVal(1901, LEGACY_BITS), 

298 self.year <= BitVecVal(2099, LEGACY_BITS), 

299 self.month >= BitVecVal(1, LEGACY_BITS), 

300 self.month <= BitVecVal(12, LEGACY_BITS), 

301 self.day >= BitVecVal(1, LEGACY_BITS), 

302 self.day <= days_in_month(self.year, self.month), 

303 ), 

304 # 2100-01-01 to 2100-02-28 

305 And( 

306 self.year == BitVecVal(2100, LEGACY_BITS), 

307 self.month >= BitVecVal(1, LEGACY_BITS), 

308 self.month <= BitVecVal(2, LEGACY_BITS), 

309 self.day >= BitVecVal(1, LEGACY_BITS), 

310 self.day <= days_in_month(self.year, self.month), 

311 ), 

312 ) 

313 ) 

314 

315 def __add__(self, other) -> 'DateVar': 

316 """ 

317 DateVar + Period following naive semantics: 

318 1) Combine Y and M (normalize months into 1..12 with year carry) 

319 2) Apply EOM clamp: day := min(original_day, days_in_month(new_year,new_month)) 

320 3) Add D days via iterative day carry (month/year rollover as required) 

321 

322 Optimizations: 

323 - Fast path: If period is days-only (years=0, months=0), skip month normalization 

324 """ 

325 if isinstance(other, Period): 325 ↛ 363line 325 didn't jump to line 363 because the condition on line 325 was always true

326 result = DateVar(f"{self.name}_plus") 

327 

328 # Extract period components 

329 period_years = other.years 

330 period_months = other.months 

331 period_days = other.days 

332 

333 # Fast path: days-only period (skip month normalization and EOM clamp) 

334 # Check at Python level since Period components are concrete integers 

335 if period_years == 0 and period_months == 0: 

336 # Days-only path: directly add days 

337 y2, m2, d2 = add_days_componentwise( 

338 self.year, self.month, self.day, period_days 

339 ) 

340 else: 

341 # Full path: Step 1: Combine Y and M (normalize months into 1..12 with year carry) 

342 # Convert period years to months and combine with period months 

343 period_total_months = BitVecVal(period_years, LEGACY_BITS) * BitVecVal(12, LEGACY_BITS) + BitVecVal(period_months, LEGACY_BITS) 

344 # Add to current month and normalize 

345 total_months = self.month + period_total_months 

346 year_carry, m1 = normalize_month(BitVecVal(0, LEGACY_BITS), total_months) 

347 y1 = self.year + year_carry 

348 

349 # Step 2: Apply EOM clamp: day := min(original_day, days_in_month(new_year,new_month)) 

350 d1 = eom_clamp(y1, m1, self.day) 

351 

352 # Step 3: Add D days via iterative carry across month/year boundaries 

353 y2, m2, d2 = add_days_componentwise(y1, m1, d1, period_days) 

354 

355 # Direct assignment  

356 result.year, result.month, result.day = y2, m2, d2 

357 

358 # Add bounds to intermediate result 

359 result._solver = self._solver 

360 result._add_bounds() 

361 return result 

362 else: 

363 raise TypeError(f"Cannot add {type(other)} to DateVar") 

364 

365 def __sub__(self, other) -> "DateVar": 

366 """DateVar - Period implemented as DateVar + (-Period).""" 

367 if isinstance(other, Period): 367 ↛ 371line 367 didn't jump to line 371 because the condition on line 367 was always true

368 neg = Period(-other.years, -other.months, -other.days) 

369 return self.__add__(neg) 

370 else: 

371 raise TypeError(f"Cannot subtract {type(other)} from DateVar") 

372 

373 

374class NaiveSolver: 

375 """Naive date constraint solver using component-based representation.""" 

376 

377 def __init__(self, timeout_ms=600000, use_maxsat=False): 

378 """Initialize the solver with optional year bounds and timeout. 

379 

380 Args: 

381 timeout_ms: Timeout in milliseconds (default: 60 seconds) 

382 use_maxsat: If True, use MaxSAT optimization with soft constraints 

383 """ 

384 self.use_maxsat = use_maxsat 

385 if use_maxsat: 385 ↛ 386line 385 didn't jump to line 386 because the condition on line 385 was never true

386 self.solver = Optimize() 

387 else: 

388 self.solver = Solver() 

389 self.solver.set("timeout", timeout_ms) 

390 self.date_vars = {} 

391 self.constraints = [] 

392 self.timeout_ms = timeout_ms 

393 

394 def add_date_var(self, name: str) -> DateVar: 

395 """Add a symbolic date variable with comprehensive date validation.""" 

396 date_var = DateVar(name) 

397 date_var._solver = self.solver 

398 self.date_vars[name] = date_var 

399 

400 # Add bounds using _add_bounds method 

401 date_var._add_bounds() 

402 return date_var 

403 

404 def add_constraint(self, constraint: BoolRef) -> None: 

405 """Add a constraint to the solver.""" 

406 self.constraints.append(constraint) 

407 self.solver.add(constraint) 

408 

409 def check(self) -> CheckSatResult: 

410 """Check if constraints are satisfiable.""" 

411 return self.solver.check() 

412 

413 def model(self) -> ModelRef: 

414 """Get the model if satisfiable.""" 

415 return self.solver.model() 

416 

417 def get_concrete_dates(self, model: ModelRef) -> dict: 

418 """Get concrete dates from the model.""" 

419 return { 

420 name: var.to_concrete_date(model) for name, var in self.date_vars.items() 

421 } 

422 

423 def solve(self) -> Union[bool, dict]: 

424 """Solve the constraints.""" 

425 # Add MaxSAT soft constraints if enabled 

426 if self.use_maxsat: 426 ↛ 427line 426 didn't jump to line 427 because the condition on line 426 was never true

427 from datetime import date 

428 

429 today = date.today() 

430 today_year = today.year 

431 

432 # Add soft constraints for each date variable 

433 for name, date_var in self.date_vars.items(): 

434 # High weight: today ± 50 years 

435 within_50_years = And( 

436 date_var.year >= BitVecVal(today_year - 50, LEGACY_BITS), 

437 date_var.year <= BitVecVal(today_year + 50, LEGACY_BITS), 

438 ) 

439 self.solver.add_soft(within_50_years, weight=100) 

440 

441 # Low weight: today ± 10 years 

442 within_10_years = And( 

443 date_var.year >= BitVecVal(today_year - 10, LEGACY_BITS), 

444 date_var.year <= BitVecVal(today_year + 10, LEGACY_BITS), 

445 ) 

446 self.solver.add_soft(within_10_years, weight=10) 

447 

448 result = self.check() 

449 if result == sat: 

450 model = self.model() 

451 return { 

452 "status": "sat", 

453 "dates": self.get_concrete_dates(model), 

454 } 

455 elif result == unsat: 455 ↛ 459line 455 didn't jump to line 459 because the condition on line 455 was always true

456 return {"status": "unsat", "dates": {}} 

457 else: 

458 # result == unknown (timeout or resource limit) 

459 return {"status": "timeout", "dates": {}} 

460 

461 def to_smt2(self) -> str: 

462 """Return the current problem in SMT-LIB v2 format.""" 

463 return self.solver.to_smt2() 

464 

465 def get_assertions(self) -> List[BoolRef]: 

466 """Return the list of current Z3 assertions (BoolRef).""" 

467 return list(self.solver.assertions())