Skip to content

Commit 03040db

Browse files
DinoVfacebook-github-bot
authored andcommitted
Inline comprehensions
Summary: This diff adds support for inlining list/dict/set comprehensions where it is considered safe - names introduced by inlined comprehension will not conflict with local names used in comprehensions or free/implicitly global names used in sibling scopes. It also only inlines comprehensions in functions - inlining for top level statements comes with additional set of challenges and I'm not sure whether adding extra complexity to handle something that is executed once would be worth it. After inlining comprehension we generate the code to delete locals added by comprehension to avoid adding extra references that are not controlled by user. This works fine for non-exceptional case however in case of exception being raised by the comprehension lifetime of object referenced by comprehension iteration variable will be extended until execution leaves current frame. Another related issue is - if original iterable being used in comprehension yields no values, comprehension iteration variable will stay unbound and `DELETE_FAST` would fail. To handle this we can either: - relax requirements to `DELETE_FAST` so deleting unbound name would be no-op - have a dedicated opcode that would behave as relaxed `DELETE_FAST` - keep `DELETE_FAST` relaxed (similar to (1)) but change generated code for `del x` to be `LOAD_FAST; POP_TOP; DELETE_FAST` so name binding would still be checked by `LOAD_FAST` (suggested by DinoV) This diff currently uses option 1 as the simplest one but this could be changed. Reviewed By: vladima Differential Revision: D28940584 fbshipit-source-id: b5b7512
1 parent 9e8e3a5 commit 03040db

22 files changed

+3732
-2765
lines changed

Include/pythonrun.h

+6
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ PyAPI_FUNC(struct symtable *) _Py_SymtableStringObjectFlags(
135135
PyObject *filename,
136136
int start,
137137
PyCompilerFlags *flags);
138+
PyAPI_FUNC(struct symtable *) _Py_SymtableStringObjectFlagsOptFlags(
139+
const char *str,
140+
PyObject *filename,
141+
int start,
142+
PyCompilerFlags *flags,
143+
int inline_comprehensions);
138144
#endif
139145

140146
PyAPI_FUNC(void) PyErr_Print(void);

Include/symtable.h

+7
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct symtable {
3333
the symbol table */
3434
int recursion_depth; /* current recursion depth */
3535
int recursion_limit; /* recursion limit */
36+
int st_inline_comprehensions;
3637
};
3738

3839
typedef struct _symtable_entry {
@@ -64,6 +65,7 @@ typedef struct _symtable_entry {
6465
int ste_col_offset; /* offset of first line of block */
6566
int ste_opt_lineno; /* lineno of last exec or import * */
6667
int ste_opt_col_offset; /* offset of last exec or import * */
68+
unsigned int ste_inlined_comprehension; /* true is comprehension is inlined and symbols were already merged in parent scope */
6769
struct symtable *ste_table;
6870
} PySTEntryObject;
6971

@@ -81,6 +83,11 @@ PyAPI_FUNC(struct symtable *) PySymtable_BuildObject(
8183
mod_ty mod,
8284
PyObject *filename,
8385
PyFutureFeatures *future);
86+
PyAPI_FUNC(struct symtable *) _PySymtable_BuildObjectOptFlags(
87+
mod_ty mod,
88+
PyObject *filename,
89+
PyFutureFeatures *future,
90+
int inline_comprehensions);
8491
PyAPI_FUNC(PySTEntryObject *) PySymtable_Lookup(struct symtable *, void *);
8592

8693
PyAPI_FUNC(void) PySymtable_Free(struct symtable *);

Lib/compiler/pycodegen.py

+87-12
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class CodeGenerator(ASTVisitor):
199199
class_name = None # provide default for instance variable
200200
future_flags = 0
201201
flow_graph = pyassem.PyFlowGraph
202+
_SymbolVisitor = symbols.SymbolVisitor
202203

203204
def __init__(
204205
self,
@@ -912,7 +913,7 @@ def get_qual_prefix(self, gen):
912913
while not isinstance(parent, symbols.ModuleScope):
913914
# Only real functions use "<locals>", nested scopes like
914915
# comprehensions don't.
915-
if type(parent) in (symbols.FunctionScope, symbols.LambdaScope):
916+
if parent.is_function_scope:
916917
prefix = parent.name + ".<locals>." + prefix
917918
else:
918919
prefix = parent.name + "." + prefix
@@ -961,7 +962,9 @@ def compile_comprehension(
961962
if opcode:
962963
gen.emit(opcode, oparg)
963964

964-
gen.compile_comprehension_generator(node.generators, 0, elt, val, type(node))
965+
gen.compile_comprehension_generator(
966+
node.generators, 0, elt, val, type(node), True
967+
)
965968

966969
if not isinstance(node, ast.GeneratorExp):
967970
gen.emit("RETURN_VALUE")
@@ -1001,19 +1004,27 @@ def visitDictComp(self, node):
10011004
node, sys.intern("<dictcomp>"), node.key, node.value, "BUILD_MAP"
10021005
)
10031006

1004-
def compile_comprehension_generator(self, generators, gen_index, elt, val, type):
1007+
def compile_comprehension_generator(
1008+
self, generators, gen_index, elt, val, type, outermost_gen_is_param
1009+
):
10051010
if generators[gen_index].is_async:
1006-
self.compile_async_comprehension(generators, gen_index, elt, val, type)
1011+
self.compile_async_comprehension(
1012+
generators, gen_index, elt, val, type, outermost_gen_is_param
1013+
)
10071014
else:
1008-
self.compile_sync_comprehension(generators, gen_index, elt, val, type)
1015+
self.compile_sync_comprehension(
1016+
generators, gen_index, elt, val, type, outermost_gen_is_param
1017+
)
10091018

1010-
def compile_async_comprehension(self, generators, gen_index, elt, val, type):
1019+
def compile_async_comprehension(
1020+
self, generators, gen_index, elt, val, type, outermost_gen_is_param
1021+
):
10111022
start = self.newBlock("start")
10121023
except_ = self.newBlock("except")
10131024
if_cleanup = self.newBlock("if_cleanup")
10141025

10151026
gen = generators[gen_index]
1016-
if gen_index == 0:
1027+
if gen_index == 0 and outermost_gen_is_param:
10171028
self.loadName(".0")
10181029
else:
10191030
self.visit(gen.iter)
@@ -1033,7 +1044,9 @@ def compile_async_comprehension(self, generators, gen_index, elt, val, type):
10331044

10341045
gen_index += 1
10351046
if gen_index < len(generators):
1036-
self.compile_comprehension_generator(generators, gen_index, elt, val, type)
1047+
self.compile_comprehension_generator(
1048+
generators, gen_index, elt, val, type, False
1049+
)
10371050
elif type is ast.GeneratorExp:
10381051
self.visit(elt)
10391052
self.emit("YIELD_VALUE")
@@ -1056,14 +1069,16 @@ def compile_async_comprehension(self, generators, gen_index, elt, val, type):
10561069
self.nextBlock(except_)
10571070
self.emit("END_ASYNC_FOR")
10581071

1059-
def compile_sync_comprehension(self, generators, gen_index, elt, val, type):
1072+
def compile_sync_comprehension(
1073+
self, generators, gen_index, elt, val, type, outermost_gen_is_param
1074+
):
10601075
start = self.newBlock("start")
10611076
skip = self.newBlock("skip")
10621077
if_cleanup = self.newBlock("if_cleanup")
10631078
anchor = self.newBlock("anchor")
10641079

10651080
gen = generators[gen_index]
1066-
if gen_index == 0:
1081+
if gen_index == 0 and outermost_gen_is_param:
10671082
self.loadName(".0")
10681083
else:
10691084
self.visit(gen.iter)
@@ -1080,7 +1095,9 @@ def compile_sync_comprehension(self, generators, gen_index, elt, val, type):
10801095

10811096
gen_index += 1
10821097
if gen_index < len(generators):
1083-
self.compile_comprehension_generator(generators, gen_index, elt, val, type)
1098+
self.compile_comprehension_generator(
1099+
generators, gen_index, elt, val, type, False
1100+
)
10841101
else:
10851102
if type is ast.GeneratorExp:
10861103
self.visit(elt)
@@ -2319,7 +2336,7 @@ def make_code_gen(
23192336
):
23202337
if ast_optimizer_enabled:
23212338
tree = cls.optimize_tree(optimize, tree)
2322-
s = symbols.SymbolVisitor()
2339+
s = cls._SymbolVisitor()
23232340
walk(tree, s)
23242341

23252342
graph = cls.flow_graph(
@@ -2361,6 +2378,7 @@ def __init__(self, kind, block, exit):
23612378

23622379
class CinderCodeGenerator(CodeGenerator):
23632380
flow_graph = pyassem.PyFlowGraphCinder
2381+
_SymbolVisitor = symbols.CinderSymbolVisitor
23642382

23652383
def set_qual_name(self, qualname):
23662384
self._qual_name = qualname
@@ -2463,6 +2481,63 @@ def findFutures(self, node):
24632481
future_flags |= consts.CO_FUTURE_LAZY_IMPORTS
24642482
return future_flags
24652483

2484+
def compile_comprehension(self, node, name, elt, val, opcode, oparg=0):
2485+
self.update_lineno(node)
2486+
# fetch the scope that correspond to comprehension
2487+
scope = self.scopes[node]
2488+
if scope.inlined:
2489+
# for inlined comprehension process with current generator
2490+
gen = self
2491+
else:
2492+
gen = self.make_func_codegen(
2493+
node, self.conjure_arguments([ast.arg(".0", None)]), name, node.lineno
2494+
)
2495+
2496+
if opcode:
2497+
gen.emit(opcode, oparg)
2498+
2499+
gen.compile_comprehension_generator(
2500+
node.generators, 0, elt, val, type(node), not scope.inlined
2501+
)
2502+
2503+
if scope.inlined:
2504+
# collect list of defs that were introduced by comprehension
2505+
# note that we need to exclude:
2506+
# - .0 parameter since it is used
2507+
# - non-local names (typically named expressions), they are
2508+
# defined in enclosing scope and thus should not be deleted
2509+
to_delete = [
2510+
v
2511+
for v in scope.defs
2512+
if v != ".0" and v not in scope.nonlocals and v not in scope.cells
2513+
]
2514+
# sort names to have deterministic deletion order
2515+
to_delete.sort()
2516+
for v in to_delete:
2517+
self.delName(v)
2518+
return
2519+
2520+
if not isinstance(node, ast.GeneratorExp):
2521+
gen.emit("RETURN_VALUE")
2522+
2523+
gen.finishFunction()
2524+
2525+
self._makeClosure(gen, 0)
2526+
2527+
# precomputation of outmost iterable
2528+
self.visit(node.generators[0].iter)
2529+
if node.generators[0].is_async:
2530+
self.emit("GET_AITER")
2531+
else:
2532+
self.emit("GET_ITER")
2533+
self.emit("CALL_FUNCTION", 1)
2534+
2535+
if gen.scope.coroutine and type(node) is not ast.GeneratorExp:
2536+
self.emit("GET_AWAITABLE")
2537+
self.emit("LOAD_CONST", None)
2538+
self.emit("YIELD_FROM")
2539+
2540+
24662541
def get_default_generator():
24672542

24682543
if "cinder" in sys.version:

Lib/compiler/static/compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def _bind(
339339
if name not in self.modules:
340340
tree = self.add_module(name, filename, tree, optimize)
341341
# Analyze variable scopes
342-
s = SymbolVisitor()
342+
s = self.code_generator._SymbolVisitor()
343343
s.visit(tree)
344344

345345
# Analyze the types of objects within local scopes

Lib/compiler/static/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@
172172
from ..optimizer import AstOptimizer
173173
from ..pyassem import Block
174174
from ..pycodegen import FOR_LOOP, CodeGenerator
175-
from ..symbols import SymbolVisitor
175+
from ..symbols import SymbolVisitor, CinderSymbolVisitor
176176
from ..symbols import Scope, ModuleScope
177177
from ..unparse import to_expr
178178
from ..visitor import ASTVisitor

Lib/compiler/strict/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def make_code_gen(
148148
) -> StrictCodeGenerator:
149149
if ast_optimizer_enabled:
150150
tree = cls.optimize_tree(optimize, tree)
151-
s = symbols.SymbolVisitor()
151+
s = cls._SymbolVisitor()
152152
walk(tree, s)
153153

154154
graph = cls.flow_graph(

0 commit comments

Comments
 (0)
  翻译: