Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Rewrite assertion AST to produce nice error messages""" 

2import ast 

3import errno 

4import functools 

5import importlib.abc 

6import importlib.machinery 

7import importlib.util 

8import io 

9import itertools 

10import marshal 

11import os 

12import struct 

13import sys 

14import tokenize 

15import types 

16from typing import Dict 

17from typing import List 

18from typing import Optional 

19from typing import Set 

20from typing import Tuple 

21 

22from _pytest._io.saferepr import saferepr 

23from _pytest._version import version 

24from _pytest.assertion import util 

25from _pytest.assertion.util import ( # noqa: F401 

26 format_explanation as _format_explanation, 

27) 

28from _pytest.compat import fspath 

29from _pytest.pathlib import fnmatch_ex 

30from _pytest.pathlib import Path 

31from _pytest.pathlib import PurePath 

32 

33# pytest caches rewritten pycs in pycache dirs 

34PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version) 

35PYC_EXT = ".py" + (__debug__ and "c" or "o") 

36PYC_TAIL = "." + PYTEST_TAG + PYC_EXT 

37 

38 

39class AssertionRewritingHook(importlib.abc.MetaPathFinder): 

40 """PEP302/PEP451 import hook which rewrites asserts.""" 

41 

42 def __init__(self, config): 

43 self.config = config 

44 try: 

45 self.fnpats = config.getini("python_files") 

46 except ValueError: 

47 self.fnpats = ["test_*.py", "*_test.py"] 

48 self.session = None 

49 self._rewritten_names = set() # type: Set[str] 

50 self._must_rewrite = set() # type: Set[str] 

51 # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, 

52 # which might result in infinite recursion (#3506) 

53 self._writing_pyc = False 

54 self._basenames_to_check_rewrite = {"conftest"} 

55 self._marked_for_rewrite_cache = {} # type: Dict[str, bool] 

56 self._session_paths_checked = False 

57 

58 def set_session(self, session): 

59 self.session = session 

60 self._session_paths_checked = False 

61 

62 # Indirection so we can mock calls to find_spec originated from the hook during testing 

63 _find_spec = importlib.machinery.PathFinder.find_spec 

64 

65 def find_spec(self, name, path=None, target=None): 

66 if self._writing_pyc: 

67 return None 

68 state = self.config._assertstate 

69 if self._early_rewrite_bailout(name, state): 

70 return None 

71 state.trace("find_module called for: %s" % name) 

72 

73 spec = self._find_spec(name, path) 

74 if ( 

75 # the import machinery could not find a file to import 

76 spec is None 

77 # this is a namespace package (without `__init__.py`) 

78 # there's nothing to rewrite there 

79 # python3.5 - python3.6: `namespace` 

80 # python3.7+: `None` 

81 or spec.origin == "namespace" 

82 or spec.origin is None 

83 # we can only rewrite source files 

84 or not isinstance(spec.loader, importlib.machinery.SourceFileLoader) 

85 # if the file doesn't exist, we can't rewrite it 

86 or not os.path.exists(spec.origin) 

87 ): 

88 return None 

89 else: 

90 fn = spec.origin 

91 

92 if not self._should_rewrite(name, fn, state): 

93 return None 

94 

95 return importlib.util.spec_from_file_location( 

96 name, 

97 fn, 

98 loader=self, 

99 submodule_search_locations=spec.submodule_search_locations, 

100 ) 

101 

102 def create_module(self, spec): 

103 return None # default behaviour is fine 

104 

105 def exec_module(self, module): 

106 fn = Path(module.__spec__.origin) 

107 state = self.config._assertstate 

108 

109 self._rewritten_names.add(module.__name__) 

110 

111 # The requested module looks like a test file, so rewrite it. This is 

112 # the most magical part of the process: load the source, rewrite the 

113 # asserts, and load the rewritten source. We also cache the rewritten 

114 # module code in a special pyc. We must be aware of the possibility of 

115 # concurrent pytest processes rewriting and loading pycs. To avoid 

116 # tricky race conditions, we maintain the following invariant: The 

117 # cached pyc is always a complete, valid pyc. Operations on it must be 

118 # atomic. POSIX's atomic rename comes in handy. 

119 write = not sys.dont_write_bytecode 

120 cache_dir = get_cache_dir(fn) 

121 if write: 

122 ok = try_makedirs(cache_dir) 

123 if not ok: 

124 write = False 

125 state.trace("read only directory: {}".format(cache_dir)) 

126 

127 cache_name = fn.name[:-3] + PYC_TAIL 

128 pyc = cache_dir / cache_name 

129 # Notice that even if we're in a read-only directory, I'm going 

130 # to check for a cached pyc. This may not be optimal... 

131 co = _read_pyc(fn, pyc, state.trace) 

132 if co is None: 

133 state.trace("rewriting {!r}".format(fn)) 

134 source_stat, co = _rewrite_test(fn, self.config) 

135 if write: 

136 self._writing_pyc = True 

137 try: 

138 _write_pyc(state, co, source_stat, pyc) 

139 finally: 

140 self._writing_pyc = False 

141 else: 

142 state.trace("found cached rewritten pyc for {}".format(fn)) 

143 exec(co, module.__dict__) 

144 

145 def _early_rewrite_bailout(self, name, state): 

146 """This is a fast way to get out of rewriting modules. 

147 

148 Profiling has shown that the call to PathFinder.find_spec (inside of 

149 the find_spec from this class) is a major slowdown, so, this method 

150 tries to filter what we're sure won't be rewritten before getting to 

151 it. 

152 """ 

153 if self.session is not None and not self._session_paths_checked: 

154 self._session_paths_checked = True 

155 for path in self.session._initialpaths: 

156 # Make something as c:/projects/my_project/path.py -> 

157 # ['c:', 'projects', 'my_project', 'path.py'] 

158 parts = str(path).split(os.path.sep) 

159 # add 'path' to basenames to be checked. 

160 self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0]) 

161 

162 # Note: conftest already by default in _basenames_to_check_rewrite. 

163 parts = name.split(".") 

164 if parts[-1] in self._basenames_to_check_rewrite: 

165 return False 

166 

167 # For matching the name it must be as if it was a filename. 

168 path = PurePath(os.path.sep.join(parts) + ".py") 

169 

170 for pat in self.fnpats: 

171 # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based 

172 # on the name alone because we need to match against the full path 

173 if os.path.dirname(pat): 

174 return False 

175 if fnmatch_ex(pat, path): 

176 return False 

177 

178 if self._is_marked_for_rewrite(name, state): 

179 return False 

180 

181 state.trace("early skip of rewriting module: {}".format(name)) 

182 return True 

183 

184 def _should_rewrite(self, name, fn, state): 

185 # always rewrite conftest files 

186 if os.path.basename(fn) == "conftest.py": 

187 state.trace("rewriting conftest file: {!r}".format(fn)) 

188 return True 

189 

190 if self.session is not None: 

191 if self.session.isinitpath(fn): 

192 state.trace( 

193 "matched test file (was specified on cmdline): {!r}".format(fn) 

194 ) 

195 return True 

196 

197 # modules not passed explicitly on the command line are only 

198 # rewritten if they match the naming convention for test files 

199 fn_path = PurePath(fn) 

200 for pat in self.fnpats: 

201 if fnmatch_ex(pat, fn_path): 

202 state.trace("matched test file {!r}".format(fn)) 

203 return True 

204 

205 return self._is_marked_for_rewrite(name, state) 

206 

207 def _is_marked_for_rewrite(self, name: str, state): 

208 try: 

209 return self._marked_for_rewrite_cache[name] 

210 except KeyError: 

211 for marked in self._must_rewrite: 

212 if name == marked or name.startswith(marked + "."): 

213 state.trace( 

214 "matched marked file {!r} (from {!r})".format(name, marked) 

215 ) 

216 self._marked_for_rewrite_cache[name] = True 

217 return True 

218 

219 self._marked_for_rewrite_cache[name] = False 

220 return False 

221 

222 def mark_rewrite(self, *names: str) -> None: 

223 """Mark import names as needing to be rewritten. 

224 

225 The named module or package as well as any nested modules will 

226 be rewritten on import. 

227 """ 

228 already_imported = ( 

229 set(names).intersection(sys.modules).difference(self._rewritten_names) 

230 ) 

231 for name in already_imported: 

232 mod = sys.modules[name] 

233 if not AssertionRewriter.is_rewrite_disabled( 

234 mod.__doc__ or "" 

235 ) and not isinstance(mod.__loader__, type(self)): 

236 self._warn_already_imported(name) 

237 self._must_rewrite.update(names) 

238 self._marked_for_rewrite_cache.clear() 

239 

240 def _warn_already_imported(self, name): 

241 from _pytest.warning_types import PytestAssertRewriteWarning 

242 from _pytest.warnings import _issue_warning_captured 

243 

244 _issue_warning_captured( 

245 PytestAssertRewriteWarning( 

246 "Module already imported so cannot be rewritten: %s" % name 

247 ), 

248 self.config.hook, 

249 stacklevel=5, 

250 ) 

251 

252 def get_data(self, pathname): 

253 """Optional PEP302 get_data API.""" 

254 with open(pathname, "rb") as f: 

255 return f.read() 

256 

257 

258def _write_pyc_fp(fp, source_stat, co): 

259 # Technically, we don't have to have the same pyc format as 

260 # (C)Python, since these "pycs" should never be seen by builtin 

261 # import. However, there's little reason deviate. 

262 fp.write(importlib.util.MAGIC_NUMBER) 

263 # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) 

264 mtime = int(source_stat.st_mtime) & 0xFFFFFFFF 

265 size = source_stat.st_size & 0xFFFFFFFF 

266 # "<LL" stands for 2 unsigned longs, little-ending 

267 fp.write(struct.pack("<LL", mtime, size)) 

268 fp.write(marshal.dumps(co)) 

269 

270 

271if sys.platform == "win32": 

272 from atomicwrites import atomic_write 

273 

274 def _write_pyc(state, co, source_stat, pyc): 

275 try: 

276 with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp: 

277 _write_pyc_fp(fp, source_stat, co) 

278 except EnvironmentError as e: 

279 state.trace("error writing pyc file at {}: errno={}".format(pyc, e.errno)) 

280 # we ignore any failure to write the cache file 

281 # there are many reasons, permission-denied, pycache dir being a 

282 # file etc. 

283 return False 

284 return True 

285 

286 

287else: 

288 

289 def _write_pyc(state, co, source_stat, pyc): 

290 proc_pyc = "{}.{}".format(pyc, os.getpid()) 

291 try: 

292 fp = open(proc_pyc, "wb") 

293 except EnvironmentError as e: 

294 state.trace( 

295 "error writing pyc file at {}: errno={}".format(proc_pyc, e.errno) 

296 ) 

297 return False 

298 

299 try: 

300 _write_pyc_fp(fp, source_stat, co) 

301 os.rename(proc_pyc, fspath(pyc)) 

302 except BaseException as e: 

303 state.trace("error writing pyc file at {}: errno={}".format(pyc, e.errno)) 

304 # we ignore any failure to write the cache file 

305 # there are many reasons, permission-denied, pycache dir being a 

306 # file etc. 

307 return False 

308 finally: 

309 fp.close() 

310 return True 

311 

312 

313def _rewrite_test(fn, config): 

314 """read and rewrite *fn* and return the code object.""" 

315 fn = fspath(fn) 

316 stat = os.stat(fn) 

317 with open(fn, "rb") as f: 

318 source = f.read() 

319 tree = ast.parse(source, filename=fn) 

320 rewrite_asserts(tree, source, fn, config) 

321 co = compile(tree, fn, "exec", dont_inherit=True) 

322 return stat, co 

323 

324 

325def _read_pyc(source, pyc, trace=lambda x: None): 

326 """Possibly read a pytest pyc containing rewritten code. 

327 

328 Return rewritten code if successful or None if not. 

329 """ 

330 try: 

331 fp = open(fspath(pyc), "rb") 

332 except IOError: 

333 return None 

334 with fp: 

335 try: 

336 stat_result = os.stat(fspath(source)) 

337 mtime = int(stat_result.st_mtime) 

338 size = stat_result.st_size 

339 data = fp.read(12) 

340 except EnvironmentError as e: 

341 trace("_read_pyc({}): EnvironmentError {}".format(source, e)) 

342 return None 

343 # Check for invalid or out of date pyc file. 

344 if ( 

345 len(data) != 12 

346 or data[:4] != importlib.util.MAGIC_NUMBER 

347 or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF) 

348 ): 

349 trace("_read_pyc(%s): invalid or out of date pyc" % source) 

350 return None 

351 try: 

352 co = marshal.load(fp) 

353 except Exception as e: 

354 trace("_read_pyc({}): marshal.load error {}".format(source, e)) 

355 return None 

356 if not isinstance(co, types.CodeType): 

357 trace("_read_pyc(%s): not a code object" % source) 

358 return None 

359 return co 

360 

361 

362def rewrite_asserts(mod, source, module_path=None, config=None): 

363 """Rewrite the assert statements in mod.""" 

364 AssertionRewriter(module_path, config, source).run(mod) 

365 

366 

367def _saferepr(obj): 

368 """Get a safe repr of an object for assertion error messages. 

369 

370 The assertion formatting (util.format_explanation()) requires 

371 newlines to be escaped since they are a special character for it. 

372 Normally assertion.util.format_explanation() does this but for a 

373 custom repr it is possible to contain one of the special escape 

374 sequences, especially '\n{' and '\n}' are likely to be present in 

375 JSON reprs. 

376 

377 """ 

378 return saferepr(obj).replace("\n", "\\n") 

379 

380 

381def _format_assertmsg(obj): 

382 """Format the custom assertion message given. 

383 

384 For strings this simply replaces newlines with '\n~' so that 

385 util.format_explanation() will preserve them instead of escaping 

386 newlines. For other objects saferepr() is used first. 

387 

388 """ 

389 # reprlib appears to have a bug which means that if a string 

390 # contains a newline it gets escaped, however if an object has a 

391 # .__repr__() which contains newlines it does not get escaped. 

392 # However in either case we want to preserve the newline. 

393 replaces = [("\n", "\n~"), ("%", "%%")] 

394 if not isinstance(obj, str): 

395 obj = saferepr(obj) 

396 replaces.append(("\\n", "\n~")) 

397 

398 for r1, r2 in replaces: 

399 obj = obj.replace(r1, r2) 

400 

401 return obj 

402 

403 

404def _should_repr_global_name(obj): 

405 if callable(obj): 

406 return False 

407 

408 try: 

409 return not hasattr(obj, "__name__") 

410 except Exception: 

411 return True 

412 

413 

414def _format_boolop(explanations, is_or): 

415 explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" 

416 if isinstance(explanation, str): 

417 return explanation.replace("%", "%%") 

418 else: 

419 return explanation.replace(b"%", b"%%") 

420 

421 

422def _call_reprcompare(ops, results, expls, each_obj): 

423 # type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str 

424 for i, res, expl in zip(range(len(ops)), results, expls): 

425 try: 

426 done = not res 

427 except Exception: 

428 done = True 

429 if done: 

430 break 

431 if util._reprcompare is not None: 

432 custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) 

433 if custom is not None: 

434 return custom 

435 return expl 

436 

437 

438def _call_assertion_pass(lineno, orig, expl): 

439 # type: (int, str, str) -> None 

440 if util._assertion_pass is not None: 

441 util._assertion_pass(lineno, orig, expl) 

442 

443 

444def _check_if_assertion_pass_impl(): 

445 # type: () -> bool 

446 """Checks if any plugins implement the pytest_assertion_pass hook 

447 in order not to generate explanation unecessarily (might be expensive)""" 

448 return True if util._assertion_pass else False 

449 

450 

451UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} 

452 

453BINOP_MAP = { 

454 ast.BitOr: "|", 

455 ast.BitXor: "^", 

456 ast.BitAnd: "&", 

457 ast.LShift: "<<", 

458 ast.RShift: ">>", 

459 ast.Add: "+", 

460 ast.Sub: "-", 

461 ast.Mult: "*", 

462 ast.Div: "/", 

463 ast.FloorDiv: "//", 

464 ast.Mod: "%%", # escaped for string formatting 

465 ast.Eq: "==", 

466 ast.NotEq: "!=", 

467 ast.Lt: "<", 

468 ast.LtE: "<=", 

469 ast.Gt: ">", 

470 ast.GtE: ">=", 

471 ast.Pow: "**", 

472 ast.Is: "is", 

473 ast.IsNot: "is not", 

474 ast.In: "in", 

475 ast.NotIn: "not in", 

476 ast.MatMult: "@", 

477} 

478 

479 

480def set_location(node, lineno, col_offset): 

481 """Set node location information recursively.""" 

482 

483 def _fix(node, lineno, col_offset): 

484 if "lineno" in node._attributes: 

485 node.lineno = lineno 

486 if "col_offset" in node._attributes: 

487 node.col_offset = col_offset 

488 for child in ast.iter_child_nodes(node): 

489 _fix(child, lineno, col_offset) 

490 

491 _fix(node, lineno, col_offset) 

492 return node 

493 

494 

495def _get_assertion_exprs(src: bytes) -> Dict[int, str]: 

496 """Returns a mapping from {lineno: "assertion test expression"}""" 

497 ret = {} # type: Dict[int, str] 

498 

499 depth = 0 

500 lines = [] # type: List[str] 

501 assert_lineno = None # type: Optional[int] 

502 seen_lines = set() # type: Set[int] 

503 

504 def _write_and_reset() -> None: 

505 nonlocal depth, lines, assert_lineno, seen_lines 

506 assert assert_lineno is not None 

507 ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\") 

508 depth = 0 

509 lines = [] 

510 assert_lineno = None 

511 seen_lines = set() 

512 

513 tokens = tokenize.tokenize(io.BytesIO(src).readline) 

514 for tp, source, (lineno, offset), _, line in tokens: 

515 if tp == tokenize.NAME and source == "assert": 

516 assert_lineno = lineno 

517 elif assert_lineno is not None: 

518 # keep track of depth for the assert-message `,` lookup 

519 if tp == tokenize.OP and source in "([{": 

520 depth += 1 

521 elif tp == tokenize.OP and source in ")]}": 

522 depth -= 1 

523 

524 if not lines: 

525 lines.append(line[offset:]) 

526 seen_lines.add(lineno) 

527 # a non-nested comma separates the expression from the message 

528 elif depth == 0 and tp == tokenize.OP and source == ",": 

529 # one line assert with message 

530 if lineno in seen_lines and len(lines) == 1: 

531 offset_in_trimmed = offset + len(lines[-1]) - len(line) 

532 lines[-1] = lines[-1][:offset_in_trimmed] 

533 # multi-line assert with message 

534 elif lineno in seen_lines: 

535 lines[-1] = lines[-1][:offset] 

536 # multi line assert with escapd newline before message 

537 else: 

538 lines.append(line[:offset]) 

539 _write_and_reset() 

540 elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}: 

541 _write_and_reset() 

542 elif lines and lineno not in seen_lines: 

543 lines.append(line) 

544 seen_lines.add(lineno) 

545 

546 return ret 

547 

548 

549class AssertionRewriter(ast.NodeVisitor): 

550 """Assertion rewriting implementation. 

551 

552 The main entrypoint is to call .run() with an ast.Module instance, 

553 this will then find all the assert statements and rewrite them to 

554 provide intermediate values and a detailed assertion error. See 

555 http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html 

556 for an overview of how this works. 

557 

558 The entry point here is .run() which will iterate over all the 

559 statements in an ast.Module and for each ast.Assert statement it 

560 finds call .visit() with it. Then .visit_Assert() takes over and 

561 is responsible for creating new ast statements to replace the 

562 original assert statement: it rewrites the test of an assertion 

563 to provide intermediate values and replace it with an if statement 

564 which raises an assertion error with a detailed explanation in 

565 case the expression is false and calls pytest_assertion_pass hook 

566 if expression is true. 

567 

568 For this .visit_Assert() uses the visitor pattern to visit all the 

569 AST nodes of the ast.Assert.test field, each visit call returning 

570 an AST node and the corresponding explanation string. During this 

571 state is kept in several instance attributes: 

572 

573 :statements: All the AST statements which will replace the assert 

574 statement. 

575 

576 :variables: This is populated by .variable() with each variable 

577 used by the statements so that they can all be set to None at 

578 the end of the statements. 

579 

580 :variable_counter: Counter to create new unique variables needed 

581 by statements. Variables are created using .variable() and 

582 have the form of "@py_assert0". 

583 

584 :expl_stmts: The AST statements which will be executed to get 

585 data from the assertion. This is the code which will construct 

586 the detailed assertion message that is used in the AssertionError 

587 or for the pytest_assertion_pass hook. 

588 

589 :explanation_specifiers: A dict filled by .explanation_param() 

590 with %-formatting placeholders and their corresponding 

591 expressions to use in the building of an assertion message. 

592 This is used by .pop_format_context() to build a message. 

593 

594 :stack: A stack of the explanation_specifiers dicts maintained by 

595 .push_format_context() and .pop_format_context() which allows 

596 to build another %-formatted string while already building one. 

597 

598 This state is reset on every new assert statement visited and used 

599 by the other visitors. 

600 

601 """ 

602 

603 def __init__(self, module_path, config, source): 

604 super().__init__() 

605 self.module_path = module_path 

606 self.config = config 

607 if config is not None: 

608 self.enable_assertion_pass_hook = config.getini( 

609 "enable_assertion_pass_hook" 

610 ) 

611 else: 

612 self.enable_assertion_pass_hook = False 

613 self.source = source 

614 

615 @functools.lru_cache(maxsize=1) 

616 def _assert_expr_to_lineno(self): 

617 return _get_assertion_exprs(self.source) 

618 

619 def run(self, mod: ast.Module) -> None: 

620 """Find all assert statements in *mod* and rewrite them.""" 

621 if not mod.body: 

622 # Nothing to do. 

623 return 

624 # Insert some special imports at the top of the module but after any 

625 # docstrings and __future__ imports. 

626 aliases = [ 

627 ast.alias("builtins", "@py_builtins"), 

628 ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), 

629 ] 

630 doc = getattr(mod, "docstring", None) 

631 expect_docstring = doc is None 

632 if doc is not None and self.is_rewrite_disabled(doc): 

633 return 

634 pos = 0 

635 lineno = 1 

636 for item in mod.body: 

637 if ( 

638 expect_docstring 

639 and isinstance(item, ast.Expr) 

640 and isinstance(item.value, ast.Str) 

641 ): 

642 doc = item.value.s 

643 if self.is_rewrite_disabled(doc): 

644 return 

645 expect_docstring = False 

646 elif ( 

647 not isinstance(item, ast.ImportFrom) 

648 or item.level > 0 

649 or item.module != "__future__" 

650 ): 

651 lineno = item.lineno 

652 break 

653 pos += 1 

654 else: 

655 lineno = item.lineno 

656 imports = [ 

657 ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases 

658 ] 

659 mod.body[pos:pos] = imports 

660 # Collect asserts. 

661 nodes = [mod] # type: List[ast.AST] 

662 while nodes: 

663 node = nodes.pop() 

664 for name, field in ast.iter_fields(node): 

665 if isinstance(field, list): 

666 new = [] # type: List 

667 for i, child in enumerate(field): 

668 if isinstance(child, ast.Assert): 

669 # Transform assert. 

670 new.extend(self.visit(child)) 

671 else: 

672 new.append(child) 

673 if isinstance(child, ast.AST): 

674 nodes.append(child) 

675 setattr(node, name, new) 

676 elif ( 

677 isinstance(field, ast.AST) 

678 # Don't recurse into expressions as they can't contain 

679 # asserts. 

680 and not isinstance(field, ast.expr) 

681 ): 

682 nodes.append(field) 

683 

684 @staticmethod 

685 def is_rewrite_disabled(docstring): 

686 return "PYTEST_DONT_REWRITE" in docstring 

687 

688 def variable(self): 

689 """Get a new variable.""" 

690 # Use a character invalid in python identifiers to avoid clashing. 

691 name = "@py_assert" + str(next(self.variable_counter)) 

692 self.variables.append(name) 

693 return name 

694 

695 def assign(self, expr): 

696 """Give *expr* a name.""" 

697 name = self.variable() 

698 self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) 

699 return ast.Name(name, ast.Load()) 

700 

701 def display(self, expr): 

702 """Call saferepr on the expression.""" 

703 return self.helper("_saferepr", expr) 

704 

705 def helper(self, name, *args): 

706 """Call a helper in this module.""" 

707 py_name = ast.Name("@pytest_ar", ast.Load()) 

708 attr = ast.Attribute(py_name, name, ast.Load()) 

709 return ast.Call(attr, list(args), []) 

710 

711 def builtin(self, name): 

712 """Return the builtin called *name*.""" 

713 builtin_name = ast.Name("@py_builtins", ast.Load()) 

714 return ast.Attribute(builtin_name, name, ast.Load()) 

715 

716 def explanation_param(self, expr): 

717 """Return a new named %-formatting placeholder for expr. 

718 

719 This creates a %-formatting placeholder for expr in the 

720 current formatting context, e.g. ``%(py0)s``. The placeholder 

721 and expr are placed in the current format context so that it 

722 can be used on the next call to .pop_format_context(). 

723 

724 """ 

725 specifier = "py" + str(next(self.variable_counter)) 

726 self.explanation_specifiers[specifier] = expr 

727 return "%(" + specifier + ")s" 

728 

729 def push_format_context(self): 

730 """Create a new formatting context. 

731 

732 The format context is used for when an explanation wants to 

733 have a variable value formatted in the assertion message. In 

734 this case the value required can be added using 

735 .explanation_param(). Finally .pop_format_context() is used 

736 to format a string of %-formatted values as added by 

737 .explanation_param(). 

738 

739 """ 

740 self.explanation_specifiers = {} # type: Dict[str, ast.expr] 

741 self.stack.append(self.explanation_specifiers) 

742 

743 def pop_format_context(self, expl_expr): 

744 """Format the %-formatted string with current format context. 

745 

746 The expl_expr should be an ast.Str instance constructed from 

747 the %-placeholders created by .explanation_param(). This will 

748 add the required code to format said string to .expl_stmts and 

749 return the ast.Name instance of the formatted string. 

750 

751 """ 

752 current = self.stack.pop() 

753 if self.stack: 

754 self.explanation_specifiers = self.stack[-1] 

755 keys = [ast.Str(key) for key in current.keys()] 

756 format_dict = ast.Dict(keys, list(current.values())) 

757 form = ast.BinOp(expl_expr, ast.Mod(), format_dict) 

758 name = "@py_format" + str(next(self.variable_counter)) 

759 if self.enable_assertion_pass_hook: 

760 self.format_variables.append(name) 

761 self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) 

762 return ast.Name(name, ast.Load()) 

763 

764 def generic_visit(self, node): 

765 """Handle expressions we don't have custom code for.""" 

766 assert isinstance(node, ast.expr) 

767 res = self.assign(node) 

768 return res, self.explanation_param(self.display(res)) 

769 

770 def visit_Assert(self, assert_): 

771 """Return the AST statements to replace the ast.Assert instance. 

772 

773 This rewrites the test of an assertion to provide 

774 intermediate values and replace it with an if statement which 

775 raises an assertion error with a detailed explanation in case 

776 the expression is false. 

777 

778 """ 

779 if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: 

780 from _pytest.warning_types import PytestAssertRewriteWarning 

781 import warnings 

782 

783 warnings.warn_explicit( 

784 PytestAssertRewriteWarning( 

785 "assertion is always true, perhaps remove parentheses?" 

786 ), 

787 category=None, 

788 filename=fspath(self.module_path), 

789 lineno=assert_.lineno, 

790 ) 

791 

792 self.statements = [] # type: List[ast.stmt] 

793 self.variables = [] # type: List[str] 

794 self.variable_counter = itertools.count() 

795 

796 if self.enable_assertion_pass_hook: 

797 self.format_variables = [] # type: List[str] 

798 

799 self.stack = [] # type: List[Dict[str, ast.expr]] 

800 self.expl_stmts = [] # type: List[ast.stmt] 

801 self.push_format_context() 

802 # Rewrite assert into a bunch of statements. 

803 top_condition, explanation = self.visit(assert_.test) 

804 

805 negation = ast.UnaryOp(ast.Not(), top_condition) 

806 

807 if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook 

808 msg = self.pop_format_context(ast.Str(explanation)) 

809 

810 # Failed 

811 if assert_.msg: 

812 assertmsg = self.helper("_format_assertmsg", assert_.msg) 

813 gluestr = "\n>assert " 

814 else: 

815 assertmsg = ast.Str("") 

816 gluestr = "assert " 

817 err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) 

818 err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) 

819 err_name = ast.Name("AssertionError", ast.Load()) 

820 fmt = self.helper("_format_explanation", err_msg) 

821 exc = ast.Call(err_name, [fmt], []) 

822 raise_ = ast.Raise(exc, None) 

823 statements_fail = [] 

824 statements_fail.extend(self.expl_stmts) 

825 statements_fail.append(raise_) 

826 

827 # Passed 

828 fmt_pass = self.helper("_format_explanation", msg) 

829 orig = self._assert_expr_to_lineno()[assert_.lineno] 

830 hook_call_pass = ast.Expr( 

831 self.helper( 

832 "_call_assertion_pass", 

833 ast.Num(assert_.lineno), 

834 ast.Str(orig), 

835 fmt_pass, 

836 ) 

837 ) 

838 # If any hooks implement assert_pass hook 

839 hook_impl_test = ast.If( 

840 self.helper("_check_if_assertion_pass_impl"), 

841 self.expl_stmts + [hook_call_pass], 

842 [], 

843 ) 

844 statements_pass = [hook_impl_test] 

845 

846 # Test for assertion condition 

847 main_test = ast.If(negation, statements_fail, statements_pass) 

848 self.statements.append(main_test) 

849 if self.format_variables: 

850 variables = [ 

851 ast.Name(name, ast.Store()) for name in self.format_variables 

852 ] 

853 clear_format = ast.Assign(variables, ast.NameConstant(None)) 

854 self.statements.append(clear_format) 

855 

856 else: # Original assertion rewriting 

857 # Create failure message. 

858 body = self.expl_stmts 

859 self.statements.append(ast.If(negation, body, [])) 

860 if assert_.msg: 

861 assertmsg = self.helper("_format_assertmsg", assert_.msg) 

862 explanation = "\n>assert " + explanation 

863 else: 

864 assertmsg = ast.Str("") 

865 explanation = "assert " + explanation 

866 template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) 

867 msg = self.pop_format_context(template) 

868 fmt = self.helper("_format_explanation", msg) 

869 err_name = ast.Name("AssertionError", ast.Load()) 

870 exc = ast.Call(err_name, [fmt], []) 

871 raise_ = ast.Raise(exc, None) 

872 

873 body.append(raise_) 

874 

875 # Clear temporary variables by setting them to None. 

876 if self.variables: 

877 variables = [ast.Name(name, ast.Store()) for name in self.variables] 

878 clear = ast.Assign(variables, ast.NameConstant(None)) 

879 self.statements.append(clear) 

880 # Fix line numbers. 

881 for stmt in self.statements: 

882 set_location(stmt, assert_.lineno, assert_.col_offset) 

883 return self.statements 

884 

885 def visit_Name(self, name): 

886 # Display the repr of the name if it's a local variable or 

887 # _should_repr_global_name() thinks it's acceptable. 

888 locs = ast.Call(self.builtin("locals"), [], []) 

889 inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs]) 

890 dorepr = self.helper("_should_repr_global_name", name) 

891 test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) 

892 expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) 

893 return name, self.explanation_param(expr) 

894 

895 def visit_BoolOp(self, boolop): 

896 res_var = self.variable() 

897 expl_list = self.assign(ast.List([], ast.Load())) 

898 app = ast.Attribute(expl_list, "append", ast.Load()) 

899 is_or = int(isinstance(boolop.op, ast.Or)) 

900 body = save = self.statements 

901 fail_save = self.expl_stmts 

902 levels = len(boolop.values) - 1 

903 self.push_format_context() 

904 # Process each operand, short-circuiting if needed. 

905 for i, v in enumerate(boolop.values): 

906 if i: 

907 fail_inner = [] # type: List[ast.stmt] 

908 # cond is set in a prior loop iteration below 

909 self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa 

910 self.expl_stmts = fail_inner 

911 self.push_format_context() 

912 res, expl = self.visit(v) 

913 body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) 

914 expl_format = self.pop_format_context(ast.Str(expl)) 

915 call = ast.Call(app, [expl_format], []) 

916 self.expl_stmts.append(ast.Expr(call)) 

917 if i < levels: 

918 cond = res # type: ast.expr 

919 if is_or: 

920 cond = ast.UnaryOp(ast.Not(), cond) 

921 inner = [] # type: List[ast.stmt] 

922 self.statements.append(ast.If(cond, inner, [])) 

923 self.statements = body = inner 

924 self.statements = save 

925 self.expl_stmts = fail_save 

926 expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) 

927 expl = self.pop_format_context(expl_template) 

928 return ast.Name(res_var, ast.Load()), self.explanation_param(expl) 

929 

930 def visit_UnaryOp(self, unary): 

931 pattern = UNARY_MAP[unary.op.__class__] 

932 operand_res, operand_expl = self.visit(unary.operand) 

933 res = self.assign(ast.UnaryOp(unary.op, operand_res)) 

934 return res, pattern % (operand_expl,) 

935 

936 def visit_BinOp(self, binop): 

937 symbol = BINOP_MAP[binop.op.__class__] 

938 left_expr, left_expl = self.visit(binop.left) 

939 right_expr, right_expl = self.visit(binop.right) 

940 explanation = "({} {} {})".format(left_expl, symbol, right_expl) 

941 res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) 

942 return res, explanation 

943 

944 def visit_Call(self, call): 

945 """ 

946 visit `ast.Call` nodes 

947 """ 

948 new_func, func_expl = self.visit(call.func) 

949 arg_expls = [] 

950 new_args = [] 

951 new_kwargs = [] 

952 for arg in call.args: 

953 res, expl = self.visit(arg) 

954 arg_expls.append(expl) 

955 new_args.append(res) 

956 for keyword in call.keywords: 

957 res, expl = self.visit(keyword.value) 

958 new_kwargs.append(ast.keyword(keyword.arg, res)) 

959 if keyword.arg: 

960 arg_expls.append(keyword.arg + "=" + expl) 

961 else: # **args have `arg` keywords with an .arg of None 

962 arg_expls.append("**" + expl) 

963 

964 expl = "{}({})".format(func_expl, ", ".join(arg_expls)) 

965 new_call = ast.Call(new_func, new_args, new_kwargs) 

966 res = self.assign(new_call) 

967 res_expl = self.explanation_param(self.display(res)) 

968 outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl) 

969 return res, outer_expl 

970 

971 def visit_Starred(self, starred): 

972 # From Python 3.5, a Starred node can appear in a function call 

973 res, expl = self.visit(starred.value) 

974 new_starred = ast.Starred(res, starred.ctx) 

975 return new_starred, "*" + expl 

976 

977 def visit_Attribute(self, attr): 

978 if not isinstance(attr.ctx, ast.Load): 

979 return self.generic_visit(attr) 

980 value, value_expl = self.visit(attr.value) 

981 res = self.assign(ast.Attribute(value, attr.attr, ast.Load())) 

982 res_expl = self.explanation_param(self.display(res)) 

983 pat = "%s\n{%s = %s.%s\n}" 

984 expl = pat % (res_expl, res_expl, value_expl, attr.attr) 

985 return res, expl 

986 

987 def visit_Compare(self, comp: ast.Compare): 

988 self.push_format_context() 

989 left_res, left_expl = self.visit(comp.left) 

990 if isinstance(comp.left, (ast.Compare, ast.BoolOp)): 

991 left_expl = "({})".format(left_expl) 

992 res_variables = [self.variable() for i in range(len(comp.ops))] 

993 load_names = [ast.Name(v, ast.Load()) for v in res_variables] 

994 store_names = [ast.Name(v, ast.Store()) for v in res_variables] 

995 it = zip(range(len(comp.ops)), comp.ops, comp.comparators) 

996 expls = [] 

997 syms = [] 

998 results = [left_res] 

999 for i, op, next_operand in it: 

1000 next_res, next_expl = self.visit(next_operand) 

1001 if isinstance(next_operand, (ast.Compare, ast.BoolOp)): 

1002 next_expl = "({})".format(next_expl) 

1003 results.append(next_res) 

1004 sym = BINOP_MAP[op.__class__] 

1005 syms.append(ast.Str(sym)) 

1006 expl = "{} {} {}".format(left_expl, sym, next_expl) 

1007 expls.append(ast.Str(expl)) 

1008 res_expr = ast.Compare(left_res, [op], [next_res]) 

1009 self.statements.append(ast.Assign([store_names[i]], res_expr)) 

1010 left_res, left_expl = next_res, next_expl 

1011 # Use pytest.assertion.util._reprcompare if that's available. 

1012 expl_call = self.helper( 

1013 "_call_reprcompare", 

1014 ast.Tuple(syms, ast.Load()), 

1015 ast.Tuple(load_names, ast.Load()), 

1016 ast.Tuple(expls, ast.Load()), 

1017 ast.Tuple(results, ast.Load()), 

1018 ) 

1019 if len(comp.ops) > 1: 

1020 res = ast.BoolOp(ast.And(), load_names) # type: ast.expr 

1021 else: 

1022 res = load_names[0] 

1023 return res, self.explanation_param(self.pop_format_context(expl_call)) 

1024 

1025 

1026def try_makedirs(cache_dir) -> bool: 

1027 """Attempts to create the given directory and sub-directories exist, returns True if 

1028 successful or it already exists""" 

1029 try: 

1030 os.makedirs(fspath(cache_dir), exist_ok=True) 

1031 except (FileNotFoundError, NotADirectoryError, FileExistsError): 

1032 # One of the path components was not a directory: 

1033 # - we're in a zip file 

1034 # - it is a file 

1035 return False 

1036 except PermissionError: 

1037 return False 

1038 except OSError as e: 

1039 # as of now, EROFS doesn't have an equivalent OSError-subclass 

1040 if e.errno == errno.EROFS: 

1041 return False 

1042 raise 

1043 return True 

1044 

1045 

1046def get_cache_dir(file_path: Path) -> Path: 

1047 """Returns the cache directory to write .pyc files for the given .py file path""" 

1048 if sys.version_info >= (3, 8) and sys.pycache_prefix: 

1049 # given: 

1050 # prefix = '/tmp/pycs' 

1051 # path = '/home/user/proj/test_app.py' 

1052 # we want: 

1053 # '/tmp/pycs/home/user/proj' 

1054 return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) 

1055 else: 

1056 # classic pycache directory 

1057 return file_path.parent / "__pycache__"