Coverage for /usr/local/lib/python3.7/site-packages/_pytest/assertion/rewrite.py : 14%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Rewrite assertion AST to produce nice error messages"""
2import ast
3import errno
4import functools
5import importlib.abc
6import importlib.machinery
7import importlib.util
8import io
9import itertools
10import marshal
11import os
12import struct
13import sys
14import tokenize
15import types
16from typing import Dict
17from typing import List
18from typing import Optional
19from typing import Set
20from typing import Tuple
22from _pytest._io.saferepr import saferepr
23from _pytest._version import version
24from _pytest.assertion import util
25from _pytest.assertion.util import ( # noqa: F401
26 format_explanation as _format_explanation,
27)
28from _pytest.compat import fspath
29from _pytest.pathlib import fnmatch_ex
30from _pytest.pathlib import Path
31from _pytest.pathlib import PurePath
33# pytest caches rewritten pycs in pycache dirs
34PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version)
35PYC_EXT = ".py" + (__debug__ and "c" or "o")
36PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
39class AssertionRewritingHook(importlib.abc.MetaPathFinder):
40 """PEP302/PEP451 import hook which rewrites asserts."""
42 def __init__(self, config):
43 self.config = config
44 try:
45 self.fnpats = config.getini("python_files")
46 except ValueError:
47 self.fnpats = ["test_*.py", "*_test.py"]
48 self.session = None
49 self._rewritten_names = set() # type: Set[str]
50 self._must_rewrite = set() # type: Set[str]
51 # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
52 # which might result in infinite recursion (#3506)
53 self._writing_pyc = False
54 self._basenames_to_check_rewrite = {"conftest"}
55 self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
56 self._session_paths_checked = False
58 def set_session(self, session):
59 self.session = session
60 self._session_paths_checked = False
62 # Indirection so we can mock calls to find_spec originated from the hook during testing
63 _find_spec = importlib.machinery.PathFinder.find_spec
65 def find_spec(self, name, path=None, target=None):
66 if self._writing_pyc:
67 return None
68 state = self.config._assertstate
69 if self._early_rewrite_bailout(name, state):
70 return None
71 state.trace("find_module called for: %s" % name)
73 spec = self._find_spec(name, path)
74 if (
75 # the import machinery could not find a file to import
76 spec is None
77 # this is a namespace package (without `__init__.py`)
78 # there's nothing to rewrite there
79 # python3.5 - python3.6: `namespace`
80 # python3.7+: `None`
81 or spec.origin == "namespace"
82 or spec.origin is None
83 # we can only rewrite source files
84 or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
85 # if the file doesn't exist, we can't rewrite it
86 or not os.path.exists(spec.origin)
87 ):
88 return None
89 else:
90 fn = spec.origin
92 if not self._should_rewrite(name, fn, state):
93 return None
95 return importlib.util.spec_from_file_location(
96 name,
97 fn,
98 loader=self,
99 submodule_search_locations=spec.submodule_search_locations,
100 )
102 def create_module(self, spec):
103 return None # default behaviour is fine
105 def exec_module(self, module):
106 fn = Path(module.__spec__.origin)
107 state = self.config._assertstate
109 self._rewritten_names.add(module.__name__)
111 # The requested module looks like a test file, so rewrite it. This is
112 # the most magical part of the process: load the source, rewrite the
113 # asserts, and load the rewritten source. We also cache the rewritten
114 # module code in a special pyc. We must be aware of the possibility of
115 # concurrent pytest processes rewriting and loading pycs. To avoid
116 # tricky race conditions, we maintain the following invariant: The
117 # cached pyc is always a complete, valid pyc. Operations on it must be
118 # atomic. POSIX's atomic rename comes in handy.
119 write = not sys.dont_write_bytecode
120 cache_dir = get_cache_dir(fn)
121 if write:
122 ok = try_makedirs(cache_dir)
123 if not ok:
124 write = False
125 state.trace("read only directory: {}".format(cache_dir))
127 cache_name = fn.name[:-3] + PYC_TAIL
128 pyc = cache_dir / cache_name
129 # Notice that even if we're in a read-only directory, I'm going
130 # to check for a cached pyc. This may not be optimal...
131 co = _read_pyc(fn, pyc, state.trace)
132 if co is None:
133 state.trace("rewriting {!r}".format(fn))
134 source_stat, co = _rewrite_test(fn, self.config)
135 if write:
136 self._writing_pyc = True
137 try:
138 _write_pyc(state, co, source_stat, pyc)
139 finally:
140 self._writing_pyc = False
141 else:
142 state.trace("found cached rewritten pyc for {}".format(fn))
143 exec(co, module.__dict__)
145 def _early_rewrite_bailout(self, name, state):
146 """This is a fast way to get out of rewriting modules.
148 Profiling has shown that the call to PathFinder.find_spec (inside of
149 the find_spec from this class) is a major slowdown, so, this method
150 tries to filter what we're sure won't be rewritten before getting to
151 it.
152 """
153 if self.session is not None and not self._session_paths_checked:
154 self._session_paths_checked = True
155 for path in self.session._initialpaths:
156 # Make something as c:/projects/my_project/path.py ->
157 # ['c:', 'projects', 'my_project', 'path.py']
158 parts = str(path).split(os.path.sep)
159 # add 'path' to basenames to be checked.
160 self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
162 # Note: conftest already by default in _basenames_to_check_rewrite.
163 parts = name.split(".")
164 if parts[-1] in self._basenames_to_check_rewrite:
165 return False
167 # For matching the name it must be as if it was a filename.
168 path = PurePath(os.path.sep.join(parts) + ".py")
170 for pat in self.fnpats:
171 # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
172 # on the name alone because we need to match against the full path
173 if os.path.dirname(pat):
174 return False
175 if fnmatch_ex(pat, path):
176 return False
178 if self._is_marked_for_rewrite(name, state):
179 return False
181 state.trace("early skip of rewriting module: {}".format(name))
182 return True
184 def _should_rewrite(self, name, fn, state):
185 # always rewrite conftest files
186 if os.path.basename(fn) == "conftest.py":
187 state.trace("rewriting conftest file: {!r}".format(fn))
188 return True
190 if self.session is not None:
191 if self.session.isinitpath(fn):
192 state.trace(
193 "matched test file (was specified on cmdline): {!r}".format(fn)
194 )
195 return True
197 # modules not passed explicitly on the command line are only
198 # rewritten if they match the naming convention for test files
199 fn_path = PurePath(fn)
200 for pat in self.fnpats:
201 if fnmatch_ex(pat, fn_path):
202 state.trace("matched test file {!r}".format(fn))
203 return True
205 return self._is_marked_for_rewrite(name, state)
207 def _is_marked_for_rewrite(self, name: str, state):
208 try:
209 return self._marked_for_rewrite_cache[name]
210 except KeyError:
211 for marked in self._must_rewrite:
212 if name == marked or name.startswith(marked + "."):
213 state.trace(
214 "matched marked file {!r} (from {!r})".format(name, marked)
215 )
216 self._marked_for_rewrite_cache[name] = True
217 return True
219 self._marked_for_rewrite_cache[name] = False
220 return False
222 def mark_rewrite(self, *names: str) -> None:
223 """Mark import names as needing to be rewritten.
225 The named module or package as well as any nested modules will
226 be rewritten on import.
227 """
228 already_imported = (
229 set(names).intersection(sys.modules).difference(self._rewritten_names)
230 )
231 for name in already_imported:
232 mod = sys.modules[name]
233 if not AssertionRewriter.is_rewrite_disabled(
234 mod.__doc__ or ""
235 ) and not isinstance(mod.__loader__, type(self)):
236 self._warn_already_imported(name)
237 self._must_rewrite.update(names)
238 self._marked_for_rewrite_cache.clear()
240 def _warn_already_imported(self, name):
241 from _pytest.warning_types import PytestAssertRewriteWarning
242 from _pytest.warnings import _issue_warning_captured
244 _issue_warning_captured(
245 PytestAssertRewriteWarning(
246 "Module already imported so cannot be rewritten: %s" % name
247 ),
248 self.config.hook,
249 stacklevel=5,
250 )
252 def get_data(self, pathname):
253 """Optional PEP302 get_data API."""
254 with open(pathname, "rb") as f:
255 return f.read()
258def _write_pyc_fp(fp, source_stat, co):
259 # Technically, we don't have to have the same pyc format as
260 # (C)Python, since these "pycs" should never be seen by builtin
261 # import. However, there's little reason deviate.
262 fp.write(importlib.util.MAGIC_NUMBER)
263 # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
264 mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
265 size = source_stat.st_size & 0xFFFFFFFF
266 # "<LL" stands for 2 unsigned longs, little-ending
267 fp.write(struct.pack("<LL", mtime, size))
268 fp.write(marshal.dumps(co))
271if sys.platform == "win32":
272 from atomicwrites import atomic_write
274 def _write_pyc(state, co, source_stat, pyc):
275 try:
276 with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
277 _write_pyc_fp(fp, source_stat, co)
278 except EnvironmentError as e:
279 state.trace("error writing pyc file at {}: errno={}".format(pyc, e.errno))
280 # we ignore any failure to write the cache file
281 # there are many reasons, permission-denied, pycache dir being a
282 # file etc.
283 return False
284 return True
287else:
289 def _write_pyc(state, co, source_stat, pyc):
290 proc_pyc = "{}.{}".format(pyc, os.getpid())
291 try:
292 fp = open(proc_pyc, "wb")
293 except EnvironmentError as e:
294 state.trace(
295 "error writing pyc file at {}: errno={}".format(proc_pyc, e.errno)
296 )
297 return False
299 try:
300 _write_pyc_fp(fp, source_stat, co)
301 os.rename(proc_pyc, fspath(pyc))
302 except BaseException as e:
303 state.trace("error writing pyc file at {}: errno={}".format(pyc, e.errno))
304 # we ignore any failure to write the cache file
305 # there are many reasons, permission-denied, pycache dir being a
306 # file etc.
307 return False
308 finally:
309 fp.close()
310 return True
313def _rewrite_test(fn, config):
314 """read and rewrite *fn* and return the code object."""
315 fn = fspath(fn)
316 stat = os.stat(fn)
317 with open(fn, "rb") as f:
318 source = f.read()
319 tree = ast.parse(source, filename=fn)
320 rewrite_asserts(tree, source, fn, config)
321 co = compile(tree, fn, "exec", dont_inherit=True)
322 return stat, co
325def _read_pyc(source, pyc, trace=lambda x: None):
326 """Possibly read a pytest pyc containing rewritten code.
328 Return rewritten code if successful or None if not.
329 """
330 try:
331 fp = open(fspath(pyc), "rb")
332 except IOError:
333 return None
334 with fp:
335 try:
336 stat_result = os.stat(fspath(source))
337 mtime = int(stat_result.st_mtime)
338 size = stat_result.st_size
339 data = fp.read(12)
340 except EnvironmentError as e:
341 trace("_read_pyc({}): EnvironmentError {}".format(source, e))
342 return None
343 # Check for invalid or out of date pyc file.
344 if (
345 len(data) != 12
346 or data[:4] != importlib.util.MAGIC_NUMBER
347 or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
348 ):
349 trace("_read_pyc(%s): invalid or out of date pyc" % source)
350 return None
351 try:
352 co = marshal.load(fp)
353 except Exception as e:
354 trace("_read_pyc({}): marshal.load error {}".format(source, e))
355 return None
356 if not isinstance(co, types.CodeType):
357 trace("_read_pyc(%s): not a code object" % source)
358 return None
359 return co
362def rewrite_asserts(mod, source, module_path=None, config=None):
363 """Rewrite the assert statements in mod."""
364 AssertionRewriter(module_path, config, source).run(mod)
367def _saferepr(obj):
368 """Get a safe repr of an object for assertion error messages.
370 The assertion formatting (util.format_explanation()) requires
371 newlines to be escaped since they are a special character for it.
372 Normally assertion.util.format_explanation() does this but for a
373 custom repr it is possible to contain one of the special escape
374 sequences, especially '\n{' and '\n}' are likely to be present in
375 JSON reprs.
377 """
378 return saferepr(obj).replace("\n", "\\n")
381def _format_assertmsg(obj):
382 """Format the custom assertion message given.
384 For strings this simply replaces newlines with '\n~' so that
385 util.format_explanation() will preserve them instead of escaping
386 newlines. For other objects saferepr() is used first.
388 """
389 # reprlib appears to have a bug which means that if a string
390 # contains a newline it gets escaped, however if an object has a
391 # .__repr__() which contains newlines it does not get escaped.
392 # However in either case we want to preserve the newline.
393 replaces = [("\n", "\n~"), ("%", "%%")]
394 if not isinstance(obj, str):
395 obj = saferepr(obj)
396 replaces.append(("\\n", "\n~"))
398 for r1, r2 in replaces:
399 obj = obj.replace(r1, r2)
401 return obj
404def _should_repr_global_name(obj):
405 if callable(obj):
406 return False
408 try:
409 return not hasattr(obj, "__name__")
410 except Exception:
411 return True
414def _format_boolop(explanations, is_or):
415 explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
416 if isinstance(explanation, str):
417 return explanation.replace("%", "%%")
418 else:
419 return explanation.replace(b"%", b"%%")
422def _call_reprcompare(ops, results, expls, each_obj):
423 # type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
424 for i, res, expl in zip(range(len(ops)), results, expls):
425 try:
426 done = not res
427 except Exception:
428 done = True
429 if done:
430 break
431 if util._reprcompare is not None:
432 custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
433 if custom is not None:
434 return custom
435 return expl
438def _call_assertion_pass(lineno, orig, expl):
439 # type: (int, str, str) -> None
440 if util._assertion_pass is not None:
441 util._assertion_pass(lineno, orig, expl)
444def _check_if_assertion_pass_impl():
445 # type: () -> bool
446 """Checks if any plugins implement the pytest_assertion_pass hook
447 in order not to generate explanation unecessarily (might be expensive)"""
448 return True if util._assertion_pass else False
451UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
453BINOP_MAP = {
454 ast.BitOr: "|",
455 ast.BitXor: "^",
456 ast.BitAnd: "&",
457 ast.LShift: "<<",
458 ast.RShift: ">>",
459 ast.Add: "+",
460 ast.Sub: "-",
461 ast.Mult: "*",
462 ast.Div: "/",
463 ast.FloorDiv: "//",
464 ast.Mod: "%%", # escaped for string formatting
465 ast.Eq: "==",
466 ast.NotEq: "!=",
467 ast.Lt: "<",
468 ast.LtE: "<=",
469 ast.Gt: ">",
470 ast.GtE: ">=",
471 ast.Pow: "**",
472 ast.Is: "is",
473 ast.IsNot: "is not",
474 ast.In: "in",
475 ast.NotIn: "not in",
476 ast.MatMult: "@",
477}
480def set_location(node, lineno, col_offset):
481 """Set node location information recursively."""
483 def _fix(node, lineno, col_offset):
484 if "lineno" in node._attributes:
485 node.lineno = lineno
486 if "col_offset" in node._attributes:
487 node.col_offset = col_offset
488 for child in ast.iter_child_nodes(node):
489 _fix(child, lineno, col_offset)
491 _fix(node, lineno, col_offset)
492 return node
495def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
496 """Returns a mapping from {lineno: "assertion test expression"}"""
497 ret = {} # type: Dict[int, str]
499 depth = 0
500 lines = [] # type: List[str]
501 assert_lineno = None # type: Optional[int]
502 seen_lines = set() # type: Set[int]
504 def _write_and_reset() -> None:
505 nonlocal depth, lines, assert_lineno, seen_lines
506 assert assert_lineno is not None
507 ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
508 depth = 0
509 lines = []
510 assert_lineno = None
511 seen_lines = set()
513 tokens = tokenize.tokenize(io.BytesIO(src).readline)
514 for tp, source, (lineno, offset), _, line in tokens:
515 if tp == tokenize.NAME and source == "assert":
516 assert_lineno = lineno
517 elif assert_lineno is not None:
518 # keep track of depth for the assert-message `,` lookup
519 if tp == tokenize.OP and source in "([{":
520 depth += 1
521 elif tp == tokenize.OP and source in ")]}":
522 depth -= 1
524 if not lines:
525 lines.append(line[offset:])
526 seen_lines.add(lineno)
527 # a non-nested comma separates the expression from the message
528 elif depth == 0 and tp == tokenize.OP and source == ",":
529 # one line assert with message
530 if lineno in seen_lines and len(lines) == 1:
531 offset_in_trimmed = offset + len(lines[-1]) - len(line)
532 lines[-1] = lines[-1][:offset_in_trimmed]
533 # multi-line assert with message
534 elif lineno in seen_lines:
535 lines[-1] = lines[-1][:offset]
536 # multi line assert with escapd newline before message
537 else:
538 lines.append(line[:offset])
539 _write_and_reset()
540 elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
541 _write_and_reset()
542 elif lines and lineno not in seen_lines:
543 lines.append(line)
544 seen_lines.add(lineno)
546 return ret
549class AssertionRewriter(ast.NodeVisitor):
550 """Assertion rewriting implementation.
552 The main entrypoint is to call .run() with an ast.Module instance,
553 this will then find all the assert statements and rewrite them to
554 provide intermediate values and a detailed assertion error. See
555 http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
556 for an overview of how this works.
558 The entry point here is .run() which will iterate over all the
559 statements in an ast.Module and for each ast.Assert statement it
560 finds call .visit() with it. Then .visit_Assert() takes over and
561 is responsible for creating new ast statements to replace the
562 original assert statement: it rewrites the test of an assertion
563 to provide intermediate values and replace it with an if statement
564 which raises an assertion error with a detailed explanation in
565 case the expression is false and calls pytest_assertion_pass hook
566 if expression is true.
568 For this .visit_Assert() uses the visitor pattern to visit all the
569 AST nodes of the ast.Assert.test field, each visit call returning
570 an AST node and the corresponding explanation string. During this
571 state is kept in several instance attributes:
573 :statements: All the AST statements which will replace the assert
574 statement.
576 :variables: This is populated by .variable() with each variable
577 used by the statements so that they can all be set to None at
578 the end of the statements.
580 :variable_counter: Counter to create new unique variables needed
581 by statements. Variables are created using .variable() and
582 have the form of "@py_assert0".
584 :expl_stmts: The AST statements which will be executed to get
585 data from the assertion. This is the code which will construct
586 the detailed assertion message that is used in the AssertionError
587 or for the pytest_assertion_pass hook.
589 :explanation_specifiers: A dict filled by .explanation_param()
590 with %-formatting placeholders and their corresponding
591 expressions to use in the building of an assertion message.
592 This is used by .pop_format_context() to build a message.
594 :stack: A stack of the explanation_specifiers dicts maintained by
595 .push_format_context() and .pop_format_context() which allows
596 to build another %-formatted string while already building one.
598 This state is reset on every new assert statement visited and used
599 by the other visitors.
601 """
603 def __init__(self, module_path, config, source):
604 super().__init__()
605 self.module_path = module_path
606 self.config = config
607 if config is not None:
608 self.enable_assertion_pass_hook = config.getini(
609 "enable_assertion_pass_hook"
610 )
611 else:
612 self.enable_assertion_pass_hook = False
613 self.source = source
615 @functools.lru_cache(maxsize=1)
616 def _assert_expr_to_lineno(self):
617 return _get_assertion_exprs(self.source)
619 def run(self, mod: ast.Module) -> None:
620 """Find all assert statements in *mod* and rewrite them."""
621 if not mod.body:
622 # Nothing to do.
623 return
624 # Insert some special imports at the top of the module but after any
625 # docstrings and __future__ imports.
626 aliases = [
627 ast.alias("builtins", "@py_builtins"),
628 ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
629 ]
630 doc = getattr(mod, "docstring", None)
631 expect_docstring = doc is None
632 if doc is not None and self.is_rewrite_disabled(doc):
633 return
634 pos = 0
635 lineno = 1
636 for item in mod.body:
637 if (
638 expect_docstring
639 and isinstance(item, ast.Expr)
640 and isinstance(item.value, ast.Str)
641 ):
642 doc = item.value.s
643 if self.is_rewrite_disabled(doc):
644 return
645 expect_docstring = False
646 elif (
647 not isinstance(item, ast.ImportFrom)
648 or item.level > 0
649 or item.module != "__future__"
650 ):
651 lineno = item.lineno
652 break
653 pos += 1
654 else:
655 lineno = item.lineno
656 imports = [
657 ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
658 ]
659 mod.body[pos:pos] = imports
660 # Collect asserts.
661 nodes = [mod] # type: List[ast.AST]
662 while nodes:
663 node = nodes.pop()
664 for name, field in ast.iter_fields(node):
665 if isinstance(field, list):
666 new = [] # type: List
667 for i, child in enumerate(field):
668 if isinstance(child, ast.Assert):
669 # Transform assert.
670 new.extend(self.visit(child))
671 else:
672 new.append(child)
673 if isinstance(child, ast.AST):
674 nodes.append(child)
675 setattr(node, name, new)
676 elif (
677 isinstance(field, ast.AST)
678 # Don't recurse into expressions as they can't contain
679 # asserts.
680 and not isinstance(field, ast.expr)
681 ):
682 nodes.append(field)
684 @staticmethod
685 def is_rewrite_disabled(docstring):
686 return "PYTEST_DONT_REWRITE" in docstring
688 def variable(self):
689 """Get a new variable."""
690 # Use a character invalid in python identifiers to avoid clashing.
691 name = "@py_assert" + str(next(self.variable_counter))
692 self.variables.append(name)
693 return name
695 def assign(self, expr):
696 """Give *expr* a name."""
697 name = self.variable()
698 self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
699 return ast.Name(name, ast.Load())
701 def display(self, expr):
702 """Call saferepr on the expression."""
703 return self.helper("_saferepr", expr)
705 def helper(self, name, *args):
706 """Call a helper in this module."""
707 py_name = ast.Name("@pytest_ar", ast.Load())
708 attr = ast.Attribute(py_name, name, ast.Load())
709 return ast.Call(attr, list(args), [])
711 def builtin(self, name):
712 """Return the builtin called *name*."""
713 builtin_name = ast.Name("@py_builtins", ast.Load())
714 return ast.Attribute(builtin_name, name, ast.Load())
716 def explanation_param(self, expr):
717 """Return a new named %-formatting placeholder for expr.
719 This creates a %-formatting placeholder for expr in the
720 current formatting context, e.g. ``%(py0)s``. The placeholder
721 and expr are placed in the current format context so that it
722 can be used on the next call to .pop_format_context().
724 """
725 specifier = "py" + str(next(self.variable_counter))
726 self.explanation_specifiers[specifier] = expr
727 return "%(" + specifier + ")s"
729 def push_format_context(self):
730 """Create a new formatting context.
732 The format context is used for when an explanation wants to
733 have a variable value formatted in the assertion message. In
734 this case the value required can be added using
735 .explanation_param(). Finally .pop_format_context() is used
736 to format a string of %-formatted values as added by
737 .explanation_param().
739 """
740 self.explanation_specifiers = {} # type: Dict[str, ast.expr]
741 self.stack.append(self.explanation_specifiers)
743 def pop_format_context(self, expl_expr):
744 """Format the %-formatted string with current format context.
746 The expl_expr should be an ast.Str instance constructed from
747 the %-placeholders created by .explanation_param(). This will
748 add the required code to format said string to .expl_stmts and
749 return the ast.Name instance of the formatted string.
751 """
752 current = self.stack.pop()
753 if self.stack:
754 self.explanation_specifiers = self.stack[-1]
755 keys = [ast.Str(key) for key in current.keys()]
756 format_dict = ast.Dict(keys, list(current.values()))
757 form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
758 name = "@py_format" + str(next(self.variable_counter))
759 if self.enable_assertion_pass_hook:
760 self.format_variables.append(name)
761 self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
762 return ast.Name(name, ast.Load())
764 def generic_visit(self, node):
765 """Handle expressions we don't have custom code for."""
766 assert isinstance(node, ast.expr)
767 res = self.assign(node)
768 return res, self.explanation_param(self.display(res))
770 def visit_Assert(self, assert_):
771 """Return the AST statements to replace the ast.Assert instance.
773 This rewrites the test of an assertion to provide
774 intermediate values and replace it with an if statement which
775 raises an assertion error with a detailed explanation in case
776 the expression is false.
778 """
779 if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
780 from _pytest.warning_types import PytestAssertRewriteWarning
781 import warnings
783 warnings.warn_explicit(
784 PytestAssertRewriteWarning(
785 "assertion is always true, perhaps remove parentheses?"
786 ),
787 category=None,
788 filename=fspath(self.module_path),
789 lineno=assert_.lineno,
790 )
792 self.statements = [] # type: List[ast.stmt]
793 self.variables = [] # type: List[str]
794 self.variable_counter = itertools.count()
796 if self.enable_assertion_pass_hook:
797 self.format_variables = [] # type: List[str]
799 self.stack = [] # type: List[Dict[str, ast.expr]]
800 self.expl_stmts = [] # type: List[ast.stmt]
801 self.push_format_context()
802 # Rewrite assert into a bunch of statements.
803 top_condition, explanation = self.visit(assert_.test)
805 negation = ast.UnaryOp(ast.Not(), top_condition)
807 if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
808 msg = self.pop_format_context(ast.Str(explanation))
810 # Failed
811 if assert_.msg:
812 assertmsg = self.helper("_format_assertmsg", assert_.msg)
813 gluestr = "\n>assert "
814 else:
815 assertmsg = ast.Str("")
816 gluestr = "assert "
817 err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
818 err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
819 err_name = ast.Name("AssertionError", ast.Load())
820 fmt = self.helper("_format_explanation", err_msg)
821 exc = ast.Call(err_name, [fmt], [])
822 raise_ = ast.Raise(exc, None)
823 statements_fail = []
824 statements_fail.extend(self.expl_stmts)
825 statements_fail.append(raise_)
827 # Passed
828 fmt_pass = self.helper("_format_explanation", msg)
829 orig = self._assert_expr_to_lineno()[assert_.lineno]
830 hook_call_pass = ast.Expr(
831 self.helper(
832 "_call_assertion_pass",
833 ast.Num(assert_.lineno),
834 ast.Str(orig),
835 fmt_pass,
836 )
837 )
838 # If any hooks implement assert_pass hook
839 hook_impl_test = ast.If(
840 self.helper("_check_if_assertion_pass_impl"),
841 self.expl_stmts + [hook_call_pass],
842 [],
843 )
844 statements_pass = [hook_impl_test]
846 # Test for assertion condition
847 main_test = ast.If(negation, statements_fail, statements_pass)
848 self.statements.append(main_test)
849 if self.format_variables:
850 variables = [
851 ast.Name(name, ast.Store()) for name in self.format_variables
852 ]
853 clear_format = ast.Assign(variables, ast.NameConstant(None))
854 self.statements.append(clear_format)
856 else: # Original assertion rewriting
857 # Create failure message.
858 body = self.expl_stmts
859 self.statements.append(ast.If(negation, body, []))
860 if assert_.msg:
861 assertmsg = self.helper("_format_assertmsg", assert_.msg)
862 explanation = "\n>assert " + explanation
863 else:
864 assertmsg = ast.Str("")
865 explanation = "assert " + explanation
866 template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
867 msg = self.pop_format_context(template)
868 fmt = self.helper("_format_explanation", msg)
869 err_name = ast.Name("AssertionError", ast.Load())
870 exc = ast.Call(err_name, [fmt], [])
871 raise_ = ast.Raise(exc, None)
873 body.append(raise_)
875 # Clear temporary variables by setting them to None.
876 if self.variables:
877 variables = [ast.Name(name, ast.Store()) for name in self.variables]
878 clear = ast.Assign(variables, ast.NameConstant(None))
879 self.statements.append(clear)
880 # Fix line numbers.
881 for stmt in self.statements:
882 set_location(stmt, assert_.lineno, assert_.col_offset)
883 return self.statements
885 def visit_Name(self, name):
886 # Display the repr of the name if it's a local variable or
887 # _should_repr_global_name() thinks it's acceptable.
888 locs = ast.Call(self.builtin("locals"), [], [])
889 inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
890 dorepr = self.helper("_should_repr_global_name", name)
891 test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
892 expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
893 return name, self.explanation_param(expr)
895 def visit_BoolOp(self, boolop):
896 res_var = self.variable()
897 expl_list = self.assign(ast.List([], ast.Load()))
898 app = ast.Attribute(expl_list, "append", ast.Load())
899 is_or = int(isinstance(boolop.op, ast.Or))
900 body = save = self.statements
901 fail_save = self.expl_stmts
902 levels = len(boolop.values) - 1
903 self.push_format_context()
904 # Process each operand, short-circuiting if needed.
905 for i, v in enumerate(boolop.values):
906 if i:
907 fail_inner = [] # type: List[ast.stmt]
908 # cond is set in a prior loop iteration below
909 self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
910 self.expl_stmts = fail_inner
911 self.push_format_context()
912 res, expl = self.visit(v)
913 body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
914 expl_format = self.pop_format_context(ast.Str(expl))
915 call = ast.Call(app, [expl_format], [])
916 self.expl_stmts.append(ast.Expr(call))
917 if i < levels:
918 cond = res # type: ast.expr
919 if is_or:
920 cond = ast.UnaryOp(ast.Not(), cond)
921 inner = [] # type: List[ast.stmt]
922 self.statements.append(ast.If(cond, inner, []))
923 self.statements = body = inner
924 self.statements = save
925 self.expl_stmts = fail_save
926 expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
927 expl = self.pop_format_context(expl_template)
928 return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
930 def visit_UnaryOp(self, unary):
931 pattern = UNARY_MAP[unary.op.__class__]
932 operand_res, operand_expl = self.visit(unary.operand)
933 res = self.assign(ast.UnaryOp(unary.op, operand_res))
934 return res, pattern % (operand_expl,)
936 def visit_BinOp(self, binop):
937 symbol = BINOP_MAP[binop.op.__class__]
938 left_expr, left_expl = self.visit(binop.left)
939 right_expr, right_expl = self.visit(binop.right)
940 explanation = "({} {} {})".format(left_expl, symbol, right_expl)
941 res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
942 return res, explanation
944 def visit_Call(self, call):
945 """
946 visit `ast.Call` nodes
947 """
948 new_func, func_expl = self.visit(call.func)
949 arg_expls = []
950 new_args = []
951 new_kwargs = []
952 for arg in call.args:
953 res, expl = self.visit(arg)
954 arg_expls.append(expl)
955 new_args.append(res)
956 for keyword in call.keywords:
957 res, expl = self.visit(keyword.value)
958 new_kwargs.append(ast.keyword(keyword.arg, res))
959 if keyword.arg:
960 arg_expls.append(keyword.arg + "=" + expl)
961 else: # **args have `arg` keywords with an .arg of None
962 arg_expls.append("**" + expl)
964 expl = "{}({})".format(func_expl, ", ".join(arg_expls))
965 new_call = ast.Call(new_func, new_args, new_kwargs)
966 res = self.assign(new_call)
967 res_expl = self.explanation_param(self.display(res))
968 outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
969 return res, outer_expl
971 def visit_Starred(self, starred):
972 # From Python 3.5, a Starred node can appear in a function call
973 res, expl = self.visit(starred.value)
974 new_starred = ast.Starred(res, starred.ctx)
975 return new_starred, "*" + expl
977 def visit_Attribute(self, attr):
978 if not isinstance(attr.ctx, ast.Load):
979 return self.generic_visit(attr)
980 value, value_expl = self.visit(attr.value)
981 res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
982 res_expl = self.explanation_param(self.display(res))
983 pat = "%s\n{%s = %s.%s\n}"
984 expl = pat % (res_expl, res_expl, value_expl, attr.attr)
985 return res, expl
987 def visit_Compare(self, comp: ast.Compare):
988 self.push_format_context()
989 left_res, left_expl = self.visit(comp.left)
990 if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
991 left_expl = "({})".format(left_expl)
992 res_variables = [self.variable() for i in range(len(comp.ops))]
993 load_names = [ast.Name(v, ast.Load()) for v in res_variables]
994 store_names = [ast.Name(v, ast.Store()) for v in res_variables]
995 it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
996 expls = []
997 syms = []
998 results = [left_res]
999 for i, op, next_operand in it:
1000 next_res, next_expl = self.visit(next_operand)
1001 if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
1002 next_expl = "({})".format(next_expl)
1003 results.append(next_res)
1004 sym = BINOP_MAP[op.__class__]
1005 syms.append(ast.Str(sym))
1006 expl = "{} {} {}".format(left_expl, sym, next_expl)
1007 expls.append(ast.Str(expl))
1008 res_expr = ast.Compare(left_res, [op], [next_res])
1009 self.statements.append(ast.Assign([store_names[i]], res_expr))
1010 left_res, left_expl = next_res, next_expl
1011 # Use pytest.assertion.util._reprcompare if that's available.
1012 expl_call = self.helper(
1013 "_call_reprcompare",
1014 ast.Tuple(syms, ast.Load()),
1015 ast.Tuple(load_names, ast.Load()),
1016 ast.Tuple(expls, ast.Load()),
1017 ast.Tuple(results, ast.Load()),
1018 )
1019 if len(comp.ops) > 1:
1020 res = ast.BoolOp(ast.And(), load_names) # type: ast.expr
1021 else:
1022 res = load_names[0]
1023 return res, self.explanation_param(self.pop_format_context(expl_call))
1026def try_makedirs(cache_dir) -> bool:
1027 """Attempts to create the given directory and sub-directories exist, returns True if
1028 successful or it already exists"""
1029 try:
1030 os.makedirs(fspath(cache_dir), exist_ok=True)
1031 except (FileNotFoundError, NotADirectoryError, FileExistsError):
1032 # One of the path components was not a directory:
1033 # - we're in a zip file
1034 # - it is a file
1035 return False
1036 except PermissionError:
1037 return False
1038 except OSError as e:
1039 # as of now, EROFS doesn't have an equivalent OSError-subclass
1040 if e.errno == errno.EROFS:
1041 return False
1042 raise
1043 return True
1046def get_cache_dir(file_path: Path) -> Path:
1047 """Returns the cache directory to write .pyc files for the given .py file path"""
1048 if sys.version_info >= (3, 8) and sys.pycache_prefix:
1049 # given:
1050 # prefix = '/tmp/pycs'
1051 # path = '/home/user/proj/test_app.py'
1052 # we want:
1053 # '/tmp/pycs/home/user/proj'
1054 return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
1055 else:
1056 # classic pycache directory
1057 return file_path.parent / "__pycache__"