From 6a41036b7eee2aa1838cd5f1b09c4dce728e4657 Mon Sep 17 00:00:00 2001 From: Fuad Hasan Date: Thu, 23 Apr 2026 14:05:34 +0600 Subject: [PATCH 1/2] Fix ownership moves through composite values --- .../analysis/semantics/ownership/ownership.go | 71 +++++++++++++++- .../semantics/ownership/ownership_test.go | 81 +++++++++++++++++++ 2 files changed, 148 insertions(+), 4 deletions(-) diff --git a/internal/analysis/semantics/ownership/ownership.go b/internal/analysis/semantics/ownership/ownership.go index 30474a19..1ba8f330 100644 --- a/internal/analysis/semantics/ownership/ownership.go +++ b/internal/analysis/semantics/ownership/ownership.go @@ -537,6 +537,7 @@ func (a *ownershipAnalyzer) collectValueLocalUses(used cfg.LocalSet, value mir.V a.collectValueLocalUses(used, v.Left) case *mir.CompositeValue: for _, item := range v.Items { + a.collectValueLocalUses(used, item.Key) a.collectValueLocalUses(used, item.Value) } case *mir.InterfaceValue: @@ -554,7 +555,9 @@ func (a *ownershipAnalyzer) checkComputedValue(scope *valueScope, instr *mir.Com a.checkValue(scope, instr.Value) if info, ok := a.tempInfoForValue(instr.Value); ok { a.temps[instr.TargetID] = info + return } + a.consumeMoveValue(scope, instr.Value, a.localType(instr.TargetID)) } func (a *ownershipAnalyzer) checkValue(scope *valueScope, value mir.Value) { @@ -602,6 +605,7 @@ func (a *ownershipAnalyzer) checkValue(scope *valueScope, value mir.Value) { a.checkValue(scope, v.Left) case *mir.CompositeValue: for _, item := range v.Items { + a.checkValue(scope, item.Key) a.checkValue(scope, item.Value) } case *mir.InterfaceValue: @@ -845,10 +849,17 @@ func (a *ownershipAnalyzer) consumeMoveValue(scope *valueScope, value mir.Value, if value == nil { return } + if unary, ok := value.(*mir.UnaryValue); ok && unary.Op == "copy" { + return + } if ifaceValue, ok := value.(*mir.InterfaceValue); ok { a.consumeMoveValue(scope, ifaceValue.Value, ifaceValue.ConcreteType) return } + if composite, ok := value.(*mir.CompositeValue); ok { + a.consumeCompositeMoveValue(scope, composite, typ) + return + } if local, ok := value.(*mir.LocalValue); ok && scope != nil { if slot, ok := scope.Lookup(local.LocalID); ok && slot != nil && a.isMoveType(slot.concrete) { a.consumeLocalPath(scope, local.LocalID, "", value.Loc()) @@ -859,10 +870,6 @@ func (a *ownershipAnalyzer) consumeMoveValue(scope *valueScope, value mir.Value, return } switch v := value.(type) { - case *mir.UnaryValue: - if v.Op == "copy" { - return - } case *mir.LocalValue: if info, ok := a.temps[v.LocalID]; ok { if !info.root.isLocal() { @@ -881,6 +888,62 @@ func (a *ownershipAnalyzer) consumeMoveValue(scope *valueScope, value mir.Value, a.consumeLocalPath(scope, root, path, value.Loc()) } +func (a *ownershipAnalyzer) consumeCompositeMoveValue(scope *valueScope, value *mir.CompositeValue, typ typeinfo.Type) { + if value == nil { + return + } + for i, item := range value.Items { + keyType, valueType := a.compositeItemTypes(typ, item, i) + a.consumeMoveValue(scope, item.Key, keyType) + a.consumeMoveValue(scope, item.Value, valueType) + } +} + +func (a *ownershipAnalyzer) compositeItemTypes(typ typeinfo.Type, item mir.CompositeItem, index int) (typeinfo.Type, typeinfo.Type) { + base := a.compositePayloadType(typ) + if base == nil || typeinfo.IsInvalid(base) || typeinfo.IsUnknown(base) { + return valueType(item.Key), valueType(item.Value) + } + switch t := base.(type) { + case *typeinfo.StructType: + if item.Name != "" { + if field := t.Fields[item.Name]; field != nil { + return nil, field.Type + } + } + if index >= 0 && index < len(t.OrderedFields) && t.OrderedFields[index] != nil { + return nil, t.OrderedFields[index].Type + } + case *typeinfo.TupleType: + if index >= 0 && index < len(t.Elems) { + return nil, t.Elems[index] + } + case *typeinfo.ArrayType: + return nil, t.Inner + case *typeinfo.SliceType: + return nil, t.Inner + case *typeinfo.MapType: + return t.Key, t.Value + } + return valueType(item.Key), valueType(item.Value) +} + +func (a *ownershipAnalyzer) compositePayloadType(typ typeinfo.Type) typeinfo.Type { + for typ != nil && !typeinfo.IsInvalid(typ) && !typeinfo.IsUnknown(typ) { + switch t := a.underlying(typ).(type) { + case *typeinfo.ApproxType: + typ = t.Inner + case *typeinfo.OptionalType: + typ = t.Inner + case *typeinfo.ErrorUnionType: + typ = t.Value + default: + return t + } + } + return typ +} + func (a *ownershipAnalyzer) consumeLocalPath(scope *valueScope, root int, path string, loc source.Location) { if scope == nil || root < 0 { return diff --git a/internal/analysis/semantics/ownership/ownership_test.go b/internal/analysis/semantics/ownership/ownership_test.go index 34bc3331..164dfce0 100644 --- a/internal/analysis/semantics/ownership/ownership_test.go +++ b/internal/analysis/semantics/ownership/ownership_test.go @@ -765,6 +765,70 @@ fn main(n: Node) -> i32 { t.Fatalf("expected %s diagnostic, got %#v", diagnostics.ErrUseAfterMove, result.Diagnostics.Diagnostics()) } +func TestOwnershipPhaseRejectsMoveAfterStructLiteralFieldMove(t *testing.T) { + assertOwnershipUseAfterMove(t, ` +type Node struct {} + +type Box struct { + Item: *Node +} + +fn take(box: Box) -> void { + box +} + +fn main(node: *Node) -> *Node { + take(.{ .Item = node }) + return node +} +`) +} + +func TestOwnershipPhaseRejectsMoveAfterTupleLiteralElementMove(t *testing.T) { + assertOwnershipUseAfterMove(t, ` +type Node struct {} + +fn take(pair: (*Node, i32)) -> void { + pair +} + +fn main(node: *Node) -> *Node { + take((node, 1)) + return node +} +`) +} + +func TestOwnershipPhaseRejectsMoveAfterSliceLiteralElementMove(t *testing.T) { + assertOwnershipUseAfterMove(t, ` +type Node struct {} + +fn take(nodes: []*Node) -> void { + nodes +} + +fn main(node: *Node) -> *Node { + take([]*Node{node}) + return node +} +`) +} + +func TestOwnershipPhaseRejectsMoveAfterVariadicArgumentMove(t *testing.T) { + assertOwnershipUseAfterMove(t, ` +type Node struct {} + +fn take(nodes: ...*Node) -> void { + nodes +} + +fn main(node: *Node) -> *Node { + take(node) + return node +} +`) +} + func TestOwnershipPhaseAllowsPlainValueReceiverMethodReuse(t *testing.T) { root := t.TempDir() mustWriteOwnership(t, filepath.Join(root, "main.fer"), ` @@ -1049,6 +1113,23 @@ fn main() -> void { t.Fatalf("expected %s diagnostic, got %#v", diagnostics.ErrUseAfterMove, result.Diagnostics.Diagnostics()) } +func assertOwnershipUseAfterMove(t *testing.T, source string) { + t.Helper() + root := t.TempDir() + mustWriteOwnership(t, filepath.Join(root, "main.fer"), source) + + result := compiler.New(root, ".fer", diagnostics.NewDiagnosticBag("")).ParseEntry(filepath.Join(root, "main.fer")) + if result.Entry == nil || result.Entry.Phase < phase.PhaseOwnershipAnalyzed { + t.Fatalf("expected ownership analyzed phase, got %#v", result.Entry) + } + for _, diag := range result.Diagnostics.Diagnostics() { + if diag.Code == diagnostics.ErrUseAfterMove { + return + } + } + t.Fatalf("expected %s diagnostic, got %#v", diagnostics.ErrUseAfterMove, result.Diagnostics.Diagnostics()) +} + func mustWriteOwnership(t *testing.T, path, content string) { t.Helper() if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { From 6d2104c81b8fc370426a249831127a65a45f21fb Mon Sep 17 00:00:00 2001 From: Fuad Hasan Date: Thu, 23 Apr 2026 16:34:30 +0600 Subject: [PATCH 2/2] Introduce atomic storage and task runtime Add AST/parser support for 'atomic' types and let/stmt flags. Implement atomic storage TypeInfo, synthesize and const-eval handling in the typechecker, and HIR/MIR support for atomic loads/ops. Add LLVM lowering for atomic instructions and runtime symbols. Introduce std/task surface and C runtime using pthreads on non-Windows (platform no-op on Windows). Pass -pthread to compiler/linker on non-Windows. Add tests and repro programs for atomics and task usage. --- bundler/build_tools.go | 3 + ferret_libs_dev/std/task.fer | 35 ++++++ .../semantics/typechecker/const_eval.go | 4 +- .../analysis/semantics/typechecker/core.go | 4 +- .../analysis/semantics/typechecker/stmts.go | 5 +- .../semantics/typechecker/symbol_types.go | 24 +++- .../semantics/typechecker/syntax_types.go | 2 + .../semantics/typechecker/typechecker.go | 61 ++++++++- .../semantics/typechecker/typechecker_test.go | 74 +++++++++++ .../analysis/semantics/typeinfo/rewrite.go | 5 + internal/analysis/semantics/typeinfo/types.go | 62 ++++++++++ .../analysis/semantics/typeinfo/types_test.go | 46 ++++++- internal/backend/aggregate_layout.go | 2 + internal/backend/common/common.go | 2 + internal/backend/llvm/link.go | 7 ++ internal/backend/llvm/llvm.go | 108 +++++++++++++--- internal/backend/llvm/lower_test.go | 109 ++++++++++++++++ internal/backend/type_helpers.go | 4 + internal/backend/typeinfo.go | 2 + internal/frontend/ast/clone.go | 6 +- internal/frontend/ast/debug.go | 6 +- internal/frontend/ast/decl.go | 1 + internal/frontend/ast/dump.go | 3 + internal/frontend/ast/format.go | 2 + internal/frontend/ast/stmt.go | 1 + internal/frontend/ast/type.go | 8 ++ internal/frontend/ast/walk.go | 2 + internal/frontend/parser/decl.go | 5 + internal/frontend/parser/parser.go | 2 +- internal/frontend/parser/stmt.go | 5 + internal/frontend/parser/type.go | 3 + internal/frontend/parser/validate.go | 2 + internal/ir/hir/generate.go | 23 +++- internal/ir/hir/specialize.go | 13 ++ internal/ir/mir/debug.go | 2 + internal/ir/mir/format.go | 4 + internal/ir/mir/lower.go | 117 ++++++++++++++++++ internal/ir/mir/model.go | 16 +++ internal/ir/mir/validate.go | 5 + internal/ir/mir/walk.go | 7 ++ internal/tokens/token.go | 3 + runtime/ferret_runtime.h | 9 ++ runtime/ferret_runtime_task.c | 75 +++++++++++ tests/repro/task_atomic_i32.fer | 23 ++++ tests/repro/task_race_raw.fer | 21 ++++ tests/repro/variadic.fer | 8 ++ 46 files changed, 895 insertions(+), 36 deletions(-) create mode 100644 ferret_libs_dev/std/task.fer create mode 100644 runtime/ferret_runtime_task.c create mode 100644 tests/repro/task_atomic_i32.fer create mode 100644 tests/repro/task_race_raw.fer create mode 100644 tests/repro/variadic.fer diff --git a/bundler/build_tools.go b/bundler/build_tools.go index 39f94dfe..d9e25452 100644 --- a/bundler/build_tools.go +++ b/bundler/build_tools.go @@ -93,6 +93,9 @@ func buildRuntimeLib(runtimeDir, libsDir string, bits int) error { for _, src := range entries { obj := filepath.Join(objDir, strings.TrimSuffix(filepath.Base(src), ".c")+".o") args := []string{"-std=c11", "-O2", "-Wall", "-Wextra", "-I", runtimeDir, "-c", src, "-o", obj} + if runtime.GOOS != "windows" { + args = append(args, "-pthread") + } if bits == abi.Bits32 { args = append(args[:1], append([]string{"-m32"}, args[1:]...)...) } diff --git a/ferret_libs_dev/std/task.fer b/ferret_libs_dev/std/task.fer new file mode 100644 index 00000000..36645d88 --- /dev/null +++ b/ferret_libs_dev/std/task.fer @@ -0,0 +1,35 @@ +// std/task - explicit task/thread runtime surface. +// +// Current implementation is backed by OS threads. + +import "std/mem" + +type taskInner struct {} + +type Handle struct { + raw: *taskInner +} + +#[extern("ferret_task_run_raw")] +fn run_ref(entry: fn(T) -> void, arg: T) -> ^taskInner; + +#[extern("ferret_task_wait")] +fn wait_raw(handle: ^void) -> void; + +fn Run(callback: fn(T) -> void, arg: T) -> Handle { + unsafe { + return .{ .raw = mem::Adopt(run_ref(callback, arg)) } + } +} + +fn Handle::Wait(self) -> void { + unsafe { + wait_raw(mem::Expose(self.raw) as ^void) + } +} + +fn WaitAll(handles: ...Handle) -> void { + for handles | handle | { + handle.Wait() + } +} diff --git a/internal/analysis/semantics/typechecker/const_eval.go b/internal/analysis/semantics/typechecker/const_eval.go index 97faad50..07f7f9ed 100644 --- a/internal/analysis/semantics/typechecker/const_eval.go +++ b/internal/analysis/semantics/typechecker/const_eval.go @@ -19,9 +19,9 @@ func allowsConstValueCache(node ast.Node) bool { case *ast.ConstDecl, *ast.ConstStmt: return true case *ast.LetDecl: - return n != nil && !n.IsMut + return n != nil && !n.IsMut && !n.IsAtomic case *ast.LetStmt: - return n != nil && !n.IsMut + return n != nil && !n.IsMut && !n.IsAtomic default: return false } diff --git a/internal/analysis/semantics/typechecker/core.go b/internal/analysis/semantics/typechecker/core.go index bfe688d3..4f0c679e 100644 --- a/internal/analysis/semantics/typechecker/core.go +++ b/internal/analysis/semantics/typechecker/core.go @@ -146,9 +146,9 @@ func (c *checker) symbolMutable(sym *symbols.Symbol) bool { } switch node := sym.Node.(type) { case *ast.LetDecl: - return node != nil && node.IsMut + return node != nil && (node.IsMut || node.IsAtomic) case *ast.LetStmt: - return node != nil && node.IsMut + return node != nil && (node.IsMut || node.IsAtomic) case *ast.ConstDecl, *ast.ConstStmt: return false default: diff --git a/internal/analysis/semantics/typechecker/stmts.go b/internal/analysis/semantics/typechecker/stmts.go index e292a540..0ceb495c 100644 --- a/internal/analysis/semantics/typechecker/stmts.go +++ b/internal/analysis/semantics/typechecker/stmts.go @@ -43,7 +43,10 @@ func (c *checker) checkStmt(scope *refineScope, stmt ast.Stmt) { if finalType == nil { finalType = typeinfo.UnknownType{} } - if s.Type != nil && declared != nil && !typeinfo.Equal(declared, finalType) { + if s.IsAtomic { + finalType = c.atomicStorageType(s.Loc(), finalType) + } + if !s.IsAtomic && s.Type != nil && declared != nil && !typeinfo.Equal(declared, finalType) { c.info.BindNode(s.Type, finalType) } if declared != nil && s.Value != nil && !c.checkExprAssignable(scope, s.Value, declared, value) { diff --git a/internal/analysis/semantics/typechecker/symbol_types.go b/internal/analysis/semantics/typechecker/symbol_types.go index e5f15a86..b47074f1 100644 --- a/internal/analysis/semantics/typechecker/symbol_types.go +++ b/internal/analysis/semantics/typechecker/symbol_types.go @@ -48,10 +48,18 @@ func (c *checker) synthesizeSymbolType(mod *context.Module, sym *symbols.Symbol) switch n := sym.Node.(type) { case *ast.LetDecl: if n.Type != nil { - return c.typeFromSyntax(mod, n.Type) + typ := c.typeFromSyntax(mod, n.Type) + if n.IsAtomic { + return c.atomicStorageType(n.Loc(), typ) + } + return typ } if n.Value != nil { - return c.typeOfExpr(nil, n.Value, nil) + typ := c.typeOfExpr(nil, n.Value, nil) + if n.IsAtomic { + return c.atomicStorageType(n.Loc(), typ) + } + return typ } case *ast.ConstDecl: if n.Type != nil { @@ -69,10 +77,18 @@ func (c *checker) synthesizeSymbolType(mod *context.Module, sym *symbols.Symbol) } case *ast.LetStmt: if n.Type != nil { - return c.typeFromSyntax(mod, n.Type) + typ := c.typeFromSyntax(mod, n.Type) + if n.IsAtomic { + return c.atomicStorageType(n.Loc(), typ) + } + return typ } if n.Value != nil { - return c.typeOfExpr(nil, n.Value, nil) + typ := c.typeOfExpr(nil, n.Value, nil) + if n.IsAtomic { + return c.atomicStorageType(n.Loc(), typ) + } + return typ } } if sym.Node == nil { diff --git a/internal/analysis/semantics/typechecker/syntax_types.go b/internal/analysis/semantics/typechecker/syntax_types.go index ef0b07d9..7be67c6d 100644 --- a/internal/analysis/semantics/typechecker/syntax_types.go +++ b/internal/analysis/semantics/typechecker/syntax_types.go @@ -133,6 +133,8 @@ func (c *checker) typeFromSyntax(mod *context.Module, expr ast.TypeExpr) typeinf return &typeinfo.PointerType{Inner: inner} case *ast.RefType: return &typeinfo.RefType{Mutable: t.Mutable, Inner: c.typeFromSyntax(mod, t.Inner)} + case *ast.AtomicType: + return c.atomicStorageType(t.Loc(), c.typeFromSyntax(mod, t.Inner)) case *ast.RawPtrType: inner := c.typeFromSyntax(mod, t.Inner) if t.Inner == nil || typeinfo.IsBuiltinNamed(inner, "void") { diff --git a/internal/analysis/semantics/typechecker/typechecker.go b/internal/analysis/semantics/typechecker/typechecker.go index 87fce285..5372ffa5 100644 --- a/internal/analysis/semantics/typechecker/typechecker.go +++ b/internal/analysis/semantics/typechecker/typechecker.go @@ -32,7 +32,10 @@ func (c *checker) checkDecl(decl ast.Decl) { if finalType == nil { finalType = typeinfo.UnknownType{} } - if d.Type != nil && declared != nil && !typeinfo.Equal(declared, finalType) { + if d.IsAtomic { + finalType = c.atomicStorageType(d.Loc(), finalType) + } + if !d.IsAtomic && d.Type != nil && declared != nil && !typeinfo.Equal(declared, finalType) { c.info.BindNode(d.Type, finalType) } if declared != nil && d.Value != nil { @@ -172,6 +175,35 @@ func (c *checker) resolveBindingValueType(scope *refineScope, declared, value ty return value } +func (c *checker) atomicStorageType(loc source.Location, typ typeinfo.Type) typeinfo.Type { + if typ == nil { + return nil + } + if _, ok := typ.(*typeinfo.AtomicType); ok { + return typ + } + if typeinfo.IsInvalid(typ) || typeinfo.IsUnknown(typ) { + return &typeinfo.AtomicType{Inner: typ} + } + if !typeinfo.IsAtomicStorageAllowed(typ) { + c.ctx.Diagnostics.Add( + diagnostics.NewError(fmt.Sprintf("atomic storage requires bool, integer, raw pointer, or enum type, got %s", typ.String())). + WithCode(diagnostics.ErrTypeMismatch). + WithPrimaryLabel(&loc, "unsupported atomic type"). + WithHelp("use mutex-protected shared state for non-atomic types"), + ) + return typeinfo.InvalidType{} + } + return &typeinfo.AtomicType{Inner: typ} +} + +func atomicValueType(typ typeinfo.Type) typeinfo.Type { + if atomic, ok := typ.(*typeinfo.AtomicType); ok && atomic != nil { + return atomic.Inner + } + return typ +} + func (c *checker) checkModuleBindingType(loc source.Location, typ typeinfo.Type) { if c == nil || typ == nil { return @@ -623,6 +655,10 @@ func (c *checker) getTypeOfPrefix(scope *refineScope, expr *ast.PrefixExpr, expe case "*": switch ptr := right.(type) { case *typeinfo.RefType: + if atomic, ok := ptr.Inner.(*typeinfo.AtomicType); ok && atomic != nil { + c.info.BindNode(expr, atomic.Inner) + return atomic.Inner + } c.info.BindNode(expr, ptr.Inner) return ptr.Inner case *typeinfo.RawPtrType: @@ -643,9 +679,17 @@ func (c *checker) getTypeOfPrefix(scope *refineScope, expr *ast.PrefixExpr, expe ) return typeinfo.InvalidType{} } + if atomic, ok := ptr.Inner.(*typeinfo.AtomicType); ok && atomic != nil { + c.info.BindNode(expr, atomic.Inner) + return atomic.Inner + } c.info.BindNode(expr, ptr.Inner) return ptr.Inner case *typeinfo.PointerType: + if atomic, ok := ptr.Inner.(*typeinfo.AtomicType); ok && atomic != nil { + c.info.BindNode(expr, atomic.Inner) + return atomic.Inner + } c.info.BindNode(expr, ptr.Inner) return ptr.Inner } @@ -954,17 +998,17 @@ func (c *checker) checkRangePatternAgainstMatchValue(scope *refineScope, valueTy func (c *checker) binaryResult(op string, left, right typeinfo.Type) (typeinfo.Type, bool) { switch op { case "+", "-", "*", "/", "%": - if result := typeinfo.CommonNumericType(left, right); result != nil { + if result := typeinfo.CommonNumericType(atomicValueType(left), atomicValueType(right)); result != nil { return result, true } return nil, true case "==", "!=": - if typeinfo.Assignable(left, right) || typeinfo.Assignable(right, left) || typeinfo.CommonNumericType(left, right) != nil { + if typeinfo.Assignable(left, right) || typeinfo.Assignable(right, left) || typeinfo.CommonNumericType(atomicValueType(left), atomicValueType(right)) != nil { return &typeinfo.BuiltinType{Name: "bool"}, true } return nil, true case "<", "<=", ">", ">=": - if typeinfo.CommonNumericType(left, right) != nil { + if typeinfo.CommonNumericType(atomicValueType(left), atomicValueType(right)) != nil { return &typeinfo.BuiltinType{Name: "bool"}, true } return nil, true @@ -1911,6 +1955,8 @@ func (c *checker) collectTypeParams(typ typeinfo.Type, visit func(*typeinfo.Type c.collectTypeParams(t.Inner, visit) case *typeinfo.RefType: c.collectTypeParams(t.Inner, visit) + case *typeinfo.AtomicType: + c.collectTypeParams(t.Inner, visit) case *typeinfo.RawPtrType: c.collectTypeParams(t.Inner, visit) case *typeinfo.OptionalType: @@ -2011,6 +2057,10 @@ func (c *checker) inferTypeParamBindings(pattern, actual typeinfo.Type, bindings if got, ok := actual.(*typeinfo.RefType); ok && got.Mutable == p.Mutable { c.inferTypeParamBindings(p.Inner, got.Inner, bindings) } + case *typeinfo.AtomicType: + if got, ok := actual.(*typeinfo.AtomicType); ok { + c.inferTypeParamBindings(p.Inner, got.Inner, bindings) + } case *typeinfo.RawPtrType: if got, ok := actual.(*typeinfo.RawPtrType); ok { c.inferTypeParamBindings(p.Inner, got.Inner, bindings) @@ -3295,6 +3345,9 @@ func (c *checker) exprAccess(scope *refineScope, expr ast.Expr) (addressable boo rightType := c.typeOfExpr(scope, e.Right, nil) switch t := c.underlying(rightType).(type) { case *typeinfo.RefType: + if _, ok := t.Inner.(*typeinfo.AtomicType); ok { + return true, true + } return true, t.Mutable case *typeinfo.PointerType: _, rightMutable := c.exprAccess(scope, e.Right) diff --git a/internal/analysis/semantics/typechecker/typechecker_test.go b/internal/analysis/semantics/typechecker/typechecker_test.go index 87f9cfbf..dc6d248a 100644 --- a/internal/analysis/semantics/typechecker/typechecker_test.go +++ b/internal/analysis/semantics/typechecker/typechecker_test.go @@ -6195,6 +6195,80 @@ fn main() -> i32 { } } +func TestTypecheckerAllowsBoolRawPointerAndEnumAtomicStorage(t *testing.T) { + root := t.TempDir() + mustWriteType(t, filepath.Join(root, "main.fer"), ` +type Mode enum { + off, + on +} + +fn readFlag(value: &atomic bool) -> bool { + let current = *value + return current +} + +fn readPtr(value: &atomic ^void) -> ^void { + let current = *value + return current +} + +fn readMode(value: &atomic Mode) -> Mode { + let current = *value + return current +} + +fn main(raw: ^void) -> void { + let atomic flag = true + flag = false + let flagValue: bool = readFlag(&flag) + + let atomic ptr = raw + ptr = raw + let ptrValue: ^void = readPtr(&ptr) + + let atomic mode = Mode::off + mode = Mode::on + let modeValue: Mode = readMode(&mode) + + _ = flagValue + _ = ptrValue + _ = modeValue +} +`) + + result := compiler.New(root, ".fer", diagnostics.NewDiagnosticBag("")).ParseEntry(filepath.Join(root, "main.fer")) + if result.Diagnostics.HasErrors() { + t.Fatalf("unexpected diagnostics: %#v", result.Diagnostics.Diagnostics()) + } +} + +func TestTypecheckerRejectsNonAtomicStorageTypes(t *testing.T) { + root := t.TempDir() + mustWriteType(t, filepath.Join(root, "main.fer"), ` +fn main() -> void { + let atomic text = "hi" + let atomic items = []i32{1, 2, 3} + _ = text + _ = items +} +`) + + result := compiler.New(root, ".fer", diagnostics.NewDiagnosticBag("")).ParseEntry(filepath.Join(root, "main.fer")) + if !result.Diagnostics.HasErrors() { + t.Fatal("expected atomic storage diagnostics") + } + found := 0 + for _, diag := range result.Diagnostics.Diagnostics() { + if diag.Code == diagnostics.ErrTypeMismatch && strings.Contains(diag.Message, "atomic storage requires bool, integer, raw pointer, or enum type") { + found++ + } + } + if found < 2 { + t.Fatalf("expected atomic storage type mismatch diagnostics, got %#v", result.Diagnostics.Diagnostics()) + } +} + func TestTypecheckerRejectsImmutableValuePassedToInterfaceMutReceiverCall(t *testing.T) { root := t.TempDir() mustWriteType(t, filepath.Join(root, "main.fer"), ` diff --git a/internal/analysis/semantics/typeinfo/rewrite.go b/internal/analysis/semantics/typeinfo/rewrite.go index f2e989e5..5692ddab 100644 --- a/internal/analysis/semantics/typeinfo/rewrite.go +++ b/internal/analysis/semantics/typeinfo/rewrite.go @@ -41,6 +41,11 @@ func rewriteType(typ Type, pre, post func(Type) Type, seen map[Type]Type) Type { seen[typ] = copy copy.Inner = rewriteType(t.Inner, pre, post, seen) out = copy + case *AtomicType: + copy := &AtomicType{} + seen[typ] = copy + copy.Inner = rewriteType(t.Inner, pre, post, seen) + out = copy case *RawPtrType: copy := &RawPtrType{Const: t.Const} seen[typ] = copy diff --git a/internal/analysis/semantics/typeinfo/types.go b/internal/analysis/semantics/typeinfo/types.go index af172703..64155d85 100644 --- a/internal/analysis/semantics/typeinfo/types.go +++ b/internal/analysis/semantics/typeinfo/types.go @@ -112,6 +112,17 @@ func (t *RefType) String() string { return prefix + typeString(t.Inner) } +type AtomicType struct { + Inner Type +} + +func (t *AtomicType) String() string { + if t == nil { + return "" + } + return "atomic " + typeString(t.Inner) +} + type RawPtrType struct { Const bool Inner Type @@ -397,6 +408,51 @@ func IsUnknown(t Type) bool { return ok } +func IsAtomicStorageAllowed(typ Type) bool { + switch base := atomicStorageBaseType(typ).(type) { + case *BuiltinType: + if base.Name == "bool" { + return true + } + family, _, ok := NumericInfo(base) + return ok && family != NumericFloat + case *RawPtrType: + return true + case *EnumType: + return true + default: + return false + } +} + +func SupportsAtomicAdd(typ Type) bool { + base := atomicStorageBaseType(typ) + family, _, ok := NumericInfo(base) + return ok && family != NumericFloat +} + +func atomicStorageBaseType(typ Type) Type { + for typ != nil { + switch t := typ.(type) { + case *AtomicType: + typ = t.Inner + case *ApproxType: + typ = t.Inner + case *NamedType: + if t != nil && t.Decl != nil { + switch t.Decl.Type.(type) { + case *ast.EnumType: + return &EnumType{} + } + } + return typ + default: + return typ + } + } + return nil +} + func IsBuiltinNamed(t Type, name string) bool { if name == "str" { _, ok := t.(*StringType) @@ -473,6 +529,9 @@ func Equal(a, b Type) bool { case *RefType: bt, ok := b.(*RefType) return ok && at.Mutable == bt.Mutable && Equal(at.Inner, bt.Inner) + case *AtomicType: + bt, ok := b.(*AtomicType) + return ok && Equal(at.Inner, bt.Inner) case *RawPtrType: bt, ok := b.(*RawPtrType) return ok && at.Const == bt.Const && Equal(at.Inner, bt.Inner) @@ -542,6 +601,9 @@ func Assignable(dst, src Type) bool { if IsImplicitNumericWidening(dst, src) { return true } + if atomic, ok := dst.(*AtomicType); ok && src != nil { + return Assignable(atomic.Inner, src) + } if opt, ok := dst.(*OptionalType); ok && src != nil { return Assignable(opt.Inner, src) } diff --git a/internal/analysis/semantics/typeinfo/types_test.go b/internal/analysis/semantics/typeinfo/types_test.go index 353cf4e5..b07f21ba 100644 --- a/internal/analysis/semantics/typeinfo/types_test.go +++ b/internal/analysis/semantics/typeinfo/types_test.go @@ -1,6 +1,10 @@ package typeinfo -import "testing" +import ( + "testing" + + "compiler/internal/frontend/ast" +) func TestRefAndRawTypeString(t *testing.T) { immutable := &RefType{Inner: &BuiltinType{Name: "i32"}} @@ -135,3 +139,43 @@ func TestAssignableAllowsImplicitNumericWideningAndIntToFloat(t *testing.T) { t.Fatal("expected f32 to remain non-assignable to i32") } } + +func TestIsAtomicStorageAllowed(t *testing.T) { + mode := &NamedType{ + Name: "Mode", + Decl: &ast.TypeDecl{Type: &ast.EnumType{}}, + } + cases := []struct { + name string + typ Type + want bool + }{ + {name: "bool", typ: &BuiltinType{Name: "bool"}, want: true}, + {name: "int", typ: &BuiltinType{Name: "i32"}, want: true}, + {name: "float", typ: &BuiltinType{Name: "f32"}, want: false}, + {name: "rawptr", typ: &RawPtrType{Inner: &BuiltinType{Name: "u8"}}, want: true}, + {name: "enum", typ: mode, want: true}, + {name: "str", typ: &StringType{}, want: false}, + {name: "slice", typ: &SliceType{Inner: &BuiltinType{Name: "u8"}}, want: false}, + } + for _, tc := range cases { + if got := IsAtomicStorageAllowed(tc.typ); got != tc.want { + t.Fatalf("%s: expected %v, got %v", tc.name, tc.want, got) + } + } +} + +func TestSupportsAtomicAdd(t *testing.T) { + if !SupportsAtomicAdd(&BuiltinType{Name: "i32"}) { + t.Fatal("expected i32 atomic add support") + } + if SupportsAtomicAdd(&BuiltinType{Name: "f32"}) { + t.Fatal("expected f32 atomic add to remain unsupported") + } + if SupportsAtomicAdd(&BuiltinType{Name: "bool"}) { + t.Fatal("expected bool atomic add to remain unsupported") + } + if SupportsAtomicAdd(&RawPtrType{Inner: &BuiltinType{Name: "u8"}}) { + t.Fatal("expected raw pointer atomic add to remain unsupported") + } +} diff --git a/internal/backend/aggregate_layout.go b/internal/backend/aggregate_layout.go index 1626e0d7..bc48c784 100644 --- a/internal/backend/aggregate_layout.go +++ b/internal/backend/aggregate_layout.go @@ -180,6 +180,8 @@ func sharedScalarSizeAlign(typ typeinfo.Type) (int64, int64, error) { case *typeinfo.PointerType, *typeinfo.RefType, *typeinfo.RawPtrType, *typeinfo.FuncType, *typeinfo.MapType: ptrSize := abi.PointerBytes() return ptrSize, ptrSize, nil + case *typeinfo.AtomicType: + return sharedScalarSizeAlign(t.Inner) } return 0, 0, fmt.Errorf("not a primitive type") } diff --git a/internal/backend/common/common.go b/internal/backend/common/common.go index b4ba45d4..46722b04 100644 --- a/internal/backend/common/common.go +++ b/internal/backend/common/common.go @@ -353,6 +353,8 @@ func RuntimeTypeKey(typ typeinfo.Type) string { return "ref_mut__" + RuntimeTypeKey(t.Inner) } return "ref__" + RuntimeTypeKey(t.Inner) + case *typeinfo.AtomicType: + return "atomic__" + RuntimeTypeKey(t.Inner) case *typeinfo.RawPtrType: return "rawptr__" + RuntimeTypeKey(t.Inner) case *typeinfo.SliceType: diff --git a/internal/backend/llvm/link.go b/internal/backend/llvm/link.go index c4b22690..9fb5df96 100644 --- a/internal/backend/llvm/link.go +++ b/internal/backend/llvm/link.go @@ -64,6 +64,9 @@ func CompileIR(llvmIR, outputPath string, opts CompileOptions) error { } args = append(args, "-O0", "-fno-omit-frame-pointer") } + if runtime.GOOS != "windows" { + args = append(args, "-pthread") + } args = append(args, irFile) args = append(args, runtimeLib) args = append(args, "-o", outputPath) @@ -128,6 +131,10 @@ func llvmBaseType(typ typeinfo.Type) (string, error) { } case *typeinfo.PointerType, *typeinfo.RefType, *typeinfo.RawPtrType, *typeinfo.FuncType, *typeinfo.MapType: return "ptr", nil + case *typeinfo.EnumType, *typeinfo.ErrorSetType: + return "i32", nil + case *typeinfo.AtomicType: + return llvmBaseType(base.Inner) case *typeinfo.OptionalType: if backend.OptionalUsesNiche(base.Inner) { return llvmBaseType(base.Inner) diff --git a/internal/backend/llvm/llvm.go b/internal/backend/llvm/llvm.go index 5de0917c..7008e789 100644 --- a/internal/backend/llvm/llvm.go +++ b/internal/backend/llvm/llvm.go @@ -371,6 +371,8 @@ func (d *debugState) getType(state *moduleState, typ typeinfo.Type) int { case *typeinfo.RefType: innerID := d.getType(state, t.Inner) return d.getPointerType(innerID) + case *typeinfo.AtomicType: + return d.getType(state, t.Inner) case *typeinfo.RawPtrType: innerID := d.getType(state, t.Inner) return d.getPointerType(innerID) @@ -1744,6 +1746,8 @@ func lowerInstr(state *moduleState, instr mir.Instr) (string, error) { return lowerStoreField(state, i) case *mir.StoreInstr: return lowerStorePlace(state, i) + case *mir.AtomicAddInstr: + return lowerAtomicAdd(state, i) case *mir.EvalInstr: if call, ok := i.Value.(*mir.CallValue); ok { return lowerCall(state, "", nil, call) @@ -1855,6 +1859,9 @@ func lowerSSAAssign(state *moduleState, name string, typ typeinfo.Type, value mi if load, ok := value.(*mir.LoadValue); ok { return lowerLoadValue(state, name, typ, load) } + if load, ok := value.(*mir.AtomicLoadValue); ok { + return lowerAtomicLoadValue(state, name, typ, load) + } if bin, ok := value.(*mir.BinaryValue); ok { if line, handled, err := lowerAggregateCompare(state, name, typ, bin); handled || err != nil { return line, err @@ -2102,6 +2109,9 @@ func lowerLoadValue(state *moduleState, targetName string, targetType typeinfo.T if err != nil { return "", err } + if _, ok := targetType.(*typeinfo.AtomicType); ok { + return lowerAtomicLoadFromPointer(state, targetName, targetType, ptr) + } irType, err := llvmBaseType(targetType) if err != nil { return "", err @@ -2109,6 +2119,43 @@ func lowerLoadValue(state *moduleState, targetName string, targetType typeinfo.T return fmt.Sprintf("%s = load %s, ptr %s", llvmLocalName(targetName), irType, ptr), nil } +func lowerAtomicLoadValue(state *moduleState, targetName string, targetType typeinfo.Type, load *mir.AtomicLoadValue) (string, error) { + ptr, err := lowerValue(state, load.Pointer) + if err != nil { + return "", err + } + return lowerAtomicLoadFromPointer(state, targetName, targetType, ptr) +} + +func lowerAtomicLoadFromPointer(state *moduleState, targetName string, targetType typeinfo.Type, ptr string) (string, error) { + typ := atomicInnerType(targetType) + irType, err := llvmBaseType(typ) + if err != nil { + return "", err + } + return llvmLoadLine(llvmLocalName(targetName), irType, ptr, true), nil +} + +func lowerAtomicAdd(state *moduleState, instr *mir.AtomicAddInstr) (string, error) { + if instr == nil { + return "", nil + } + ptr, err := lowerValue(state, instr.Pointer) + if err != nil { + return "", err + } + delta, err := lowerValue(state, instr.Delta) + if err != nil { + return "", err + } + irType, err := llvmBaseType(instr.Type) + if err != nil { + return "", err + } + tmp := freshTemp(state, "atomic_add") + return fmt.Sprintf("%s = atomicrmw add ptr %s, %s %s seq_cst", tmp, ptr, irType, delta), nil +} + func lowerFieldLoad(state *moduleState, targetName string, targetType typeinfo.Type, field *mir.FieldLoadValue) (string, error) { lines, addr, _, err := lowerFieldAddress(state, field.Base, field.FieldIndex) if err != nil { @@ -2387,6 +2434,14 @@ func lowerStorePlace(state *moduleState, instr *mir.StoreInstr) (string, error) if err != nil { return "", err } + if atomic, ok := targetType.(*typeinfo.AtomicType); ok && atomic != nil { + irType, err := llvmBaseType(atomic.Inner) + if err != nil { + return "", err + } + lines = append(lines, fmt.Sprintf("store atomic %s %s, ptr %s seq_cst, align %d", irType, val, addr, irTypeAlign(irType))) + return strings.Join(lines, "\n"), nil + } irType, err := llvmBaseType(instr.Value.Type()) if err != nil { return "", err @@ -3778,20 +3833,24 @@ func lowerInterfaceConcretePointer(state *moduleState, value mir.Value, concrete return lowerInterfaceConcretePointer(state, v.Right, concreteType) } case *mir.LocalValue: - if agg, ok := state.aggLocals[v.LocalID]; ok { - return nil, llvmLocalName(agg.PtrName), nil - } - if sc, ok := state.scalarLocals[v.LocalID]; ok { - return nil, sc.AllocaName, nil + if typeinfo.Equal(concreteType, v.Type()) { + if agg, ok := state.aggLocals[v.LocalID]; ok { + return nil, llvmLocalName(agg.PtrName), nil + } + if sc, ok := state.scalarLocals[v.LocalID]; ok { + return nil, sc.AllocaName, nil + } } case *mir.NameValue: if len(v.Path) == 1 { if local := becommon.FindLocalByName(state.fn, v.Path[0]); local != nil { - if agg, ok := state.aggLocals[local.ID]; ok { - return nil, llvmLocalName(agg.PtrName), nil - } - if sc, ok := state.scalarLocals[local.ID]; ok { - return nil, sc.AllocaName, nil + if typeinfo.Equal(concreteType, v.Type()) { + if agg, ok := state.aggLocals[local.ID]; ok { + return nil, llvmLocalName(agg.PtrName), nil + } + if sc, ok := state.scalarLocals[local.ID]; ok { + return nil, sc.AllocaName, nil + } } } } @@ -4538,8 +4597,8 @@ func lowerValue(state *moduleState, value mir.Value) (string, error) { } if sc, ok := state.scalarLocals[v.LocalID]; ok { tmp := freshTemp(state, "ld") - state.pendingLines = append(state.pendingLines, - fmt.Sprintf("%s = load %s, ptr %s", tmp, sc.IRType, sc.AllocaName)) + _, isAtomic := v.Type().(*typeinfo.AtomicType) + state.pendingLines = append(state.pendingLines, llvmLoadLine(tmp, sc.IRType, sc.AllocaName, isAtomic)) return tmp, nil } return llvmLocalName(becommon.LocalNameByID(state.fn, v.LocalID)), nil @@ -4566,8 +4625,8 @@ func lowerValue(state *moduleState, value mir.Value) (string, error) { } if sc, ok := state.scalarLocals[local.ID]; ok { tmp := freshTemp(state, "ld") - state.pendingLines = append(state.pendingLines, - fmt.Sprintf("%s = load %s, ptr %s", tmp, sc.IRType, sc.AllocaName)) + _, isAtomic := v.Type().(*typeinfo.AtomicType) + state.pendingLines = append(state.pendingLines, llvmLoadLine(tmp, sc.IRType, sc.AllocaName, isAtomic)) return tmp, nil } return llvmLocalName(local.Name), nil @@ -4587,7 +4646,8 @@ func lowerValue(state *moduleState, value mir.Value) (string, error) { if v.LinkName != "" { sym = "@" + becommon.SanitizeLinkName(v.LinkName) } - state.pendingLines = append(state.pendingLines, fmt.Sprintf("%s = load %s, ptr %s", tmp, irType, sym)) + _, isAtomic := v.Type().(*typeinfo.AtomicType) + state.pendingLines = append(state.pendingLines, llvmLoadLine(tmp, irType, sym, isAtomic)) return tmp, nil } } @@ -4656,6 +4716,20 @@ func lowerValue(state *moduleState, value mir.Value) (string, error) { } } +func atomicInnerType(typ typeinfo.Type) typeinfo.Type { + if atomic, ok := typ.(*typeinfo.AtomicType); ok && atomic != nil { + return atomic.Inner + } + return typ +} + +func llvmLoadLine(target, irType, ptr string, atomic bool) string { + if atomic { + return fmt.Sprintf("%s = load atomic %s, ptr %s seq_cst, align %d", target, irType, ptr, irTypeAlign(irType)) + } + return fmt.Sprintf("%s = load %s, ptr %s", target, irType, ptr) +} + func lowerMapCompositeValue(state *moduleState, comp *mir.CompositeValue, targetType typeinfo.Type) (string, error) { mapType, ok := backend.ResolveMapType(targetType) if !ok { @@ -5503,9 +5577,13 @@ func aggregateSizeAlignOfPrimitive(typ typeinfo.Type) (int64, int64, error) { case "f64": return 8, 8, nil } + case *typeinfo.EnumType, *typeinfo.ErrorSetType: + return 4, 4, nil case *typeinfo.PointerType, *typeinfo.RefType, *typeinfo.RawPtrType, *typeinfo.FuncType, *typeinfo.MapType: ptrSize := abi.PointerBytes() return ptrSize, ptrSize, nil + case *typeinfo.AtomicType: + return aggregateSizeAlignOfPrimitive(t.Inner) } return 0, 0, fmt.Errorf("not a primitive type") } diff --git a/internal/backend/llvm/lower_test.go b/internal/backend/llvm/lower_test.go index 15238097..a0b13370 100644 --- a/internal/backend/llvm/lower_test.go +++ b/internal/backend/llvm/lower_test.go @@ -2547,6 +2547,115 @@ fn apply(f: fn(i32) -> i32, x: i32) -> i32 { } } +func TestLowerTaskRunAndAtomicToLLVM(t *testing.T) { + root := t.TempDir() + mustWrite(t, filepath.Join(root, "main.fer"), ` +import "std/task" + +fn worker(value: &atomic i32) -> void { + (*value)++ +} + +fn main() -> void { + let atomic value = 0 + let handle = task::Run(worker, &value) + handle.Wait() + println(value) +} +`) + result := compiler.ParsePath(filepath.Join(root, "main.fer")) + if result.Diagnostics.HasErrors() { + t.Fatalf("unexpected diagnostics: %#v", result.Diagnostics.Diagnostics()) + } + lowerer, err := registry.New(backend.TargetLLVM) + if err != nil { + t.Fatalf("unexpected llvm error: %v", err) + } + artifact, err := lowerer.LowerModule(testUnit(result)) + if err != nil { + t.Fatalf("lower llvm: %v", err) + } + text := artifact.Text + for _, want := range []string{ + "@ferret_task_run_raw", + "@ferret_task_wait", + "atomicrmw add", + "load atomic i32", + } { + if !strings.Contains(text, want) { + t.Fatalf("expected %s in llvm output:\n%s", want, text) + } + } +} + +func TestLowerAtomicBoolRawPtrAndEnumToLLVM(t *testing.T) { + root := t.TempDir() + mustWrite(t, filepath.Join(root, "main.fer"), ` +type Mode enum { + off, + on +} + +fn readFlag(value: &atomic bool) -> bool { + let current = *value + return current +} + +fn readPtr(value: &atomic ^void) -> ^void { + let current = *value + return current +} + +fn readMode(value: &atomic Mode) -> Mode { + let current = *value + return current +} + +fn main(raw: ^void) -> void { + let atomic flag = true + flag = false + let flagValue: bool = readFlag(&flag) + + let atomic ptr = raw + ptr = raw + let ptrValue: ^void = readPtr(&ptr) + + let atomic mode = Mode::off + mode = Mode::on + let modeValue: Mode = readMode(&mode) + + _ = flagValue + _ = ptrValue + _ = modeValue +} +`) + result := compiler.ParsePath(filepath.Join(root, "main.fer")) + if result.Diagnostics.HasErrors() { + t.Fatalf("unexpected diagnostics: %#v", result.Diagnostics.Diagnostics()) + } + lowerer, err := registry.New(backend.TargetLLVM) + if err != nil { + t.Fatalf("unexpected llvm error: %v", err) + } + artifact, err := lowerer.LowerModule(testUnit(result)) + if err != nil { + t.Fatalf("lower llvm: %v", err) + } + text := artifact.Text + for _, want := range []string{ + "store atomic i8", + "load atomic i8", + "store atomic ptr", + "load atomic ptr", + "store atomic i32", + "load atomic i32", + } { + if !strings.Contains(text, want) { + t.Fatalf("expected %s in llvm output:\n%s", want, text) + } + } +} + func TestLowerFunctionRouteTableStructToLLVM(t *testing.T) { root := t.TempDir() mustWrite(t, filepath.Join(root, "main.fer"), ` diff --git a/internal/backend/type_helpers.go b/internal/backend/type_helpers.go index 2752ed19..6e3e8f10 100644 --- a/internal/backend/type_helpers.go +++ b/internal/backend/type_helpers.go @@ -56,6 +56,8 @@ func ResolveMapType(typ typeinfo.Type) (*typeinfo.MapType, bool) { return ResolveMapType(t.Inner) case *typeinfo.RefType: return ResolveMapType(t.Inner) + case *typeinfo.AtomicType: + return ResolveMapType(t.Inner) case *typeinfo.NamedType: if mt, ok := namedMapAliasType(t); ok { return mt, true @@ -144,6 +146,8 @@ func aliasSyntaxType(expr ast.TypeExpr, bindings map[string]typeinfo.Type) typei return &typeinfo.PointerType{Inner: aliasSyntaxType(t.Inner, bindings)} case *ast.RefType: return &typeinfo.RefType{Mutable: t.Mutable, Inner: aliasSyntaxType(t.Inner, bindings)} + case *ast.AtomicType: + return &typeinfo.AtomicType{Inner: aliasSyntaxType(t.Inner, bindings)} case *ast.RawPtrType: return &typeinfo.RawPtrType{Const: t.Const, Inner: aliasSyntaxType(t.Inner, bindings)} case *ast.OptionalType: diff --git a/internal/backend/typeinfo.go b/internal/backend/typeinfo.go index 598c8b61..4047f845 100644 --- a/internal/backend/typeinfo.go +++ b/internal/backend/typeinfo.go @@ -135,6 +135,8 @@ func DescribeRuntimeType(typ typeinfo.Type) RuntimeTypeDescriptor { } case *typeinfo.PointerType, *typeinfo.RefType, *typeinfo.RawPtrType: desc.Flags |= RuntimeTypeFlagPointer + case *typeinfo.AtomicType: + return DescribeRuntimeType(t.Inner) case *typeinfo.InterfaceType: desc.Flags |= RuntimeTypeFlagInterface case *typeinfo.SliceType: diff --git a/internal/frontend/ast/clone.go b/internal/frontend/ast/clone.go index c1a4eca1..6132d3e0 100644 --- a/internal/frontend/ast/clone.go +++ b/internal/frontend/ast/clone.go @@ -68,6 +68,10 @@ func CloneExprWithNodeMapAndSubstitute(expr Expr, substitute func(Node) Expr) (E out := &RefType{Mutable: t.Mutable, Inner: cloneType(t.Inner), Location: t.Location} mapping[t] = out return out + case *AtomicType: + out := &AtomicType{Inner: cloneType(t.Inner), Location: t.Location} + mapping[t] = out + return out case *RawPtrType: out := &RawPtrType{Const: t.Const, Inner: cloneType(t.Inner), Location: t.Location} mapping[t] = out @@ -248,7 +252,7 @@ func CloneExprWithNodeMapAndSubstitute(expr Expr, substitute func(Node) Expr) (E case *BlockStmt: return cloneBlock(s) case *LetStmt: - out := &LetStmt{Name: cloneIdent(s.Name), IsMut: s.IsMut, Type: cloneType(s.Type), Value: cloneExpr(s.Value), Location: s.Location} + out := &LetStmt{Name: cloneIdent(s.Name), IsMut: s.IsMut, IsAtomic: s.IsAtomic, Type: cloneType(s.Type), Value: cloneExpr(s.Value), Location: s.Location} mapping[s] = out return out case *ConstStmt: diff --git a/internal/frontend/ast/debug.go b/internal/frontend/ast/debug.go index addd7241..7c9f3ceb 100644 --- a/internal/frontend/ast/debug.go +++ b/internal/frontend/ast/debug.go @@ -38,7 +38,7 @@ func DebugModule(mod *Module) map[string]any { func debugDecl(decl Decl) any { switch d := decl.(type) { case *LetDecl: - return map[string]any{"kind": "LetDecl", "name": debugExpr(d.Name), "attrs": debugAttrs(d.Attrs), "is_mut": d.IsMut, "type": debugType(d.Type), "value": debugExpr(d.Value), "loc": debugLoc(d.Location)} + return map[string]any{"kind": "LetDecl", "name": debugExpr(d.Name), "attrs": debugAttrs(d.Attrs), "is_mut": d.IsMut, "is_atomic": d.IsAtomic, "type": debugType(d.Type), "value": debugExpr(d.Value), "loc": debugLoc(d.Location)} case *ConstDecl: return map[string]any{"kind": "ConstDecl", "name": debugExpr(d.Name), "attrs": debugAttrs(d.Attrs), "type": debugType(d.Type), "value": debugExpr(d.Value), "loc": debugLoc(d.Location)} case *TypeDecl: @@ -130,7 +130,7 @@ func debugStmt(stmt Stmt) any { } return map[string]any{"kind": "BlockStmt", "stmts": stmts, "comptime": s.Comptime, "loc": debugLoc(s.Location)} case *LetStmt: - return map[string]any{"kind": "LetStmt", "name": debugExpr(s.Name), "is_mut": s.IsMut, "type": debugType(s.Type), "value": debugExpr(s.Value), "loc": debugLoc(s.Location)} + return map[string]any{"kind": "LetStmt", "name": debugExpr(s.Name), "is_mut": s.IsMut, "is_atomic": s.IsAtomic, "type": debugType(s.Type), "value": debugExpr(s.Value), "loc": debugLoc(s.Location)} case *ConstStmt: return map[string]any{"kind": "ConstStmt", "name": debugExpr(s.Name), "type": debugType(s.Type), "value": debugExpr(s.Value), "loc": debugLoc(s.Location)} case *ReturnStmt: @@ -285,6 +285,8 @@ func debugType(typ TypeExpr) any { return map[string]any{"kind": "PointerType", "inner": debugType(t.Inner), "loc": debugLoc(t.Location)} case *RefType: return map[string]any{"kind": "RefType", "mutable": t.Mutable, "inner": debugType(t.Inner), "loc": debugLoc(t.Location)} + case *AtomicType: + return map[string]any{"kind": "AtomicType", "inner": debugType(t.Inner), "loc": debugLoc(t.Location)} case *RawPtrType: return map[string]any{"kind": "RawPtrType", "const": t.Const, "inner": debugType(t.Inner), "loc": debugLoc(t.Location)} case *SelfType: diff --git a/internal/frontend/ast/decl.go b/internal/frontend/ast/decl.go index 139c08bf..92c827d5 100644 --- a/internal/frontend/ast/decl.go +++ b/internal/frontend/ast/decl.go @@ -28,6 +28,7 @@ type LetDecl struct { Doc *CommentGroup Attrs []Attribute IsMut bool + IsAtomic bool Type TypeExpr Value Expr Location source.Location diff --git a/internal/frontend/ast/dump.go b/internal/frontend/ast/dump.go index f765374d..f1cfee6a 100644 --- a/internal/frontend/ast/dump.go +++ b/internal/frontend/ast/dump.go @@ -9,6 +9,9 @@ func DeclSummary(decl Decl) string { case *ConstDecl: return fmt.Sprintf("const %s", d.Name.Text()) case *LetDecl: + if d.IsAtomic { + return fmt.Sprintf("let atomic %s", d.Name.Text()) + } if d.IsMut { return fmt.Sprintf("let mut %s", d.Name.Text()) } diff --git a/internal/frontend/ast/format.go b/internal/frontend/ast/format.go index 23a474e8..c6d32206 100644 --- a/internal/frontend/ast/format.go +++ b/internal/frontend/ast/format.go @@ -41,6 +41,8 @@ func TypeString(typ TypeExpr) string { return "&mut " + TypeString(t.Inner) } return "&" + TypeString(t.Inner) + case *AtomicType: + return "atomic " + TypeString(t.Inner) case *RawPtrType: if t.Const { return "^const " + TypeString(t.Inner) diff --git a/internal/frontend/ast/stmt.go b/internal/frontend/ast/stmt.go index 7523db4c..61ff9baf 100644 --- a/internal/frontend/ast/stmt.go +++ b/internal/frontend/ast/stmt.go @@ -15,6 +15,7 @@ type LetStmt struct { Name *Ident Doc *CommentGroup IsMut bool + IsAtomic bool Type TypeExpr Value Expr Location source.Location diff --git a/internal/frontend/ast/type.go b/internal/frontend/ast/type.go index e7b90fd6..210695fb 100644 --- a/internal/frontend/ast/type.go +++ b/internal/frontend/ast/type.go @@ -45,6 +45,14 @@ type RefType struct { func (*RefType) typeNode() {} func (t *RefType) Loc() source.Location { return t.Location } +type AtomicType struct { + Inner TypeExpr + Location source.Location +} + +func (*AtomicType) typeNode() {} +func (t *AtomicType) Loc() source.Location { return t.Location } + type RawPtrType struct { Const bool Inner TypeExpr diff --git a/internal/frontend/ast/walk.go b/internal/frontend/ast/walk.go index 667b450b..ea995745 100644 --- a/internal/frontend/ast/walk.go +++ b/internal/frontend/ast/walk.go @@ -23,6 +23,8 @@ func WalkType(typ TypeExpr, visit func(TypeExpr) bool) { WalkType(t.Inner, visit) case *RefType: WalkType(t.Inner, visit) + case *AtomicType: + WalkType(t.Inner, visit) case *RawPtrType: WalkType(t.Inner, visit) case *OptionalType: diff --git a/internal/frontend/parser/decl.go b/internal/frontend/parser/decl.go index 8e27ae75..6f77eaaa 100644 --- a/internal/frontend/parser/decl.go +++ b/internal/frontend/parser/decl.go @@ -47,7 +47,11 @@ func (p *Parser) parseTypeDecl(doc *ast.CommentGroup, attrs []ast.Attribute) ast func (p *Parser) parseLetDecl(doc *ast.CommentGroup, attrs []ast.Attribute) ast.Decl { start := p.expect(tokens.LET, "expected 'let'").Start + isAtomic := p.match(tokens.ATOMIC) isMut := p.match(tokens.MUT) + if isAtomic && isMut { + p.errorAt(p.locOfToken(p.previous()), "atomic bindings do not use 'mut'") + } nameTok := p.expect(tokens.IDENT, "expected variable name") name := nameTok.Literal var typ ast.TypeExpr @@ -64,6 +68,7 @@ func (p *Parser) parseLetDecl(doc *ast.CommentGroup, attrs []ast.Attribute) ast. Doc: doc, Attrs: attrs, IsMut: isMut, + IsAtomic: isAtomic, Type: typ, Value: value, Location: p.locFrom(start), diff --git a/internal/frontend/parser/parser.go b/internal/frontend/parser/parser.go index 8e1f11d5..f666bc49 100644 --- a/internal/frontend/parser/parser.go +++ b/internal/frontend/parser/parser.go @@ -687,7 +687,7 @@ func (p *Parser) parseNamePath() []string { func (p *Parser) startsType() bool { switch p.current().Kind { - case tokens.IDENT, tokens.FN, tokens.QUESTION, tokens.AMP, tokens.CARET, tokens.ASTERISK, tokens.TILDE, + case tokens.IDENT, tokens.FN, tokens.QUESTION, tokens.AMP, tokens.CARET, tokens.ASTERISK, tokens.TILDE, tokens.ATOMIC, tokens.LBRACK, tokens.LPAREN, tokens.STRUCT, tokens.INTERFACE, tokens.ENUM, tokens.UNION, tokens.ERROR: return true diff --git a/internal/frontend/parser/stmt.go b/internal/frontend/parser/stmt.go index 191347e9..10a250fc 100644 --- a/internal/frontend/parser/stmt.go +++ b/internal/frontend/parser/stmt.go @@ -120,7 +120,11 @@ func (p *Parser) parseStmt() ast.Stmt { func (p *Parser) parseLetStmt(doc *ast.CommentGroup) ast.Stmt { start := p.advance().Start + isAtomic := p.match(tokens.ATOMIC) isMut := p.match(tokens.MUT) + if isAtomic && isMut { + p.errorAt(p.locOfToken(p.previous()), "atomic bindings do not use 'mut'") + } nameTok := p.expect(tokens.IDENT, "expected variable name") var typ ast.TypeExpr if p.match(tokens.COLON) { @@ -134,6 +138,7 @@ func (p *Parser) parseLetStmt(doc *ast.CommentGroup) ast.Stmt { Name: &ast.Ident{Path: []string{nameTok.Literal}, Location: p.locOfToken(nameTok)}, Doc: doc, IsMut: isMut, + IsAtomic: isAtomic, Type: typ, Value: value, Location: p.locFrom(start), diff --git a/internal/frontend/parser/type.go b/internal/frontend/parser/type.go index dc991ce5..569eecc7 100644 --- a/internal/frontend/parser/type.go +++ b/internal/frontend/parser/type.go @@ -24,6 +24,9 @@ func (p *Parser) parseType() ast.TypeExpr { } ref.Inner = p.parseType() return ref + case tokens.ATOMIC: + p.advance() + return &ast.AtomicType{Inner: p.parseType(), Location: p.locFrom(start)} case tokens.CARET: p.advance() raw := &ast.RawPtrType{Location: p.locFrom(start)} diff --git a/internal/frontend/parser/validate.go b/internal/frontend/parser/validate.go index 17147918..5d6a5bc4 100644 --- a/internal/frontend/parser/validate.go +++ b/internal/frontend/parser/validate.go @@ -414,6 +414,8 @@ func renderType(typ ast.TypeExpr) string { prefix = "&mut " } return prefix + renderType(t.Inner) + case *ast.AtomicType: + return "atomic " + renderType(t.Inner) case *ast.RawPtrType: if t.Const { return "^const " + renderType(t.Inner) diff --git a/internal/ir/hir/generate.go b/internal/ir/hir/generate.go index 9a410060..9bc8860b 100644 --- a/internal/ir/hir/generate.go +++ b/internal/ir/hir/generate.go @@ -350,11 +350,15 @@ func (g *generator) generateLetDecl(d *ast.LetDecl) *Global { if d == nil { return nil } + typ := effectiveType(g.types, d.Type, d.Value) + if d.IsAtomic { + typ = atomicBindingHIRType(typ) + } return &Global{ Name: d.Name.Text(), - Mutable: d.IsMut, + Mutable: d.IsMut || d.IsAtomic, Constant: false, - Type: effectiveType(g.types, d.Type, d.Value), + Type: typ, Value: g.generateExpr(d.Value), Location: d.Location, Source: d, @@ -588,13 +592,16 @@ func (g *generator) generateStmt(stmt ast.Stmt) Stmt { return g.generateBlock(s) case *ast.LetStmt: targetType := effectiveType(g.types, s.Type, s.Value) + if s.IsAtomic { + targetType = atomicBindingHIRType(targetType) + } if _, ok := targetType.(*typeinfo.FuncType); ok && !s.IsMut { if _, ok := s.Value.(*ast.LambdaExpr); ok { _ = g.generateExpr(s.Value) return nil } } - out := &LetStmt{Name: g.maybeMangledLocalName(s.Name), LocalID: g.maybeLocalID(s.Name), Mutable: s.IsMut, Type: targetType, Value: g.generateExprForTarget(s.Value, targetType)} + out := &LetStmt{Name: g.maybeMangledLocalName(s.Name), LocalID: g.maybeLocalID(s.Name), Mutable: s.IsMut || s.IsAtomic, Type: targetType, Value: g.generateExprForTarget(s.Value, targetType)} out.Location = s.Location return out case *ast.ConstStmt: @@ -1205,6 +1212,16 @@ func effectiveType(types *typeinfo.ModuleInfo, syntax ast.TypeExpr, value ast.Ex return typeinfo.UnknownType{} } +func atomicBindingHIRType(typ typeinfo.Type) typeinfo.Type { + if typ == nil { + return nil + } + if _, ok := typ.(*typeinfo.AtomicType); ok { + return typ + } + return &typeinfo.AtomicType{Inner: typ} +} + type structLiteralField struct { Name string Default ast.Expr diff --git a/internal/ir/hir/specialize.go b/internal/ir/hir/specialize.go index ec4ca9f6..fed699af 100644 --- a/internal/ir/hir/specialize.go +++ b/internal/ir/hir/specialize.go @@ -836,6 +836,10 @@ func inferTypeBindings(pattern, actual typeinfo.Type, bindings map[*typeinfo.Typ if got, ok := actual.(*typeinfo.RefType); ok { inferTypeBindings(p.Inner, got.Inner, bindings) } + case *typeinfo.AtomicType: + if got, ok := actual.(*typeinfo.AtomicType); ok { + inferTypeBindings(p.Inner, got.Inner, bindings) + } case *typeinfo.RawPtrType: if got, ok := actual.(*typeinfo.RawPtrType); ok { inferTypeBindings(p.Inner, got.Inner, bindings) @@ -1372,6 +1376,13 @@ func (s *specializer) specializeNamedTypeRefs(typ typeinfo.Type, seen map[typein seen[t] = struct{}{} t.Inner = s.specializeNamedTypeRefs(t.Inner, seen) return t + case *typeinfo.AtomicType: + if _, ok := seen[t]; ok { + return t + } + seen[t] = struct{}{} + t.Inner = s.specializeNamedTypeRefs(t.Inner, seen) + return t case *typeinfo.RawPtrType: if _, ok := seen[t]; ok { return t @@ -1545,6 +1556,8 @@ func typeHasTypeParam(typ typeinfo.Type) bool { return typeHasTypeParam(t.Inner) case *typeinfo.RefType: return typeHasTypeParam(t.Inner) + case *typeinfo.AtomicType: + return typeHasTypeParam(t.Inner) case *typeinfo.RawPtrType: return typeHasTypeParam(t.Inner) case *typeinfo.OptionalType: diff --git a/internal/ir/mir/debug.go b/internal/ir/mir/debug.go index e975fc61..5ea6133d 100644 --- a/internal/ir/mir/debug.go +++ b/internal/ir/mir/debug.go @@ -135,6 +135,8 @@ func debugValue(value Value) any { return map[string]any{"kind": "addr_of", "mutable": v.Mutable, "raw": v.Raw, "source": debugValue(v.Source), "type": typeString(v.Type())} case *LoadValue: return map[string]any{"kind": "load", "pointer": debugValue(v.Pointer), "type": typeString(v.Type())} + case *AtomicLoadValue: + return map[string]any{"kind": "atomic_load", "pointer": debugValue(v.Pointer), "type": typeString(v.Type())} case *BinaryValue: return map[string]any{"kind": "binary", "op": v.Op, "left": debugValue(v.Left), "right": debugValue(v.Right), "type": typeString(v.Type())} case *PostfixValue: diff --git a/internal/ir/mir/format.go b/internal/ir/mir/format.go index 485c87a6..49cc6e48 100644 --- a/internal/ir/mir/format.go +++ b/internal/ir/mir/format.go @@ -175,6 +175,8 @@ func formatInstr(instr Instr) string { return fmt.Sprintf("%s = %s", formatLocalRef(currentFnForFormat, i.TargetID), formatValue(i.Value)) case *StoreInstr: return fmt.Sprintf("store %s = %s", formatPlace(i.Target), formatValue(i.Value)) + case *AtomicAddInstr: + return fmt.Sprintf("atomic_add %s, %s", formatValue(i.Pointer), formatValue(i.Delta)) case *StoreFieldInstr: return fmt.Sprintf("store_field %s %d %s", wrapValue(i.Base), i.FieldIndex, formatValue(i.Value)) case *EvalInstr: @@ -292,6 +294,8 @@ func formatValue(value Value) string { return fmt.Sprintf("%s %s", kw, wrapValue(v.Source)) case *LoadValue: return fmt.Sprintf("load %s", wrapValue(v.Pointer)) + case *AtomicLoadValue: + return fmt.Sprintf("atomic_load %s", wrapValue(v.Pointer)) case *BinaryValue: return fmt.Sprintf("%s %s %s", binaryOpcode(v.Op), wrapValue(v.Left), wrapValue(v.Right)) case *PostfixValue: diff --git a/internal/ir/mir/lower.go b/internal/ir/mir/lower.go index 6e9501f7..04232882 100644 --- a/internal/ir/mir/lower.go +++ b/internal/ir/mir/lower.go @@ -285,6 +285,9 @@ func lowerInstr(lowerCtx *lowerContext, stmt hir.Stmt) Instr { if ident, ok := s.Left.(*hir.Ident); ok && len(ident.Path) == 1 && ident.Path[0] == "_" { return &EvalInstr{baseInstr: baseInstr{Location: s.Loc()}, Value: lowerValue(lowerCtx, s.Right)} } + if atomicAdd := lowerAtomicAddInstr(lowerCtx, s); atomicAdd != nil { + return atomicAdd + } if target := lowerAssignableTarget(lowerCtx, s.Left); target != nil { return &StoreInstr{baseInstr: baseInstr{Location: s.Loc()}, Target: target, Value: lowerCoercedValue(lowerCtx, s.Right, s.Left.Type())} } @@ -505,6 +508,9 @@ func lowerValue(lowerCtx *lowerContext, expr hir.Expr) Value { _, isRaw := e.Type().(*typeinfo.RawPtrType) return &AddrOfValue{baseValue: baseValue{Location: e.Loc(), ExprType: e.Type()}, Source: lowerAddrSource(lowerCtx, e.Right), Mutable: true, Raw: isRaw} case "*": + if atomic, ok := derefAtomicInnerType(e.Right.Type()); ok { + return &AtomicLoadValue{baseValue: baseValue{Location: e.Loc(), ExprType: atomic}, Pointer: lowerValue(lowerCtx, e.Right)} + } return &LoadValue{baseValue: baseValue{Location: e.Loc(), ExprType: e.Type()}, Pointer: lowerValue(lowerCtx, e.Right)} default: return &UnaryValue{baseValue: baseValue{Location: e.Loc(), ExprType: e.Type()}, Op: e.Op, Right: lowerValue(lowerCtx, e.Right)} @@ -618,6 +624,111 @@ func lowerValue(lowerCtx *lowerContext, expr hir.Expr) Value { } } +func lowerAtomicAddInstr(lowerCtx *lowerContext, stmt *hir.AssignStmt) Instr { + if stmt == nil { + return nil + } + atomic, ok := atomicAssignTargetType(stmt.Left) + if !ok || atomic == nil { + return nil + } + if !typeinfo.SupportsAtomicAdd(atomic.Inner) { + return nil + } + bin, ok := stmt.Right.(*hir.BinaryExpr) + if !ok || bin == nil { + return nil + } + var delta Value + switch bin.Op { + case "+": + if sameHIRExpr(stmt.Left, bin.Left) { + delta = lowerValue(lowerCtx, bin.Right) + } else if sameHIRExpr(stmt.Left, bin.Right) { + delta = lowerValue(lowerCtx, bin.Left) + } + } + if delta == nil { + return nil + } + return &AtomicAddInstr{ + baseInstr: baseInstr{Location: stmt.Loc()}, + Pointer: lowerAtomicPointer(lowerCtx, stmt.Left), + Delta: delta, + Type: atomic.Inner, + } +} + +func lowerAtomicPointer(lowerCtx *lowerContext, expr hir.Expr) Value { + if pref, ok := expr.(*hir.PrefixExpr); ok && pref != nil && pref.Op == "*" { + return lowerValue(lowerCtx, pref.Right) + } + return &AddrOfValue{ + baseValue: baseValue{Location: expr.Loc(), ExprType: &typeinfo.RefType{Inner: expr.Type()}}, + Source: lowerAddrSource(lowerCtx, expr), + } +} + +func atomicAssignTargetType(expr hir.Expr) (*typeinfo.AtomicType, bool) { + if expr == nil { + return nil, false + } + if atomic, ok := expr.Type().(*typeinfo.AtomicType); ok && atomic != nil { + return atomic, true + } + if pref, ok := expr.(*hir.PrefixExpr); ok && pref != nil && pref.Op == "*" { + switch t := pref.Right.Type().(type) { + case *typeinfo.RefType: + atomic, ok := t.Inner.(*typeinfo.AtomicType) + return atomic, ok && atomic != nil + case *typeinfo.PointerType: + atomic, ok := t.Inner.(*typeinfo.AtomicType) + return atomic, ok && atomic != nil + case *typeinfo.RawPtrType: + atomic, ok := t.Inner.(*typeinfo.AtomicType) + return atomic, ok && atomic != nil + } + } + return nil, false +} + +func derefAtomicInnerType(typ typeinfo.Type) (typeinfo.Type, bool) { + switch t := typ.(type) { + case *typeinfo.RefType: + if atomic, ok := t.Inner.(*typeinfo.AtomicType); ok && atomic != nil { + return atomic.Inner, true + } + case *typeinfo.PointerType: + if atomic, ok := t.Inner.(*typeinfo.AtomicType); ok && atomic != nil { + return atomic.Inner, true + } + case *typeinfo.RawPtrType: + if atomic, ok := t.Inner.(*typeinfo.AtomicType); ok && atomic != nil { + return atomic.Inner, true + } + } + return nil, false +} + +func sameHIRExpr(left, right hir.Expr) bool { + if left == nil || right == nil { + return false + } + if left.SourceExpr() != nil && left.SourceExpr() == right.SourceExpr() { + return true + } + switch l := left.(type) { + case *hir.PrefixExpr: + r, ok := right.(*hir.PrefixExpr) + return ok && l.Op == r.Op && sameHIRExpr(l.Right, r.Right) + case *hir.Ident: + r, ok := right.(*hir.Ident) + return ok && l.LocalID == r.LocalID && strings.Join(l.Path, "::") == strings.Join(r.Path, "::") + default: + return false + } +} + func lowerCallArgs(lowerCtx *lowerContext, loc source.Location, args []hir.Expr, fnType *typeinfo.FuncType, stringifyAnyArgs bool) []Value { unwrapSpread := func(arg hir.Expr) (hir.Expr, bool) { if pref, ok := arg.(*hir.PrefixExpr); ok && pref != nil && pref.Op == "..." { @@ -1601,6 +1712,12 @@ func lowerInterfaceCoercion(lowerCtx *lowerContext, source, target typeinfo.Type if source == nil || typeinfo.Equal(source, target) || lowerIsInterfaceType(source) { return nil, nil, false } + if len(targetIface) == 0 { + if atomic, ok := source.(*typeinfo.AtomicType); ok && atomic != nil { + return nil, atomic.Inner, true + } + return nil, source, true + } if lowerCtx == nil || lowerCtx.lookupMethod == nil { return nil, nil, false } diff --git a/internal/ir/mir/model.go b/internal/ir/mir/model.go index 86a69fdb..276cfbd6 100644 --- a/internal/ir/mir/model.go +++ b/internal/ir/mir/model.go @@ -232,6 +232,13 @@ type LoadValue struct { func (*LoadValue) valueNode() {} +type AtomicLoadValue struct { + baseValue + Pointer Value +} + +func (*AtomicLoadValue) valueNode() {} + type BinaryValue struct { baseValue Left Value @@ -395,6 +402,15 @@ type StoreInstr struct { func (*StoreInstr) instrNode() {} +type AtomicAddInstr struct { + baseInstr + Pointer Value + Delta Value + Type typeinfo.Type +} + +func (*AtomicAddInstr) instrNode() {} + type StoreFieldInstr struct { baseInstr Base Value diff --git a/internal/ir/mir/validate.go b/internal/ir/mir/validate.go index d9fd15de..183b623c 100644 --- a/internal/ir/mir/validate.go +++ b/internal/ir/mir/validate.go @@ -92,6 +92,9 @@ func validateInstrValueShape(bag *diagnostics.DiagnosticBag, instr Instr) bool { return requireNormalizedCompute(bag, i.Loc(), i.Value) case *StoreInstr: return requireNormalizedAssignable(bag, i.Loc(), i.Value, "store") + case *AtomicAddInstr: + ok := requireSimpleValue(bag, i.Loc(), i.Pointer, "atomic_add pointer") + return requireSimpleValue(bag, i.Loc(), i.Delta, "atomic_add delta") && ok case *StoreFieldInstr: ok := requireSimpleValue(bag, i.Loc(), i.Base, "store_field base") return requireSimpleValue(bag, i.Loc(), i.Value, "store_field") && ok @@ -177,6 +180,8 @@ func childrenAreSimple(value Value) bool { return isSimpleValue(v.Source) || childrenAreSimple(v.Source) case *LoadValue: return isSimpleValue(v.Pointer) + case *AtomicLoadValue: + return isSimpleValue(v.Pointer) case *BinaryValue: return isSimpleValue(v.Left) && isSimpleValue(v.Right) case *PostfixValue: diff --git a/internal/ir/mir/walk.go b/internal/ir/mir/walk.go index 94b7584c..9fca5e84 100644 --- a/internal/ir/mir/walk.go +++ b/internal/ir/mir/walk.go @@ -56,6 +56,11 @@ func WalkInstrValues(instr Instr, visit func(Value) error) error { return err } return WalkValue(i.Value, visit) + case *AtomicAddInstr: + if err := WalkValue(i.Pointer, visit); err != nil { + return err + } + return WalkValue(i.Delta, visit) case *StoreFieldInstr: if err := WalkValue(i.Base, visit); err != nil { return err @@ -133,6 +138,8 @@ func WalkValue(value Value, visit func(Value) error) error { return WalkValue(v.Source, visit) case *LoadValue: return WalkValue(v.Pointer, visit) + case *AtomicLoadValue: + return WalkValue(v.Pointer, visit) case *BinaryValue: if err := WalkValue(v.Left, visit); err != nil { return err diff --git a/internal/tokens/token.go b/internal/tokens/token.go index 7d40deb7..ef3923ef 100644 --- a/internal/tokens/token.go +++ b/internal/tokens/token.go @@ -95,6 +95,7 @@ const ( AS Kind = "AS" IS Kind = "IS" MUT Kind = "MUT" + ATOMIC Kind = "ATOMIC" COMPTIME Kind = "COMPTIME" LOCK Kind = "LOCK" DEFER Kind = "DEFER" @@ -128,6 +129,7 @@ var keywords = map[string]Kind{ "as": AS, "is": IS, "mut": MUT, + "atomic": ATOMIC, "comptime": COMPTIME, "lock": LOCK, "defer": DEFER, @@ -161,6 +163,7 @@ var keywordDocs = map[Kind]string{ AS: "Cast an expression to a target type.", IS: "Check whether a value conforms to a target type.", MUT: "Mark a binding or reference as mutable.", + ATOMIC: "Declare or name atomic storage.", COMPTIME: "Force compile-time evaluation.", LOCK: "Acquire a lock guard for the block scope.", DEFER: "Run a statement when the current scope exits.", diff --git a/runtime/ferret_runtime.h b/runtime/ferret_runtime.h index 40bfec15..737e7dc3 100644 --- a/runtime/ferret_runtime.h +++ b/runtime/ferret_runtime.h @@ -207,6 +207,15 @@ FerretStr ferret_global_recover(void); void ferret_global_print(const FerretSliceAny *values); +/* ------------------------------------------------------------------------- + * std/task surface. + * -------------------------------------------------------------------------*/ + +typedef void (*FerretTaskEntryRaw)(ferret_raw); + +ferret_raw ferret_task_run_raw(FerretTaskEntryRaw entry, ferret_raw arg); +void ferret_task_wait(ferret_raw handle); + /* ------------------------------------------------------------------------- * std/io surface. * -------------------------------------------------------------------------*/ diff --git a/runtime/ferret_runtime_task.c b/runtime/ferret_runtime_task.c new file mode 100644 index 00000000..9b75a242 --- /dev/null +++ b/runtime/ferret_runtime_task.c @@ -0,0 +1,75 @@ +#include "ferret_runtime_internal.h" + +#if defined(_WIN32) +ferret_raw ferret_task_run_raw(FerretTaskEntryRaw entry, ferret_raw arg) { + (void)entry; + (void)arg; + return NULL; +} + +void ferret_task_wait(ferret_raw raw_handle) { + (void)raw_handle; +} +#else +#include +#include + +typedef struct { + pthread_t thread; +} FerretTaskHandle; + +typedef struct { + FerretTaskEntryRaw entry; + ferret_raw arg; +} FerretTaskStart; + +static void *ferret_task_trampoline(void *raw) { + FerretTaskStart *start = (FerretTaskStart *)raw; + FerretTaskEntryRaw entry; + ferret_raw arg; + + if (start == NULL) { + return NULL; + } + entry = start->entry; + arg = start->arg; + free(start); + if (entry != NULL) { + entry(arg); + } + return NULL; +} + +ferret_raw ferret_task_run_raw(FerretTaskEntryRaw entry, ferret_raw arg) { + FerretTaskHandle *handle; + FerretTaskStart *start; + + if (entry == NULL) { + return NULL; + } + handle = (FerretTaskHandle *)malloc(sizeof(FerretTaskHandle)); + start = (FerretTaskStart *)malloc(sizeof(FerretTaskStart)); + if (handle == NULL || start == NULL) { + free(handle); + free(start); + return NULL; + } + start->entry = entry; + start->arg = arg; + if (pthread_create(&handle->thread, NULL, ferret_task_trampoline, start) != 0) { + free(start); + free(handle); + return NULL; + } + return (ferret_raw)handle; +} + +void ferret_task_wait(ferret_raw raw_handle) { + FerretTaskHandle *handle = (FerretTaskHandle *)raw_handle; + if (handle == NULL) { + return; + } + (void)pthread_join(handle->thread, NULL); + free(handle); +} +#endif diff --git a/tests/repro/task_atomic_i32.fer b/tests/repro/task_atomic_i32.fer new file mode 100644 index 00000000..becdcafd --- /dev/null +++ b/tests/repro/task_atomic_i32.fer @@ -0,0 +1,23 @@ +import "std/task" + +fn worker(value: &atomic i32) -> void { + let mut i = 0 + while i < 200000 { + (*value)++ + i += 1 + } +} + +fn main() -> void { + let atomic counter = 0 + let a = task::Run(worker, &counter) + let b = task::Run(worker, &counter) + let c = task::Run(worker, &counter) + let d = task::Run(worker, &counter) + + task::WaitAll(a, b, c, d) + //b.Wait() + println(counter) + + //let atomic val: struct{ x: i32 } = .{ .x = 2 } +} diff --git a/tests/repro/task_race_raw.fer b/tests/repro/task_race_raw.fer new file mode 100644 index 00000000..a616f2b5 --- /dev/null +++ b/tests/repro/task_race_raw.fer @@ -0,0 +1,21 @@ +import "std/task" + +fn worker(value: &mut i32) -> void { + let mut counter = 0 + while counter < 200000 { + (*value)++ + counter += 1 + } +} + +fn main() -> void { + let mut res = 0 + let t1 = task::Run(worker, &mut res) + let t2 = task::Run(worker, &mut res) + let t3 = task::Run(worker, &mut res) + let t4 = task::Run(worker, &mut res) + + task::WaitAll(t1, t2, t3, t4) + //t1.Wait() + println(res) +} diff --git a/tests/repro/variadic.fer b/tests/repro/variadic.fer new file mode 100644 index 00000000..1f1346eb --- /dev/null +++ b/tests/repro/variadic.fer @@ -0,0 +1,8 @@ + +import "std/mem" + +fn main() { + let allocator = mem::CAllocator{} + + let v1 = allocator.AllocZeroed(4) +} \ No newline at end of file