From 6f496ab1715dc4ef5ee67f8cda8cbadec230b838 Mon Sep 17 00:00:00 2001 From: Seth Hoenig Date: Sat, 28 Mar 2026 13:22:22 -0500 Subject: [PATCH] verbs: implement cas functionality via gets and cas commands --- e2e_test.go | 62 +++++++++++++++++ iopool/pool.go | 8 +-- verbs.go | 176 ++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 241 insertions(+), 5 deletions(-) diff --git a/e2e_test.go b/e2e_test.go index 4dea157..54eb063 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -507,3 +507,65 @@ func TestE2E_StatsItems(t *testing.T) { must.Positive(t, data[0].Number) must.Positive(t, data[0].MemRequested) } + +func TestE2E_CAS(t *testing.T) { + t.Parallel() + + address, done := memctest.LaunchTCP(t, nil) + t.Cleanup(done) + + c := New([]string{address}) + defer ignore.Close(c) + + t.Run("success", func(t *testing.T) { + err := Set(c, "key1", "value1") + must.NoError(t, err) + + v, cas, verr := Gets[string](c, "key1") + must.NoError(t, verr) + must.Eq(t, "value1", v) + must.Positive(t, uint64(cas)) + + err = CompareAndSwap(c, "key1", cas, "value1.updated") + must.NoError(t, err) + + v, err = Get[string](c, "key1") + must.NoError(t, err) + must.Eq(t, "value1.updated", v) + }) + + t.Run("conflict", func(t *testing.T) { + err := Set(c, "key2", "original") + must.NoError(t, err) + + _, cas1, verr := Gets[string](c, "key2") + must.NoError(t, verr) + + _, _, verr = Gets[string](c, "key2") + must.NoError(t, verr) + + err = CompareAndSwap(c, "key2", cas1, "first-update") + must.NoError(t, err) + + err = CompareAndSwap(c, "key2", cas1, "stale-update") + must.ErrorIs(t, err, ErrConflict) + + v, err := Get[string](c, "key2") + must.NoError(t, err) + must.Eq(t, "first-update", v) + }) + + t.Run("not found", func(t *testing.T) { + err := Set(c, "key3", "value3") + must.NoError(t, err) + + _, cas, verr := Gets[string](c, "key3") + must.NoError(t, verr) + + err = Delete(c, "key3") + must.NoError(t, err) + + err = CompareAndSwap(c, "key3", cas, "newvalue") + must.ErrorIs(t, err, ErrNotFound) + }) +} diff --git a/iopool/pool.go b/iopool/pool.go index c7e7e9f..4465dc3 100644 --- a/iopool/pool.go +++ b/iopool/pool.go @@ -44,10 +44,10 @@ type Buffer struct { func newBuffer(conn Connection) *Buffer { return &Buffer{ - bufio.NewReader(conn), - bufio.NewWriter(conn), - conn, - new(atomic.Bool), + Reader: bufio.NewReader(conn), + Writer: bufio.NewWriter(conn), + Closer: conn, + failure: new(atomic.Bool), } } diff --git a/verbs.go b/verbs.go index a61c1b2..07e4ffe 100644 --- a/verbs.go +++ b/verbs.go @@ -28,6 +28,10 @@ var ( ErrCommandIssue = errors.New("memc: got command error response") ) +// CAS represents a Compare-And-Swap token used for optimistic locking. +// It is returned by Gets and must be provided to CompareAndSwap to atomically update a value. +type CAS uint64 + // Options contains configuration parameters that may be applied when executing // a verb like Get, Set, etc. type Options struct { @@ -437,6 +441,88 @@ func Add[T any](c *Client, key string, item T, opts ...Option) error { }) } +// CompareAndSwap will store the item using the given key, but only if the CAS +// token matches the current value's CAS token. This provides atomic +// compare-and-swap functionality for optimistic locking. +// +// If the key does not exist, ErrNotFound is returned. +// +// If the CAS token does not match (meaning the value was modified since it was +// retrieved with Gets), ErrConflict is returned. +// +// Uses Client c to connect to a memcached instance, and automatically handles +// connection pooling and reuse. +// +// One or more Option(s) may be applied to configure things such as the value +// expiration TTL or its associated flags. +func CompareAndSwap[T any](c *Client, key string, cas CAS, item T, opts ...Option) error { + if err := check(key); err != nil { + return err + } + + options := &Options{ + expiration: c.expiration, + flags: 0, + } + + for _, opt := range opts { + opt(options) + } + + return c.do(key, func(conn *iopool.Buffer) error { + encoding, encerr := encode(item) + if encerr != nil { + return encerr + } + + expiration, experr := c.seconds(options.expiration) + if experr != nil { + return experr + } + + // write the header components with CAS token + if _, err := fmt.Fprintf( + conn, + "cas %s %d %d %d %d\r\n", + key, options.flags, expiration, len(encoding), cas, + ); err != nil { + return err + } + + // write the payload + if _, err := conn.Write(encoding); err != nil { + return err + } + + // write clrf + if _, err := io.WriteString(conn, "\r\n"); err != nil { + return err + } + + // flush the buffer + if err := conn.Flush(); err != nil { + return err + } + + // read response + line, lerr := conn.ReadSlice('\n') + if lerr != nil { + return lerr + } + + switch string(line) { + case "STORED\r\n": + return nil + case "NOT_FOUND\r\n": + return ErrNotFound + case "EXISTS\r\n": + return ErrConflict + default: + return fmt.Errorf("memc: unexpected response to cas: %q", string(line)) + } + }) +} + // Get the value associated with the given key. // // Uses Client c to connect to a memcached instance, and automatically handles @@ -472,6 +558,51 @@ func Get[T any](c *Client, key string) (T, error) { return result, err } +// Gets the value associated with the given key, along with its CAS token. +// +// The CAS token can be used with CompareAndSwap to atomically update the value, +// providing optimistic locking. If the value has been modified since it was +// retrieved, CompareAndSwap will return an ErrConflict error. +// +// Uses Client c to connect to a memcached instance, and automatically handles +// connection pooling and reuse. +func Gets[T any](c *Client, key string) (T, CAS, error) { + var result T + var casToken CAS + + if err := check(key); err != nil { + return result, 0, err + } + + err := c.do(key, func(conn *iopool.Buffer) error { + // write the header components + if _, err := fmt.Fprintf(conn, "gets %s\r\n", key); err != nil { + return err + } + + // flush the connection, forcing bytes over the wire + if err := conn.Flush(); err != nil { + return err + } + + // read the response payload with CAS token + payload, cas, err := getPayloadWithCAS(conn.Reader) + if err != nil { + return err + } + + result, err = decode[T](payload) + if err != nil { + return err + } + + casToken = CAS(cas) + return nil + }) + + return result, casToken, err +} + func getPayload(r *bufio.Reader) ([]byte, error) { b, err := r.ReadSlice('\n') if err != nil { @@ -483,7 +614,6 @@ func getPayload(r *bufio.Reader) ([]byte, error) { return nil, ErrCacheMiss } - // TODO: does not handle CAS value for now expect := "VALUE %s %d %d\r\n" var ( key string @@ -515,6 +645,50 @@ func getPayload(r *bufio.Reader) ([]byte, error) { return payload, err } +func getPayloadWithCAS(r *bufio.Reader) ([]byte, uint64, error) { + b, err := r.ReadSlice('\n') + if err != nil { + return nil, 0, err + } + + // key was not found, is a cache miss + if string(b) == "END\r\n" { + return nil, 0, ErrCacheMiss + } + + // handle CAS value - format is "VALUE key flags bytes cas\r\n" + expect := "VALUE %s %d %d %d\r\n" + var ( + key string + flags int + size int + cas uint64 + ) + + // scan the header line, giving us a payload size and CAS token + if _, err = fmt.Sscanf(string(b), expect, &key, &flags, &size, &cas); err != nil { + return nil, 0, err + } + + // read the data into our payload + payload := make([]byte, size+2) // including \r\n + if _, err = io.ReadFull(r, payload); err != nil { + return nil, 0, err + } + payload = payload[0:size] // chop \r\n + + // read the trailing line ("END\r\n") + b, err = r.ReadSlice('\n') + if err != nil { + return nil, 0, err + } + if string(b) != "END\r\n" { + return nil, 0, unexpected(b) + } + + return payload, cas, nil +} + // Delete will remove the value associated with key from memcached. // // Uses Client c to connect to a memcached instance, and automatically handles