Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import ast 

2import inspect 

3import linecache 

4import sys 

5import textwrap 

6import tokenize 

7import warnings 

8from bisect import bisect_right 

9from types import CodeType 

10from types import FrameType 

11from typing import Iterator 

12from typing import List 

13from typing import Optional 

14from typing import Sequence 

15from typing import Tuple 

16from typing import Union 

17 

18import py 

19 

20from _pytest.compat import overload 

21from _pytest.compat import TYPE_CHECKING 

22 

23if TYPE_CHECKING: 

24 from typing_extensions import Literal 

25 

26 

27class Source: 

28 """ an immutable object holding a source code fragment, 

29 possibly deindenting it. 

30 """ 

31 

32 _compilecounter = 0 

33 

34 def __init__(self, *parts, **kwargs) -> None: 

35 self.lines = lines = [] # type: List[str] 

36 de = kwargs.get("deindent", True) 

37 for part in parts: 

38 if not part: 

39 partlines = [] # type: List[str] 

40 elif isinstance(part, Source): 

41 partlines = part.lines 

42 elif isinstance(part, (tuple, list)): 

43 partlines = [x.rstrip("\n") for x in part] 

44 elif isinstance(part, str): 

45 partlines = part.split("\n") 

46 else: 

47 partlines = getsource(part, deindent=de).lines 

48 if de: 

49 partlines = deindent(partlines) 

50 lines.extend(partlines) 

51 

52 def __eq__(self, other): 

53 try: 

54 return self.lines == other.lines 

55 except AttributeError: 

56 if isinstance(other, str): 

57 return str(self) == other 

58 return False 

59 

60 # Ignore type because of https://github.com/python/mypy/issues/4266. 

61 __hash__ = None # type: ignore 

62 

63 @overload 

64 def __getitem__(self, key: int) -> str: 

65 raise NotImplementedError() 

66 

67 @overload # noqa: F811 

68 def __getitem__(self, key: slice) -> "Source": # noqa: F811 

69 raise NotImplementedError() 

70 

71 def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811 

72 if isinstance(key, int): 

73 return self.lines[key] 

74 else: 

75 if key.step not in (None, 1): 

76 raise IndexError("cannot slice a Source with a step") 

77 newsource = Source() 

78 newsource.lines = self.lines[key.start : key.stop] 

79 return newsource 

80 

81 def __iter__(self) -> Iterator[str]: 

82 return iter(self.lines) 

83 

84 def __len__(self) -> int: 

85 return len(self.lines) 

86 

87 def strip(self) -> "Source": 

88 """ return new source object with trailing 

89 and leading blank lines removed. 

90 """ 

91 start, end = 0, len(self) 

92 while start < end and not self.lines[start].strip(): 

93 start += 1 

94 while end > start and not self.lines[end - 1].strip(): 

95 end -= 1 

96 source = Source() 

97 source.lines[:] = self.lines[start:end] 

98 return source 

99 

100 def putaround( 

101 self, before: str = "", after: str = "", indent: str = " " * 4 

102 ) -> "Source": 

103 """ return a copy of the source object with 

104 'before' and 'after' wrapped around it. 

105 """ 

106 beforesource = Source(before) 

107 aftersource = Source(after) 

108 newsource = Source() 

109 lines = [(indent + line) for line in self.lines] 

110 newsource.lines = beforesource.lines + lines + aftersource.lines 

111 return newsource 

112 

113 def indent(self, indent: str = " " * 4) -> "Source": 

114 """ return a copy of the source object with 

115 all lines indented by the given indent-string. 

116 """ 

117 newsource = Source() 

118 newsource.lines = [(indent + line) for line in self.lines] 

119 return newsource 

120 

121 def getstatement(self, lineno: int) -> "Source": 

122 """ return Source statement which contains the 

123 given linenumber (counted from 0). 

124 """ 

125 start, end = self.getstatementrange(lineno) 

126 return self[start:end] 

127 

128 def getstatementrange(self, lineno: int) -> Tuple[int, int]: 

129 """ return (start, end) tuple which spans the minimal 

130 statement region which containing the given lineno. 

131 """ 

132 if not (0 <= lineno < len(self)): 

133 raise IndexError("lineno out of range") 

134 ast, start, end = getstatementrange_ast(lineno, self) 

135 return start, end 

136 

137 def deindent(self) -> "Source": 

138 """return a new source object deindented.""" 

139 newsource = Source() 

140 newsource.lines[:] = deindent(self.lines) 

141 return newsource 

142 

143 def isparseable(self, deindent: bool = True) -> bool: 

144 """ return True if source is parseable, heuristically 

145 deindenting it by default. 

146 """ 

147 from parser import suite as syntax_checker 

148 

149 if deindent: 

150 source = str(self.deindent()) 

151 else: 

152 source = str(self) 

153 try: 

154 # compile(source+'\n', "x", "exec") 

155 syntax_checker(source + "\n") 

156 except KeyboardInterrupt: 

157 raise 

158 except Exception: 

159 return False 

160 else: 

161 return True 

162 

163 def __str__(self) -> str: 

164 return "\n".join(self.lines) 

165 

166 @overload 

167 def compile( 

168 self, 

169 filename: Optional[str] = ..., 

170 mode: str = ..., 

171 flag: "Literal[0]" = ..., 

172 dont_inherit: int = ..., 

173 _genframe: Optional[FrameType] = ..., 

174 ) -> CodeType: 

175 raise NotImplementedError() 

176 

177 @overload # noqa: F811 

178 def compile( # noqa: F811 

179 self, 

180 filename: Optional[str] = ..., 

181 mode: str = ..., 

182 flag: int = ..., 

183 dont_inherit: int = ..., 

184 _genframe: Optional[FrameType] = ..., 

185 ) -> Union[CodeType, ast.AST]: 

186 raise NotImplementedError() 

187 

188 def compile( # noqa: F811 

189 self, 

190 filename: Optional[str] = None, 

191 mode: str = "exec", 

192 flag: int = 0, 

193 dont_inherit: int = 0, 

194 _genframe: Optional[FrameType] = None, 

195 ) -> Union[CodeType, ast.AST]: 

196 """ return compiled code object. if filename is None 

197 invent an artificial filename which displays 

198 the source/line position of the caller frame. 

199 """ 

200 if not filename or py.path.local(filename).check(file=0): 

201 if _genframe is None: 

202 _genframe = sys._getframe(1) # the caller 

203 fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno 

204 base = "<%d-codegen " % self._compilecounter 

205 self.__class__._compilecounter += 1 

206 if not filename: 

207 filename = base + "%s:%d>" % (fn, lineno) 

208 else: 

209 filename = base + "%r %s:%d>" % (filename, fn, lineno) 

210 source = "\n".join(self.lines) + "\n" 

211 try: 

212 co = compile(source, filename, mode, flag) 

213 except SyntaxError as ex: 

214 # re-represent syntax errors from parsing python strings 

215 msglines = self.lines[: ex.lineno] 

216 if ex.offset: 

217 msglines.append(" " * ex.offset + "^") 

218 msglines.append("(code was compiled probably from here: %s)" % filename) 

219 newex = SyntaxError("\n".join(msglines)) 

220 newex.offset = ex.offset 

221 newex.lineno = ex.lineno 

222 newex.text = ex.text 

223 raise newex 

224 else: 

225 if flag & ast.PyCF_ONLY_AST: 

226 assert isinstance(co, ast.AST) 

227 return co 

228 assert isinstance(co, CodeType) 

229 lines = [(x + "\n") for x in self.lines] 

230 # Type ignored because linecache.cache is private. 

231 linecache.cache[filename] = (1, None, lines, filename) # type: ignore 

232 return co 

233 

234 

235# 

236# public API shortcut functions 

237# 

238 

239 

240@overload 

241def compile_( 

242 source: Union[str, bytes, ast.mod, ast.AST], 

243 filename: Optional[str] = ..., 

244 mode: str = ..., 

245 flags: "Literal[0]" = ..., 

246 dont_inherit: int = ..., 

247) -> CodeType: 

248 raise NotImplementedError() 

249 

250 

251@overload # noqa: F811 

252def compile_( # noqa: F811 

253 source: Union[str, bytes, ast.mod, ast.AST], 

254 filename: Optional[str] = ..., 

255 mode: str = ..., 

256 flags: int = ..., 

257 dont_inherit: int = ..., 

258) -> Union[CodeType, ast.AST]: 

259 raise NotImplementedError() 

260 

261 

262def compile_( # noqa: F811 

263 source: Union[str, bytes, ast.mod, ast.AST], 

264 filename: Optional[str] = None, 

265 mode: str = "exec", 

266 flags: int = 0, 

267 dont_inherit: int = 0, 

268) -> Union[CodeType, ast.AST]: 

269 """ compile the given source to a raw code object, 

270 and maintain an internal cache which allows later 

271 retrieval of the source code for the code object 

272 and any recursively created code objects. 

273 """ 

274 if isinstance(source, ast.AST): 

275 # XXX should Source support having AST? 

276 assert filename is not None 

277 co = compile(source, filename, mode, flags, dont_inherit) 

278 assert isinstance(co, (CodeType, ast.AST)) 

279 return co 

280 _genframe = sys._getframe(1) # the caller 

281 s = Source(source) 

282 return s.compile(filename, mode, flags, _genframe=_genframe) 

283 

284 

285def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]: 

286 """ Return source location (path, lineno) for the given object. 

287 If the source cannot be determined return ("", -1). 

288 

289 The line number is 0-based. 

290 """ 

291 from .code import Code 

292 

293 try: 

294 code = Code(obj) 

295 except TypeError: 

296 try: 

297 fn = inspect.getsourcefile(obj) or inspect.getfile(obj) 

298 except TypeError: 

299 return "", -1 

300 

301 fspath = fn and py.path.local(fn) or None 

302 lineno = -1 

303 if fspath: 

304 try: 

305 _, lineno = findsource(obj) 

306 except IOError: 

307 pass 

308 else: 

309 fspath = code.path 

310 lineno = code.firstlineno 

311 assert isinstance(lineno, int) 

312 return fspath, lineno 

313 

314 

315# 

316# helper functions 

317# 

318 

319 

320def findsource(obj) -> Tuple[Optional[Source], int]: 

321 try: 

322 sourcelines, lineno = inspect.findsource(obj) 

323 except Exception: 

324 return None, -1 

325 source = Source() 

326 source.lines = [line.rstrip() for line in sourcelines] 

327 return source, lineno 

328 

329 

330def getsource(obj, **kwargs) -> Source: 

331 from .code import getrawcode 

332 

333 obj = getrawcode(obj) 

334 try: 

335 strsrc = inspect.getsource(obj) 

336 except IndentationError: 

337 strsrc = '"Buggy python version consider upgrading, cannot get source"' 

338 assert isinstance(strsrc, str) 

339 return Source(strsrc, **kwargs) 

340 

341 

342def deindent(lines: Sequence[str]) -> List[str]: 

343 return textwrap.dedent("\n".join(lines)).splitlines() 

344 

345 

346def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: 

347 import ast 

348 

349 # flatten all statements and except handlers into one lineno-list 

350 # AST's line numbers start indexing at 1 

351 values = [] # type: List[int] 

352 for x in ast.walk(node): 

353 if isinstance(x, (ast.stmt, ast.ExceptHandler)): 

354 values.append(x.lineno - 1) 

355 for name in ("finalbody", "orelse"): 

356 val = getattr(x, name, None) # type: Optional[List[ast.stmt]] 

357 if val: 

358 # treat the finally/orelse part as its own statement 

359 values.append(val[0].lineno - 1 - 1) 

360 values.sort() 

361 insert_index = bisect_right(values, lineno) 

362 start = values[insert_index - 1] 

363 if insert_index >= len(values): 

364 end = None 

365 else: 

366 end = values[insert_index] 

367 return start, end 

368 

369 

370def getstatementrange_ast( 

371 lineno: int, 

372 source: Source, 

373 assertion: bool = False, 

374 astnode: Optional[ast.AST] = None, 

375) -> Tuple[ast.AST, int, int]: 

376 if astnode is None: 

377 content = str(source) 

378 # See #4260: 

379 # don't produce duplicate warnings when compiling source to find ast 

380 with warnings.catch_warnings(): 

381 warnings.simplefilter("ignore") 

382 astnode = ast.parse(content, "source", "exec") 

383 

384 start, end = get_statement_startend2(lineno, astnode) 

385 # we need to correct the end: 

386 # - ast-parsing strips comments 

387 # - there might be empty lines 

388 # - we might have lesser indented code blocks at the end 

389 if end is None: 

390 end = len(source.lines) 

391 

392 if end > start + 1: 

393 # make sure we don't span differently indented code blocks 

394 # by using the BlockFinder helper used which inspect.getsource() uses itself 

395 block_finder = inspect.BlockFinder() 

396 # if we start with an indented line, put blockfinder to "started" mode 

397 block_finder.started = source.lines[start][0].isspace() 

398 it = ((x + "\n") for x in source.lines[start:end]) 

399 try: 

400 for tok in tokenize.generate_tokens(lambda: next(it)): 

401 block_finder.tokeneater(*tok) 

402 except (inspect.EndOfBlock, IndentationError): 

403 end = block_finder.last + start 

404 except Exception: 

405 pass 

406 

407 # the end might still point to a comment or empty line, correct it 

408 while end: 

409 line = source.lines[end - 1].lstrip() 

410 if line.startswith("#") or not line: 

411 end -= 1 

412 else: 

413 break 

414 return astnode, start, end