Skip to content

Commit 08e5df6

Browse files
committed
[GR-74153] Avoid CodeUnit duplication
PullRequest: graalpython/4337
2 parents 17392fb + 5600cd0 commit 08e5df6

File tree

9 files changed

+172
-246
lines changed

9 files changed

+172
-246
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_code.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,17 +180,11 @@ def inner():
180180

181181
code = compile(codestr, "<test>", "exec")
182182
assert "module doc" in code.co_consts
183-
assert 1 in code.co_consts
184183
assert "fn doc" not in code.co_consts
185184
for const in code.co_consts:
186185
if type(const) == types.CodeType:
187186
code = const
188187
assert "fn doc" in code.co_consts
189-
assert "this is fun" not in code.co_consts
190-
for const in code.co_consts:
191-
if type(const) == types.CodeType:
192-
code = const
193-
assert "this is fun" in code.co_consts
194188

195189

196190
def test_generator_code_consts():
@@ -301,8 +295,7 @@ def bar():
301295
bar = foo()
302296
assert bar.__code__ is foo().__code__
303297
i = foo.__code__.co_consts.index(bar.__code__)
304-
# TODO this is currently broken on the DSL interpreter because the code unit in constants is a separate copy
305-
# assert bar.__code__ is foo.__code__.co_consts[i]
298+
assert bar.__code__ is foo.__code__.co_consts[i]
306299
assert bar.__code__ is bar().f_code
307300

308301
foo_copy = types.FunctionType(marshal.loads(marshal.dumps(foo.__code__)), globals=foo.__globals__, closure=foo.__closure__)

graalpython/com.oracle.graal.python.test/src/tests/test_parser.py

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2018, 2026, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -779,7 +779,6 @@ def test_annotations_in_function():
779779
exec(code,test_globals)
780780
assert len(test_globals['__annotations__']) == 0
781781
assert len(test_globals['fn'].__annotations__) == 0
782-
assert 1 not in test_globals['fn'].__code__.co_consts # the annotation is ignored in function
783782

784783
source = '''def fn():
785784
a:int =1
@@ -789,7 +788,6 @@ def test_annotations_in_function():
789788
assert len(test_globals['__annotations__']) == 0
790789
assert hasattr(test_globals['fn'], '__annotations__')
791790
assert len(test_globals['fn'].__annotations__) == 0
792-
assert 1 in test_globals['fn'].__code__.co_consts
793791

794792
def test_annotations_in_class():
795793

@@ -849,58 +847,6 @@ def test_annotations_in_class():
849847
assert test_globals['Style'].__annotations__['_path'] == str
850848
assert '_path' in dir(test_globals['Style'])
851849

852-
def test_negative_float():
853-
854-
def check_const(fn, expected):
855-
for const in fn.__code__.co_consts:
856-
if repr(const) == repr(expected):
857-
return True
858-
else:
859-
return False
860-
861-
def fn1():
862-
return -0.0
863-
864-
assert check_const(fn1, -0.0)
865-
866-
867-
def find_count_in(collection, what):
868-
count = 0;
869-
for item in collection:
870-
if item == what:
871-
count +=1
872-
return count
873-
874-
def test_same_consts():
875-
def fn1(): a = 1; b = 1; return a + b
876-
assert find_count_in(fn1.__code__.co_consts, 1) == 1
877-
878-
def fn2(): a = 'a'; b = 'a'; return a + b
879-
assert find_count_in(fn2.__code__.co_consts, 'a') == 1
880-
881-
def test_tuple_in_const():
882-
def fn1() : return (0,)
883-
assert (0,) in fn1.__code__.co_consts
884-
assert 0 not in fn1.__code__.co_consts
885-
886-
def fn2() : return (1, 2, 3, 1, 2, 3)
887-
assert (1, 2, 3, 1, 2, 3) in fn2.__code__.co_consts
888-
assert 1 not in fn2.__code__.co_consts
889-
assert 2 not in fn2.__code__.co_consts
890-
assert 3 not in fn2.__code__.co_consts
891-
assert find_count_in(fn2.__code__.co_consts, (1, 2, 3, 1, 2, 3)) == 1
892-
893-
def fn3() : a = 1; return (1, 2, 1)
894-
assert (1, 2, 1) in fn3.__code__.co_consts
895-
assert find_count_in(fn3.__code__.co_consts, 1) == 1
896-
assert 2 not in fn3.__code__.co_consts
897-
898-
def fn4() : a = 1; b = (1,2,3); c = 4; return (1, 2, 3, 1, 2, 3)
899-
assert (1, 2, 3) in fn4.__code__.co_consts
900-
assert (1, 2, 3, 1, 2, 3) in fn4.__code__.co_consts
901-
assert 2 not in fn4.__code__.co_consts
902-
assert find_count_in(fn4.__code__.co_consts, 1) == 1
903-
assert find_count_in(fn4.__code__.co_consts, 4) == 1
904850

905851
def test_ComprehensionGeneratorExpr():
906852
def create_list(gen):

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MarshalModuleBuiltins.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,11 @@ private void writeBytecodeCodeUnit(BytecodeCodeUnit code) throws IOException {
15121512
}
15131513

15141514
private void writeBytecodeDSLCodeUnit(BytecodeDSLCodeUnit code) throws IOException {
1515+
/*
1516+
* Nested code units referenced by MakeFunction are stored in co_consts; the
1517+
* MakeFunction instruction itself carries only the integer index into this constants
1518+
* array.
1519+
*/
15151520
byte[] serialized = code.getSerialized(context);
15161521
writeBytes(serialized);
15171522
writeString(code.name);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/code/PCode.java

Lines changed: 59 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -41,30 +41,20 @@
4141
package com.oracle.graal.python.builtins.objects.code;
4242

4343
import static com.oracle.graal.python.nodes.StringLiterals.J_EMPTY_STRING;
44-
import static com.oracle.graal.python.util.PythonUtils.EMPTY_OBJECT_ARRAY;
4544
import static com.oracle.graal.python.util.PythonUtils.EMPTY_TRUFFLESTRING_ARRAY;
4645
import static com.oracle.graal.python.util.PythonUtils.toInternedTruffleStringUncached;
4746

4847
import java.math.BigInteger;
49-
import java.util.ArrayList;
5048
import java.util.Arrays;
51-
import java.util.HashMap;
52-
import java.util.HashSet;
53-
import java.util.List;
54-
import java.util.Map;
55-
import java.util.Set;
5649

5750
import com.oracle.graal.python.PythonLanguage;
58-
import com.oracle.graal.python.builtins.objects.PNone;
5951
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
60-
import com.oracle.graal.python.builtins.objects.ellipsis.PEllipsis;
6152
import com.oracle.graal.python.builtins.objects.function.Signature;
6253
import com.oracle.graal.python.builtins.objects.generator.PGenerator;
6354
import com.oracle.graal.python.builtins.objects.object.PythonBuiltinObject;
6455
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
6556
import com.oracle.graal.python.compiler.BytecodeCodeUnit;
6657
import com.oracle.graal.python.compiler.CodeUnit;
67-
import com.oracle.graal.python.compiler.OpCodes;
6858
import com.oracle.graal.python.nodes.PRootNode;
6959
import com.oracle.graal.python.nodes.bytecode.PBytecodeGeneratorFunctionRootNode;
7060
import com.oracle.graal.python.nodes.bytecode.PBytecodeGeneratorRootNode;
@@ -136,8 +126,6 @@ public final class PCode extends PythonBuiltinObject {
136126
// qualified name with which this code object was defined
137127
private TruffleString qualname;
138128

139-
private Map<CodeUnit, PCode> childCode;
140-
141129
// number of first line in Python source code
142130
private int firstlineno = -1;
143131
// is a string encoding the mapping from bytecode offsets to line numbers
@@ -293,55 +281,23 @@ private static TruffleString[] extractVarnames(RootNode node) {
293281
return EMPTY_TRUFFLESTRING_ARRAY;
294282
}
295283

284+
private Object[] ensureConstants() {
285+
if (constants == null) {
286+
CodeUnit codeUnit = getCodeUnit(getRootNode());
287+
constants = codeUnit != null ? new Object[codeUnit.constants.length] : PythonUtils.EMPTY_OBJECT_ARRAY;
288+
}
289+
return constants;
290+
}
291+
296292
@TruffleBoundary
297-
private Object[] extractConstants(RootNode node) {
298-
RootNode rootNode = rootNodeForExtraction(node);
299-
if (PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER) {
300-
if (rootNode instanceof PBytecodeDSLRootNode bytecodeDSLRootNode) {
301-
BytecodeDSLCodeUnit co = bytecodeDSLRootNode.getCodeUnit();
302-
List<Object> constants = new ArrayList<>();
303-
for (int i = 0; i < co.constants.length; i++) {
304-
Object constant = convertConstantToPythonSpace(co.constants[i]);
305-
constants.add(constant);
306-
}
307-
return constants.toArray(new Object[0]);
308-
}
309-
} else if (rootNode instanceof PBytecodeRootNode bytecodeRootNode) {
310-
BytecodeCodeUnit co = bytecodeRootNode.getCodeUnit();
311-
Set<Object> bytecodeConstants = new HashSet<>();
312-
for (int bci = 0; bci < co.code.length;) {
313-
OpCodes op = OpCodes.fromOpCode(co.code[bci]);
314-
if (op.quickens != null) {
315-
op = op.quickens;
316-
}
317-
if (op == OpCodes.LOAD_BYTE) {
318-
bytecodeConstants.add(Byte.toUnsignedInt(co.code[bci + 1]));
319-
} else if (op == OpCodes.LOAD_NONE) {
320-
bytecodeConstants.add(PNone.NONE);
321-
} else if (op == OpCodes.LOAD_TRUE) {
322-
bytecodeConstants.add(true);
323-
} else if (op == OpCodes.LOAD_FALSE) {
324-
bytecodeConstants.add(false);
325-
} else if (op == OpCodes.LOAD_ELLIPSIS) {
326-
bytecodeConstants.add(PEllipsis.INSTANCE);
327-
} else if (op == OpCodes.LOAD_INT || op == OpCodes.LOAD_LONG) {
328-
bytecodeConstants.add(co.primitiveConstants[Byte.toUnsignedInt(co.code[bci + 1])]);
329-
} else if (op == OpCodes.LOAD_DOUBLE) {
330-
bytecodeConstants.add(Double.longBitsToDouble(co.primitiveConstants[Byte.toUnsignedInt(co.code[bci + 1])]));
331-
}
332-
bci += op.length();
333-
}
334-
List<Object> constants = new ArrayList<>();
335-
for (int i = 0; i < co.constants.length; i++) {
336-
Object constant = convertConstantToPythonSpace(co.constants[i]);
337-
if (constant != PNone.NONE || !bytecodeConstants.contains(PNone.NONE)) {
338-
constants.add(constant);
339-
}
340-
}
341-
constants.addAll(bytecodeConstants);
342-
return constants.toArray(new Object[0]);
293+
private Object getOrCreateConstant(int index) {
294+
Object[] cachedConstants = ensureConstants();
295+
Object constant = cachedConstants[index];
296+
if (constant == null) {
297+
constant = convertConstantToPythonSpace(index);
298+
cachedConstants[index] = constant;
343299
}
344-
return EMPTY_OBJECT_ARRAY;
300+
return constant;
345301
}
346302

347303
@TruffleBoundary
@@ -419,9 +375,11 @@ public void fixCoFilename(TruffleString filename) {
419375
* New code objects inherit the filename from parent, so no need to eagerly construct them
420376
* here
421377
*/
422-
if (childCode != null) {
423-
for (PCode code : childCode.values()) {
424-
code.filename = filename;
378+
if (constants != null) {
379+
for (Object constant : constants) {
380+
if (constant instanceof PCode code) {
381+
code.filename = filename;
382+
}
425383
}
426384
}
427385
}
@@ -525,65 +483,66 @@ public CodeUnit getCodeUnit() {
525483
}
526484

527485
public Object[] getConstants() {
528-
if (constants == null) {
529-
constants = extractConstants(getRootNode());
486+
Object[] cachedConstants = ensureConstants();
487+
for (int i = 0; i < cachedConstants.length; i++) {
488+
getOrCreateConstant(i);
530489
}
531-
return constants;
490+
return cachedConstants;
532491
}
533492

534-
@TruffleBoundary
535-
public PCode getOrCreateChildCode(BytecodeDSLCodeUnit codeUnit) {
536-
PCode code = null;
537-
if (childCode == null) {
538-
childCode = new HashMap<>();
539-
} else {
540-
code = childCode.get(codeUnit);
541-
}
493+
public PCode getOrCreateChildCode(int index, BytecodeDSLCodeUnit codeUnit) {
494+
Object[] cachedConstants = ensureConstants();
495+
PCode code = (PCode) cachedConstants[index];
542496
if (code == null) {
543-
PBytecodeDSLRootNode outerRootNode = (PBytecodeDSLRootNode) getRootNode();
544-
PythonLanguage language = outerRootNode.getLanguage();
545-
RootCallTarget callTarget = language.createCachedCallTarget(l -> codeUnit.createRootNode(PythonContext.get(null), outerRootNode.getSource()), codeUnit);
546-
PBytecodeDSLRootNode rootNode = (PBytecodeDSLRootNode) callTarget.getRootNode();
547-
code = PFactory.createCode(language, callTarget, rootNode.getSignature(), codeUnit, getFilename());
548-
childCode.put(codeUnit, code);
497+
code = createCode(codeUnit);
498+
cachedConstants[index] = code;
549499
}
550500
return code;
551501
}
552502

553503
@TruffleBoundary
554-
public PCode getOrCreateChildCode(BytecodeCodeUnit codeUnit) {
555-
PCode code = null;
556-
if (childCode == null) {
557-
childCode = new HashMap<>();
558-
} else {
559-
code = childCode.get(codeUnit);
560-
}
504+
private PCode createCode(BytecodeDSLCodeUnit codeUnit) {
505+
PBytecodeDSLRootNode outerRootNode = (PBytecodeDSLRootNode) getRootNode();
506+
PythonLanguage language = outerRootNode.getLanguage();
507+
RootCallTarget callTarget = language.createCachedCallTarget(l -> codeUnit.createRootNode(PythonContext.get(null), outerRootNode.getSource()), codeUnit);
508+
PBytecodeDSLRootNode rootNode = (PBytecodeDSLRootNode) callTarget.getRootNode();
509+
return PFactory.createCode(language, callTarget, rootNode.getSignature(), codeUnit, getFilename());
510+
}
511+
512+
public PCode getOrCreateChildCode(int index, BytecodeCodeUnit codeUnit) {
513+
Object[] cachedConstants = ensureConstants();
514+
PCode code = (PCode) cachedConstants[index];
561515
if (code == null) {
562-
PBytecodeRootNode outerRootNode = (PBytecodeRootNode) getRootNodeForExtraction();
563-
PythonLanguage language = outerRootNode.getLanguage();
564-
RootCallTarget callTarget = language.createCachedCallTarget(
565-
l -> PBytecodeRootNode.createMaybeGenerator(language, codeUnit, outerRootNode.getLazySource(), outerRootNode.isInternal()),
566-
codeUnit);
567-
RootNode rootNode = callTarget.getRootNode();
568-
if (rootNode instanceof PBytecodeGeneratorFunctionRootNode generatorRoot) {
569-
rootNode = generatorRoot.getBytecodeRootNode();
570-
}
571-
code = PFactory.createCode(language, callTarget, ((PBytecodeRootNode) rootNode).getSignature(), codeUnit, getFilename());
572-
childCode.put(codeUnit, code);
516+
code = createCode(codeUnit);
517+
cachedConstants[index] = code;
573518
}
574519
return code;
575520
}
576521

577522
@TruffleBoundary
578-
private Object convertConstantToPythonSpace(Object o) {
523+
private PCode createCode(BytecodeCodeUnit codeUnit) {
524+
PBytecodeRootNode outerRootNode = (PBytecodeRootNode) getRootNodeForExtraction();
525+
PythonLanguage language = outerRootNode.getLanguage();
526+
RootCallTarget callTarget = language.createCachedCallTarget(
527+
l -> PBytecodeRootNode.createMaybeGenerator(language, codeUnit, outerRootNode.getLazySource(), outerRootNode.isInternal()), codeUnit);
528+
RootNode rootNode = callTarget.getRootNode();
529+
if (rootNode instanceof PBytecodeGeneratorFunctionRootNode generatorRoot) {
530+
rootNode = generatorRoot.getBytecodeRootNode();
531+
}
532+
return PFactory.createCode(language, callTarget, ((PBytecodeRootNode) rootNode).getSignature(), codeUnit, getFilename());
533+
}
534+
535+
@TruffleBoundary
536+
private Object convertConstantToPythonSpace(int index) {
537+
Object o = getCodeUnit().constants[index];
579538
PythonLanguage language = PythonLanguage.get(null);
580539
if (o instanceof CodeUnit) {
581540
if (PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER) {
582541
BytecodeDSLCodeUnit code = (BytecodeDSLCodeUnit) o;
583-
return getOrCreateChildCode(code);
542+
return getOrCreateChildCode(index, code);
584543
} else {
585544
BytecodeCodeUnit code = (BytecodeCodeUnit) o;
586-
return getOrCreateChildCode(code);
545+
return getOrCreateChildCode(index, code);
587546
}
588547
} else if (o instanceof BigInteger) {
589548
return PFactory.createInt(language, (BigInteger) o);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/compiler/bytecode_dsl/BytecodeDSLCompiler.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2025, 2026, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -41,6 +41,8 @@
4141
package com.oracle.graal.python.compiler.bytecode_dsl;
4242

4343
import java.util.EnumSet;
44+
import java.util.HashMap;
45+
import java.util.Map;
4446

4547
import com.oracle.graal.python.PythonLanguage;
4648
import com.oracle.graal.python.compiler.Compiler;
@@ -94,6 +96,8 @@ public static class BytecodeDSLCompilerContext {
9496
public final int futureLineNumber;
9597
public final ParserCallbacksImpl errorCallback;
9698
public final ScopeEnvironment scopeEnvironment;
99+
// Store code units for possible reparses
100+
public final Map<Object, BytecodeDSLCodeUnit> codeUnits = new HashMap<>();
97101

98102
public BytecodeDSLCompilerContext(PythonLanguage language, ModTy mod, Source source, int optimizationLevel,
99103
EnumSet<FutureFeature> futureFeatures, int futureLineNumber, ParserCallbacksImpl errorCallback, ScopeEnvironment scopeEnvironment) {

0 commit comments

Comments
 (0)