Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 81 additions & 95 deletions hat/core/src/main/java/hat/BufferTagger.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
package hat;

import hat.phases.HATPhaseUtils;
import jdk.incubator.code.CodeItem;
import jdk.incubator.code.dialect.java.ArrayType;
import jdk.incubator.code.dialect.java.JavaOp;
import jdk.incubator.code.dialect.java.PrimitiveType;
import optkl.IfaceValue;
import optkl.OpHelper;
import optkl.ifacemapper.AccessType;
import optkl.ifacemapper.Buffer;
import optkl.ifacemapper.MappableIface;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;
import jdk.incubator.code.Block;
Expand All @@ -40,142 +41,127 @@
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.List;
import static optkl.OpHelper.Invoke;
import java.util.stream.IntStream;
import static optkl.OpHelper.Invoke.invoke;

public class BufferTagger {
static HashMap<Value, AccessType> accessMap = new HashMap<>();
static HashMap<Value, Value> remappedVals = new HashMap<>(); // maps values to their "root" parameter/value
static HashMap<Block, List<Block.Parameter>> blockParams = new HashMap<>(); // holds block parameters for easy lookup
static Map<Value, AccessType> accessMap = new HashMap<>(); // mapping of parameters/buffers to access type
static Map<Value, Value> rootValues = new HashMap<>(); // maps values to their "root" parameter/value

// generates a list of AccessTypes matching the given FuncOp's parameter order
public static ArrayList<AccessType> getAccessList(MethodHandles.Lookup lookup, CoreOp.FuncOp inlinedEntryPoint) {
buildAccessMap(lookup, inlinedEntryPoint);
ArrayList<AccessType> accessList = new ArrayList<>();
for (Block.Parameter p : inlinedEntryPoint.body().entryBlock().parameters()) {
public static List<AccessType> getAccessList(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
buildAccessMap(lookup, funcOp);
List<AccessType> accessList = new ArrayList<>();
for (Block.Parameter p : funcOp.body().entryBlock().parameters()) {
if (accessMap.containsKey(p)) {
accessList.add(accessMap.get(p)); // is an accessed buffer
} else if (OpHelper.isAssignable(lookup, p.type(), MappableIface.class)) {
} else if (OpHelper.isAssignable(lookup, p.type(), IfaceValue.class)) {
accessList.add(AccessType.NA); // is a buffer but not accessed
} else {
accessList.add(AccessType.NOT_BUFFER); // is not a buffer
}
}
return accessList;
}
private static boolean isReference(Invoke ioh) {
return ioh.returns(IfaceValue.class)
&& ioh.opFromOnlyUseOrNull() instanceof JavaOp.InvokeOp nextInvoke
&& invoke(ioh.lookup(), nextInvoke) instanceof Invoke nextIoh
&& nextIoh.refIs(IfaceValue.class)
&& nextIoh.returnsVoid();
}

// creates the access map
private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
// build blockParams so that we can map params to "root" params later
funcOp.elements()
.filter(elem -> elem instanceof Block)
.map(elem->(Block)elem)
.forEach(block -> blockParams.put(block, block.parameters()));

private static void buildAccessMap(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) {
funcOp.elements().forEach(op -> {
switch (op) {
case CoreOp.BranchOp b -> mapBranch(lookup, b.branch());
case CoreOp.ConditionalBranchOp cb -> {
mapBranch(lookup, cb.trueBranch()); // handle true branch
mapBranch(lookup, cb.falseBranch()); // handle false branch
}
case JavaOp.InvokeOp invokeOp -> {
var ioh = invoke(lookup,invokeOp);
// we have to deal with array views too
// should .arrayview() calls be marked as reads?
if ( ioh.refIs(IfaceValue.class)) {
// updateAccessType(getRootValue(invokeOp), ioh.returnsVoid()? AccessType.WO : AccessType.RO); // update buffer access
// if the invokeOp retrieves an element that is only written to, don't update the access type
// (i.e. the only use is an invoke, the invoke is of MappableIface/HAType class, and is a write)
if (!isReference(ioh)) { // value retrieved and not just referenced?
updateAccessType(getRootValue(invokeOp), ioh.returnsVoid()? AccessType.WO : AccessType.RO); // update buffer access
}
if (ioh.refIs(IfaceValue.class) && (ioh.returns(IfaceValue.class) || ioh.returnsArray())) {
// if we access a struct/union from a buffer, we map the struct/union to the buffer root
remappedVals.put(invokeOp.result(), getRootValue(invokeOp));
}
case JavaOp.InvokeOp $ when invoke(lookup, $) instanceof OpHelper.Invoke ioh && !ioh.refIs(KernelContext.class) -> {
if (ioh.returns(IfaceValue.class) || ioh.returnsArray()) { // if we receive a buffer from this invoke, we save its root value
ioh.op().operands().stream()
.filter(operand -> !(operand.type() instanceof PrimitiveType) && rootValues.containsKey(operand))
.forEach(operand -> {
if (operand instanceof Block.Parameter) {
updateAccessType(operand, AccessType.RO);
} else {
updateAccessType(operand.result().op(), AccessType.RO);
}
});
rootValues.put(ioh.returnResult(), getRootValue(ioh.op()));
} else { // if we actually operate on a buffer instead of storing an element in a variable
updateAccessType(ioh.op(), ioh.returnsVoid() ? AccessType.WO : AccessType.RO); // update buffer access
}
}
case CoreOp.VarOp vop -> { // map the new VarOp to the "root" param
if (OpHelper.isAssignable(lookup, vop.resultType().valueType(), Buffer.class)) {
remappedVals.put(vop.initOperand(), getRootValue(vop));
}else{
// or else maybe CoreOp.VarOp vop when ??? ->
}
}
case JavaOp.FieldAccessOp.FieldLoadOp flop -> {
if (OpHelper.isAssignable(lookup, flop.fieldReference().refType(), KernelContext.class)) {
updateAccessType(getRootValue(flop), AccessType.RO); // handle kc access
}else{
// or else
}
}
case JavaOp.ArrayAccessOp.ArrayLoadOp alop -> updateAccessType(getRootValue(alop), AccessType.RO);
case JavaOp.ArrayAccessOp.ArrayStoreOp asop -> updateAccessType(getRootValue(asop), AccessType.WO);
case CoreOp.VarOp vop when OpHelper.isAssignable(lookup, vop.varValueType(), IfaceValue.class) ->
rootValues.put(vop.initOperand(), getRootValue(vop)); // map the new VarOp to the "root" param
case JavaOp.FieldAccessOp.FieldLoadOp flop when OpHelper.isAssignable(lookup, flop.fieldReference().refType(), KernelContext.class) ->
updateAccessType(flop, AccessType.RO); // handle kc access
case JavaOp.ArrayAccessOp.ArrayLoadOp alop when !(alop.resultType() instanceof ArrayType) ->
updateAccessType(alop, AccessType.RO);
case JavaOp.ArrayAccessOp.ArrayStoreOp asop ->
updateAccessType(asop, AccessType.WO);
default -> {}
}
});
}

// maps the parameters of a block to the values passed to a branch
private static void mapBranch(MethodHandles.Lookup lookup, Block.Reference blockReference) {
List<Value> args = blockReference.arguments();
for (int i = 0; i < args.size(); i++) {
Value key = blockParams.get(blockReference.targetBlock()).get(i);
Value value = args.get(i);
if (value instanceof Op.Result result) {
// either find root param or it doesn't exist (is a constant for example)
if (OpHelper.isAssignable(lookup, value.type(), MappableIface.class)) {
value = getRootValue(result.op());
if (value instanceof Block.Parameter) {
value = remappedVals.getOrDefault(value, value);
List<Value> inputArgs = blockReference.arguments();
List<Block.Parameter> targetArgs = blockReference.targetBlock().parameters();
IntStream.range(0, inputArgs.size()).filter(i ->
inputArgs.get(i) instanceof Op.Result && OpHelper.isAssignable(lookup, inputArgs.get(i).type(), IfaceValue.class))
.forEach(i -> {
Value input = inputArgs.get(i);
input = getRootValue(input.result().op());
rootValues.put(targetArgs.get(i), rootValues.getOrDefault(input, input));
});
}

// retrieves "root" value of an op, which is how we track accesses
private static Value getRootValue(Op op) {
// the op is a field load, an invoke, or something that reduces to one or the other
Op rootOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.FieldAccessOp.FieldLoadOp.class, JavaOp.InvokeOp.class);
switch (rootOp) {
case JavaOp.FieldAccessOp.FieldLoadOp fieldOp when !fieldOp.operands().isEmpty() -> {
return fieldOp.operands().getFirst();
}
case JavaOp.InvokeOp invokeOp -> {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this would be easier I think if you used the Invoke wrapper.

while (invokeOp != null && !invokeOp.operands().isEmpty()) { // we look for either the parameter or initialization for the buffer
if (invokeOp.operands().getFirst() instanceof Block.Parameter p) {
return p; // return the parameter that is the global buffer
}
}else{
// or else
invokeOp = (JavaOp.InvokeOp) HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(invokeOp.operands().getFirst().result().op(), JavaOp.InvokeOp.class);
}
if (invokeOp != null) {
return invokeOp.result();
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is null what do we return? Maybe we dont expect this ever to be the case. In which case shoudl we throw?

}else{
// or else?
}
remappedVals.put(key, value);
case null, default -> {}
}
return null;
}

// retrieves "root" value of an op, which is how we track accesses
// we will map the return value of this method to the accessType
private static Value getRootValue(Op op) {
// the op is a field load, an invoke, or something that reduces to one or the other
// first, check if we can retrieve a fieldloadop from the given op
Op fieldOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.FieldAccessOp.FieldLoadOp.class);
if (fieldOp != null) return fieldOp.operands().getFirst(); // if so, we use its first operand to map to accesses

// we then check if there's an invokeop that has no operands (meaning a shared or private buffer that was created)
// or if there's an invokeop with a parameter as its first operation (this is a global buffer)
Op invokeOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
while (invokeOp != null && !invokeOp.operands().isEmpty()) {
if (invokeOp.operands().getFirst() instanceof Block.Parameter p) return p; // return the parameter that is the global buffer
invokeOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(invokeOp.operands().getFirst().result().op(), JavaOp.InvokeOp.class);
}
return (invokeOp == null) ? null : invokeOp.result(); // return the shared/private buffer invokeop that creates the buffer
// retrieves root value of op before updating the access map
private static void updateAccessType(Op op, AccessType currentAccess) {
updateAccessType(getRootValue(op), currentAccess);
}

// updates accessMap
private static void updateAccessType(Value value, AccessType currentAccess) {
Value remappedValue = remappedVals.getOrDefault(value, value);
AccessType storedAccess = accessMap.get(remappedValue);
// updates the access map
private static void updateAccessType(Value value, AccessType currentAccess) {
AccessType storedAccess = accessMap.get(value);
if (storedAccess == null) {
accessMap.put(remappedValue, currentAccess);
accessMap.put(value, currentAccess);
} else if (currentAccess != storedAccess && storedAccess != AccessType.RW) {
accessMap.put(remappedValue, AccessType.RW);
} else {
// this is the same access type as what's already stored
accessMap.put(value, AccessType.RW);
} // otherwise this is the same access type as what's already stored
}

public static void printAccessList(CoreOp.FuncOp funcOp, List<AccessType> accessList) {
StringBuilder output = new StringBuilder();
output.append("func ").append(funcOp.funcName()).append(" has parameters");
for (AccessType at : accessList) {
output.append(" ").append(at);
}
System.out.println(output);
}
}
2 changes: 1 addition & 1 deletion hat/core/src/main/java/hat/buffer/S32Array.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ static S32Array createFrom(ArenaAndLookupCarrier cc, int[] arr){
return ints;
}

@Reflect default int[] arrayView() {
default int[] arrayView() {
return this.copyTo(new int[this.length()]);
}
}