diff --git a/config/config.go b/config/config.go index f56b3201..6031636a 100644 --- a/config/config.go +++ b/config/config.go @@ -72,8 +72,8 @@ type ( // any request with body (each call to request's body will result in status.ErrBodyTooLarge). // In order to disable the setting, use the math.MaxUInt64 value. MaxSize uint64 - //// DecodingBufferSize is a size of a buffer, used to store decoded request's body - //DecodingBufferSize int + // Form is either application/x-www-form-urlencoded or multipart/form-data. Due to their common + // nature, they are easy to be generalized. Form BodyForm } @@ -96,6 +96,10 @@ type ( // 2) If a stream is unsized (1) and the previous write used more than ~98.44% of its // capacity (2), the capacity doubles. WriteBufferSize NETWriteBufferSize + // SmallBody limits how big must a response body be in order to be compressed, if the + // auto compression option is enabled. This setting doesn't affect enforced compression + // options and unsized streams. + SmallBody int64 } ) @@ -158,6 +162,7 @@ func Default() *Config { Default: 2 * 1024, Maximal: 64 * 1024, }, + SmallBody: 4 * 1024, }, } } diff --git a/examples/compression/compression.go b/examples/compression/compression.go index dbef728d..cc62e7b6 100644 --- a/examples/compression/compression.go +++ b/examples/compression/compression.go @@ -23,7 +23,7 @@ func Shout(r *http.Request) *http.Response { func main() { app := indigo.New(":8080"). - Codec(codec.NewGZIP()). + Codec(codec.Suit()...). OnBind(func(addr string) { fmt.Println("Listening on", addr) }) diff --git a/examples/demo/demo.go b/examples/demo/demo.go index b876ea73..aa2ba005 100644 --- a/examples/demo/demo.go +++ b/examples/demo/demo.go @@ -62,7 +62,7 @@ func main() { app := indigo.New(":8080"). TLS(":8443", indigo.LocalCert()). Tune(s). - Codec(codec.NewGZIP()). + Codec(codec.Suit()...). OnBind(func(addr string) { log.Printf("running on %s\n", addr) }) @@ -71,10 +71,14 @@ func main() { Use(middleware.LogRequests()). Alias("/", "/static/index.html", method.GET). Alias("/favicon.ico", "/static/favicon.ico", method.GET). - Static("/static", "examples/demo/static") + Static("/static", "examples/demo/static", middleware.Autocompress) r.Get("/stress", Stressful, middleware.Recover) + r.Post("/echo", func(request *http.Request) *http.Response { + return http.Stream(request, request.Body) + }) + r.Resource("/"). Post(IndexSay) diff --git a/http/body.go b/http/body.go index edd0f2e8..39186dee 100644 --- a/http/body.go +++ b/http/body.go @@ -32,8 +32,6 @@ type Body struct { form form.Form } -// TODO: body entity can be passed by value (?) - func NewBody(src Fetcher) *Body { return &Body{ Fetcher: src, @@ -79,7 +77,7 @@ func (b *Body) Bytes() ([]byte, error) { } newSize := int(b.request.cfg.Body.Form.BufferPrealloc) - if !b.request.Encoding.Chunked { + if !b.request.Chunked { newSize = min(b.request.ContentLength, int(b.request.cfg.Body.MaxSize)) } @@ -178,6 +176,14 @@ func (b *Body) Form() (f form.Form, err error) { } } +func (b *Body) Len() int { + if b.request.Chunked { + return -1 + } + + return b.request.ContentLength +} + // Discard sinkholes the rest of the body. Should not be used unless you know what you're doing. func (b *Body) Discard() error { for b.error == nil { diff --git a/http/codec/adapter.go b/http/codec/adapter.go deleted file mode 100644 index 15386a56..00000000 --- a/http/codec/adapter.go +++ /dev/null @@ -1,37 +0,0 @@ -package codec - -import ( - "github.com/indigo-web/indigo/http" -) - -type readerAdapter struct { - fetcher http.Fetcher - err error - data []byte -} - -func newAdapter() *readerAdapter { - return new(readerAdapter) -} - -func (r *readerAdapter) Read(b []byte) (n int, err error) { - if len(r.data) == 0 { - if r.err != nil { - return 0, r.err - } - - r.data, r.err = r.fetcher.Fetch() - } - - n = copy(b, r.data) - r.data = r.data[n:] - if len(r.data) == 0 { - err = r.err - } - - return n, err -} - -func (r *readerAdapter) Reset(fetcher http.Fetcher) { - *r = readerAdapter{fetcher: fetcher} -} diff --git a/http/codec/base.go b/http/codec/base.go new file mode 100644 index 00000000..af3c7753 --- /dev/null +++ b/http/codec/base.go @@ -0,0 +1,146 @@ +package codec + +import ( + "io" + + "github.com/indigo-web/indigo/http" +) + +var _ Codec = baseCodec{} + +type instantiator = func() Instance + +type baseCodec struct { + token string + newInst instantiator +} + +func newBaseCodec(token string, newInst instantiator) baseCodec { + return baseCodec{ + token: token, + newInst: newInst, + } +} + +func (b baseCodec) Token() string { + return b.token +} + +func (b baseCodec) New() Instance { + return b.newInst() +} + +var _ Instance = new(baseInstance) + +type ( + decoderResetter = func(io.Reader, *readerAdapter) error + + writeResetter interface { + io.WriteCloser + Reset(dst io.Writer) + } +) + +type baseInstance struct { + reset decoderResetter + adapter *readerAdapter + w writeResetter // compressor + r io.Reader // decompressor + dst io.Closer + buff []byte +} + +func newBaseInstance(encoder writeResetter, decoder io.Reader, reset decoderResetter) instantiator { + return func() Instance { + return &baseInstance{ + reset: reset, + adapter: newAdapter(), + w: encoder, + r: decoder, + } + } +} + +func (b *baseInstance) ResetCompressor(w io.Writer) { + b.w.Reset(w) + b.dst = nil + + if c, ok := w.(io.Closer); ok { + b.dst = c + } +} + +func (b *baseInstance) Write(p []byte) (n int, err error) { + return b.w.Write(p) +} + +func (b *baseInstance) Close() error { + if err := b.w.Close(); err != nil { + return err + } + + if b.dst != nil { + return b.dst.Close() + } + + return nil +} + +func (b *baseInstance) ResetDecompressor(source http.Fetcher, bufferSize int) error { + if cap(b.buff) < bufferSize { + b.buff = make([]byte, bufferSize) + } + + b.adapter.Reset(source) + + return b.reset(b.r, b.adapter) +} + +func (b *baseInstance) Fetch() ([]byte, error) { + n, err := b.r.Read(b.buff) + return b.buff[:n], err +} + +func genericResetter(r io.Reader, adapter *readerAdapter) error { + type resetter interface { + Reset(r io.Reader) error + } + + if reset, ok := r.(resetter); ok { + return reset.Reset(adapter) + } + + return nil +} + +type readerAdapter struct { + fetcher http.Fetcher + err error + data []byte +} + +func newAdapter() *readerAdapter { + return new(readerAdapter) +} + +func (r *readerAdapter) Read(b []byte) (n int, err error) { + if len(r.data) == 0 { + if r.err != nil { + return 0, r.err + } + + r.data, r.err = r.fetcher.Fetch() + } + + n = copy(b, r.data) + r.data = r.data[n:] + if len(r.data) == 0 { + err = r.err + } + + return n, err +} + +func (r *readerAdapter) Reset(fetcher http.Fetcher) { + *r = readerAdapter{fetcher: fetcher} +} diff --git a/http/codec/codec.go b/http/codec/codec.go index 7cf36847..95ca5b27 100644 --- a/http/codec/codec.go +++ b/http/codec/codec.go @@ -23,6 +23,14 @@ type Compressor interface { } type Decompressor interface { - ResetDecompressor(source http.Fetcher) error http.Fetcher + ResetDecompressor(source http.Fetcher, bufferSize int) error +} + +// Suit is a collection of out-of-the-box supported codecs. It contains: +// - gzip +// - deflate +// - zstd +func Suit() []Codec { + return []Codec{NewGZIP(), NewDeflate(), NewZSTD()} } diff --git a/http/codec/deflate.go b/http/codec/deflate.go index 1fae45a0..d2c54402 100644 --- a/http/codec/deflate.go +++ b/http/codec/deflate.go @@ -1 +1,21 @@ package codec + +import ( + "io" + + "github.com/klauspost/compress/flate" +) + +func NewDeflate() Codec { + writer, err := flate.NewWriter(nil, 5) + if err != nil { + panic(err) + } + + reader := flate.NewReader(nil) + instantiator := newBaseInstance(writer, reader, func(r io.Reader, a *readerAdapter) error { + return r.(flate.Resetter).Reset(a, nil) + }) + + return newBaseCodec("deflate", instantiator) +} diff --git a/http/codec/deflate_test.go b/http/codec/deflate_test.go new file mode 100644 index 00000000..0b6eec51 --- /dev/null +++ b/http/codec/deflate_test.go @@ -0,0 +1,9 @@ +package codec + +import ( + "testing" +) + +func TestFlate(t *testing.T) { + testCodec(t, NewDeflate().New()) +} diff --git a/http/codec/gzip.go b/http/codec/gzip.go index bd440ede..a1a2910e 100644 --- a/http/codec/gzip.go +++ b/http/codec/gzip.go @@ -1,82 +1,13 @@ package codec import ( - "io" - - "github.com/indigo-web/indigo/http" "github.com/klauspost/compress/gzip" ) -// TODO: pass this via parameters? -const decompressorBufferSize = 4096 - -var _ Codec = new(GZIP) - -type GZIP struct{} - -func NewGZIP() GZIP { - return GZIP{} -} - -func (GZIP) Token() string { - return "gzip" -} - -func (g GZIP) New() Instance { - return newGZIPCodec(make([]byte, decompressorBufferSize)) -} - -var _ Instance = new(gzipCodec) - -type gzipCodec struct { - adapter *readerAdapter - w *gzip.Writer // compressor - r gzip.Reader // decompressor - wout io.Closer - buff []byte -} - -func newGZIPCodec(buff []byte) *gzipCodec { - return &gzipCodec{ - adapter: newAdapter(), - w: gzip.NewWriter(nil), - buff: buff, - } -} - -func (g *gzipCodec) ResetCompressor(w io.Writer) { - g.w.Reset(w) - - if c, ok := w.(io.Closer); ok { - g.wout = c - } -} - -func (g *gzipCodec) Write(p []byte) (n int, err error) { - // TODO: the compressor spams with Write() calls. This will cause significant performance downgrade, - // TODO: as each individual Write() call results in transferring the passed data over the network. - // TODO: Buffer this somewhere to at least 4096 (by default). Make the behaviour disable-able. - return g.w.Write(p) -} - -func (g *gzipCodec) Close() error { - if err := g.w.Close(); err != nil { - return err - } - - if g.wout != nil { - return g.wout.Close() - } - - return nil -} - -func (g *gzipCodec) ResetDecompressor(source http.Fetcher) error { - g.adapter.Reset(source) - return g.r.Reset(g.adapter) -} +func NewGZIP() Codec { + writer := gzip.NewWriter(nil) + reader := new(gzip.Reader) + instantiator := newBaseInstance(writer, reader, genericResetter) -func (g *gzipCodec) Fetch() ([]byte, error) { - n, err := g.r.Read(g.buff) - return g.buff[:n], err + return newBaseCodec("gzip", instantiator) } diff --git a/http/codec/gzip_test.go b/http/codec/gzip_test.go index cd017b2f..ba13b285 100644 --- a/http/codec/gzip_test.go +++ b/http/codec/gzip_test.go @@ -10,46 +10,27 @@ import ( "github.com/stretchr/testify/require" ) -func TestGZIP(t *testing.T) { - t.Run("default", func(t *testing.T) { - text, err := gunzip(gzipped("Hello, world!")) - require.NoError(t, err) - require.Equal(t, "Hello, world!", text) - }) +func compress(inst Instance, text string) []byte { + loopback := dummy.NewMockClient().Journaling() + inst.ResetCompressor(loopback) - t.Run("scattered", func(t *testing.T) { - text := strings.Repeat("Hello, world! Lorem ipsum! ", 100) - scattered := scatter(gzipped(text), 2) - result, err := gunzip(scattered...) - require.NoError(t, err) - require.Equal(t, text, result) - }) -} - -func gzipped(text string) []byte { - c := NewGZIP().New() - sinkhole := dummy.NewMockClient().Journaling() - c.ResetCompressor(sinkhole) - - if _, err := c.Write([]byte(text)); err != nil { + if _, err := inst.Write([]byte(text)); err != nil { panic(err) } - if err := c.Close(); err != nil { + if err := inst.Close(); err != nil { panic(err) } - return sinkhole.Written() + return loopback.Written() } -func gunzip(gzipped ...[]byte) (string, error) { - dc := NewGZIP().New() - err := dc.ResetDecompressor(dummy.NewMockClient(gzipped...)) - if err != nil { +func decompress(inst Instance, data ...[]byte) (string, error) { + if err := inst.ResetDecompressor(dummy.NewMockClient(data...), 512); err != nil { return "", err } - return fetchAll(dc) + return fetchAll(inst) } func fetchAll(source http.Fetcher) (string, error) { @@ -75,3 +56,23 @@ func scatter(b []byte, step int) (pieces [][]byte) { return pieces } + +func testCodec(t *testing.T, inst Instance) { + t.Run("identity", func(t *testing.T) { + result, err := decompress(inst, compress(inst, "Hello, world!")) + require.NoError(t, err) + require.Equal(t, "Hello, world!", result) + }) + + t.Run("stream", func(t *testing.T) { + text := strings.Repeat("Hello, world! Lorem ipsum! ", 100) + scattered := scatter(compress(inst, text), 2) + result, err := decompress(inst, scattered...) + require.NoError(t, err) + require.Equal(t, text, result) + }) +} + +func TestGZIP(t *testing.T) { + testCodec(t, NewGZIP().New()) +} diff --git a/http/codec/zstd.go b/http/codec/zstd.go new file mode 100644 index 00000000..1b7312f1 --- /dev/null +++ b/http/codec/zstd.go @@ -0,0 +1,21 @@ +package codec + +import ( + "github.com/klauspost/compress/zstd" +) + +func NewZSTD() Codec { + w, err := zstd.NewWriter(nil) + if err != nil { + panic(err) + } + + r, err := zstd.NewReader(nil) + if err != nil { + panic(err) + } + + instantiator := newBaseInstance(w, r, genericResetter) + + return newBaseCodec("zstd", instantiator) +} diff --git a/http/codec/zstd_test.go b/http/codec/zstd_test.go new file mode 100644 index 00000000..24651638 --- /dev/null +++ b/http/codec/zstd_test.go @@ -0,0 +1,7 @@ +package codec + +import "testing" + +func TestZSTD(t *testing.T) { + testCodec(t, NewZSTD().New()) +} diff --git a/http/form/form.go b/http/form/form.go index 8ab62769..3ff30791 100644 --- a/http/form/form.go +++ b/http/form/form.go @@ -16,7 +16,9 @@ func (f Form) Name(name string) iter.Seq[Data] { return func(yield func(Data) bool) { for _, entry := range f { if entry.Name == name { - yield(entry) + if !yield(entry) { + break + } } } } @@ -26,7 +28,9 @@ func (f Form) File(name string) iter.Seq[Data] { return func(yield func(Data) bool) { for _, entry := range f { if entry.Filename == name { - yield(entry) + if !yield(entry) { + break + } } } } diff --git a/http/method/method.go b/http/method/method.go index 2e0fa1c3..61c358f6 100644 --- a/http/method/method.go +++ b/http/method/method.go @@ -32,8 +32,6 @@ var List = []Method{ } func Parse(str string) Method { - // TODO: this doesn't seem to differ much from just an ordinary wall of if's in terms of performance, - // TODO: whose however would effectively reduce the visual complexity of this crap. switch len(str) { case 3: if str == "GET" { diff --git a/http/mime/charset.go b/http/mime/charset.go index fd5bac79..09d594fd 100644 --- a/http/mime/charset.go +++ b/http/mime/charset.go @@ -9,5 +9,5 @@ const ( ASCII Charset = "ascii" CP1251 Charset = "cp1251" CP1252 Charset = "cp1252" - // TODO: adding more widespreaded charsets would be nice + // feel free to add more widespread charsets! ) diff --git a/http/mime/common.go b/http/mime/common.go index 363b088e..c75d7919 100644 --- a/http/mime/common.go +++ b/http/mime/common.go @@ -39,6 +39,7 @@ const ( SQL MIME = "application/sql" TZIF MIME = "application/tzif" XFDF MIME = "application/xfdf" + HTTP MIME = "message/http" ) // Complies returns whether two MIMEs are compatible. Empty MIME is considered diff --git a/http/request.go b/http/request.go index 46c44775..f0b77239 100644 --- a/http/request.go +++ b/http/request.go @@ -8,6 +8,7 @@ import ( "github.com/indigo-web/indigo/http/cookie" "github.com/indigo-web/indigo/http/method" "github.com/indigo-web/indigo/http/proto" + "github.com/indigo-web/indigo/internal/strutil" "github.com/indigo-web/indigo/kv" "github.com/indigo-web/indigo/transport" ) @@ -84,9 +85,8 @@ func (r *Request) Cookies() (cookie.Jar, error) { r.jar.Clear() - // in RFC 6265, 5.4 cookies are explicitly prohibited from being split into - // list, yet in HTTP/2 it's allowed. I have concerns of some user-agents may - // despite sending them as a list, even via HTTP/1.1 + // even though RFC 6265, 5.4 prohibits the Cookie header from being split into a list, + // some user-agents still might do so in order to fit each value into the 8K limit. for value := range r.Headers.Values("cookie") { if err := cookie.Parse(r.jar, value); err != nil { return nil, err @@ -98,15 +98,14 @@ func (r *Request) Cookies() (cookie.Jar, error) { // Respond returns Response object. // -// WARNING: this method clears the response builder under the hood. As it is passed -// by reference, it'll be cleared EVERYWHERE along a handler +// WARNING: the Response is cleared before being returned. Considering it is stored by pointer, +// this action might affect your application if it was stored anywhere before. func (r *Request) Respond() *Response { return r.response.Clear() } -// Hijack the connection. Request body will be implicitly read (so if you need it you -// should read it before) to the end. After handler exits, the connection will -// be closed, so the connection can be hijacked at most once +// Hijack hijacks an underlying connection. The request body is implicitly discarded before +// exposing the transport. After the handler function terminates, the connection is closed automatically. func (r *Request) Hijack() (transport.Client, error) { if err := r.Body.Discard(); err != nil { return nil, err @@ -117,7 +116,7 @@ func (r *Request) Hijack() (transport.Client, error) { return r.client, nil } -// Hijacked tells whether the connection was hijacked or not +// Hijacked tells whether the connection was hijacked. func (r *Request) Hijacked() bool { return r.hijacked } @@ -147,10 +146,17 @@ type Environment struct { } type commonHeaders struct { - // Encoding holds an information about encoding, that was used to make the request - Encoding Encodings - // ContentLength holds the Content-Length header value. It isn't recommended to rely solely on this - // value, as it can be whatever (but most likely zero) if Request.Encoding.Chunked is true. + // AcceptEncoding is the list of accepted by client tokens. + AcceptEncoding []string + // TransferEncoding is the list of tokens used for this request from `Transfer-Encoding` header value. + TransferEncoding []string + // Chunked describes whether the Transfer attribute is not empty and ends with the `chunked` + // encoding. + Chunked bool + // ContentEncoding is the list of tokens used for this request from `Content-Encoding` header value. + ContentEncoding []string + // ContentLength holds the Content-Length header value. It can be non-zero even if the request uses + // chunked transfer encoding. In that case, it serves more of a hint to approximate the body size. ContentLength int // ContentType holds the Content-Type header value. ContentType string @@ -162,14 +168,23 @@ type commonHeaders struct { Upgrade proto.Protocol } -type Encodings struct { - // Accept is the list of accepted by client tokens. - Accept []string - // Transfer is the list of tokens used for this request from `Transfer-Encoding` header value. - Transfer []string - // Content is the list of tokens used for this request from `Content-Encoding` header value. - Content []string - // Chunked describes whether the Transfer attribute is not empty and ends with the `chunked` - // encoding. - Chunked bool +// PreferredEncoding chooses a preferred encoding from AcceptEncoding, respecting quality markers. +func (c *commonHeaders) PreferredEncoding() string { + if len(c.AcceptEncoding) == 0 { + return "identity" + } + + var prefer string + maxQ := -1 + + for _, str := range c.AcceptEncoding { + token, qualifier := strutil.CutHeader(str) + q := strutil.ParseQualifier(qualifier) + if q > maxQ { + maxQ = q + prefer = token + } + } + + return prefer } diff --git a/http/request_test.go b/http/request_test.go index 158adcdb..a97abb9f 100644 --- a/http/request_test.go +++ b/http/request_test.go @@ -14,41 +14,60 @@ func getRequest() *Request { return NewRequest(config.Default(), nil, dummy.NewNopClient(), kv.New(), kv.New(), kv.New()) } -func TestCookies(t *testing.T) { - t.Run("no cookies", func(t *testing.T) { - request := getRequest() - jar, err := request.Cookies() - require.NoError(t, err) - require.Zero(t, jar.Len()) - }) - - t.Run("happy path", func(t *testing.T) { - test := func(t *testing.T, request *Request) { +func TestRequest(t *testing.T) { + t.Run("cookies", func(t *testing.T) { + t.Run("none", func(t *testing.T) { + request := getRequest() jar, err := request.Cookies() require.NoError(t, err) - require.Equal(t, "world", jar.Value("hello")) - require.Equal(t, "hello", jar.Value("world")) - require.Equal(t, "funny", jar.Value("monke")) - require.Equal(t, 3, jar.Len(), "jar must contain exactly 3 pairs") - } + require.Zero(t, jar.Len()) + }) + + t.Run("happy path", func(t *testing.T) { + test := func(t *testing.T, request *Request) { + jar, err := request.Cookies() + require.NoError(t, err) + require.Equal(t, "world", jar.Value("hello")) + require.Equal(t, "hello", jar.Value("world")) + require.Equal(t, "funny", jar.Value("monke")) + require.Equal(t, 3, jar.Len(), "jar must contain exactly 3 pairs") + } - request := getRequest() - request.Headers.Add("Cookie", "hello=world; world=hello") - request.Headers.Add("Cookie", "monke=funny") - // repeat the test twice to make sure, that calling it again won't produce - // different result - test(t, request) - test(t, request) + request := getRequest() + request.Headers.Add("Cookie", "hello=world; world=hello") + request.Headers.Add("Cookie", "monke=funny") + // repeat the test twice to make sure, that calling it again won't produce + // different result + test(t, request) + test(t, request) + }) + + t.Run("malformed", func(t *testing.T) { + request := getRequest() + request.Headers.Add("Cookie", "a") + // repeat the test twice to make sure, that calling it again won't produce + // different result + _, err := request.Cookies() + require.EqualError(t, err, cookie.ErrBadCookie.Error()) + _, err = request.Cookies() + require.EqualError(t, err, cookie.ErrBadCookie.Error()) + }) }) - t.Run("malformed", func(t *testing.T) { - request := getRequest() - request.Headers.Add("Cookie", "a") - // repeat the test twice to make sure, that calling it again won't produce - // different result - _, err := request.Cookies() - require.EqualError(t, err, cookie.ErrBadCookie.Error()) - _, err = request.Cookies() - require.EqualError(t, err, cookie.ErrBadCookie.Error()) + t.Run("preferred encoding", func(t *testing.T) { + testPreferredEncoding := func(want string, tokens ...string) func(t *testing.T) { + return func(t *testing.T) { + headers := &commonHeaders{AcceptEncoding: tokens} + require.Equal(t, want, headers.PreferredEncoding()) + } + } + + t.Run("accept none", testPreferredEncoding("identity")) + t.Run("no qualifiers", testPreferredEncoding("gzip", "gzip", "deflate")) + t.Run("with qualifiers", testPreferredEncoding( + "zstd", + "gzip", "deflate;q=0.5", "zstd;q=1.0", + )) + t.Run("invalid qualifier", testPreferredEncoding("gzip", "gzip", "zstd;q=0.05")) }) } diff --git a/http/response.go b/http/response.go index eb379494..57189729 100644 --- a/http/response.go +++ b/http/response.go @@ -40,24 +40,21 @@ func NewResponse() *Response { } } -// Code sets a Response code and a corresponding status. -// In case of unknown code, "Unknown Status Code" will be set as a status -// code. In this case you should call Status explicitly +// Code sets the response code. If the code is unrecognized, its default status string +// is "Nonstandard". Otherwise, it will be chosen automatically unless overridden. func (r *Response) Code(code status.Code) *Response { r.fields.Code = code return r } -// Status sets a custom status text. This text does not matter at all, and usually -// totally ignored by client, so there is actually no reasons to use this except some -// rare cases when you need to represent a Response status text somewhere +// Status sets a custom status text. func (r *Response) Status(status status.Status) *Response { r.fields.Status = status return r } // ContentType is a shorthand for Header("Content-Type", value) with an option of setting -// a charset. If more than 1 is set, only the first one is used. +// a charset if at least one is specified. All others are ignored. func (r *Response) ContentType(value mime.MIME, charset ...mime.Charset) *Response { if value == mime.Unset { return r @@ -70,10 +67,18 @@ func (r *Response) ContentType(value mime.MIME, charset ...mime.Charset) *Respon return r.Header("Content-Type", value) } -// Compress sets the Content-Encoding value and compresses the outcoming body. Passing the compression -// token that isn't recognized is a no-op. -func (r *Response) Compress(token string) *Response { +// Compress chooses and sets the best suiting compression based on client preferences. +func (r *Response) Compress() *Response { + r.fields.AutoCompress = true + r.fields.ContentEncoding = "" // to avoid conflicts, wins the last method applied. + return r +} + +// Compression enforces a specific codec to be used, even if it isn't in Accept-Encoding. +// The method is no-op if the token is not recognized. +func (r *Response) Compression(token string) *Response { r.fields.ContentEncoding = token + r.fields.AutoCompress = false return r } @@ -91,18 +96,13 @@ func (r *Response) Header(key string, values ...string) *Response { return r } -// Headers simply merges passed headers into Response. Also, it is the only -// way to specify a quality marker of value. In case headers were not initialized -// before, Response headers will be set to a passed map, so editing this map -// will affect Response +// Headers merges the map into the response headers. func (r *Response) Headers(headers map[string][]string) *Response { - resp := r - for k, v := range headers { - resp = resp.Header(k, v...) + r.Header(k, v...) } - return resp + return r } // String sets the response body. @@ -110,9 +110,10 @@ func (r *Response) String(body string) *Response { return r.Bytes(uf.S2B(body)) } -// Bytes sets the response body without copying it. +// Bytes sets the response body. Please note that the passed slice must not be modified +// after being passed. func (r *Response) Bytes(body []byte) *Response { - return r.SizedStream(r.body.Reset(body), int64(len(body))) + return r.Stream(r.body.Reset(body), int64(len(body))) } // Write implements io.Reader interface. It always returns n=len(b) and err=nil @@ -123,7 +124,8 @@ func (r *Response) Write(b []byte) (n int, err error) { return len(b), nil } -// TryFile tries to open a file for reading and returns a new Response with attachment. +// TryFile tries to open a file by the path for reading and sets it as an upload stream if succeeded. +// Otherwise, the error is returned. func (r *Response) TryFile(path string) (*Response, error) { fd, err := os.Open(path) if err != nil { @@ -141,35 +143,33 @@ func (r *Response) TryFile(path string) (*Response, error) { return r. ContentType(mime.Guess(path, mime.HTML)). - SizedStream(fd, stat.Size()), nil + Stream(fd, stat.Size()), nil } -// File opens a file for reading and returns a new Response with attachment, set to the file -// descriptor.fields. If error occurred, it'll be silently returned +// File opens a file by the path and sets it as an upload stream if succeeded. Otherwise, the error +// is silently written instead. func (r *Response) File(path string) *Response { resp, err := r.TryFile(path) - if err != nil { - return r.Error(err) - } - - return resp + return resp.Error(err) } -// Stream sets a reader to be the source of the response's body. -func (r *Response) Stream(reader io.Reader) *Response { - // TODO: we can check whether the reader implements Len() int interface and in that - // TODO: case elide the chunked transfer encoding - r.fields.Stream = reader +// Stream sets a reader to be the source of the response's body. If no size is provided AND the reader +// doesn't have the Len() int method, the stream is considered unsized and therefore will be streamed +// using chunked transfer encoding. Otherwise, plain transfer is used, unless a compression is applied. +// Specifying the size of -1 forces the stream to be considered unsized. +func (r *Response) Stream(reader io.Reader, size ...int64) *Response { + type Len interface { + Len() int + } + r.fields.StreamSize = -1 - return r -} + if len(size) > 0 { + r.fields.StreamSize = size[0] + } else if l, ok := reader.(Len); ok { + r.fields.StreamSize = int64(l.Len()) + } -// SizedStream receives a hint of the stream's future size. This helps, for example, uploading files, -// as in this case we can rely on io.WriterTo interface, which might use more effective kernel mechanisms -// available, e.g. sendfile(2) for Linux. Passing the size of -1 is effectively equivalent to just Stream(). -func (r *Response) SizedStream(reader io.Reader, size int64) *Response { r.fields.Stream = reader - r.fields.StreamSize = size return r } @@ -179,8 +179,7 @@ func (r *Response) Cookie(cookies ...cookie.Cookie) *Response { return r } -// TryJSON receives a model (must be a pointer to the structure) and returns a new Response -// object and an error +// TryJSON tries to serialize the model into JSON. func (r *Response) TryJSON(model any) (*Response, error) { stream := json.ConfigDefault.BorrowStream(r) stream.WriteVal(model) @@ -190,21 +189,17 @@ func (r *Response) TryJSON(model any) (*Response, error) { return r.ContentType(mime.JSON), err } -// JSON does the same as TryJSON does, except returned error is being implicitly wrapped -// by Error +// JSON serializes the model into JSON and sets the Content-Type to application/json if succeeded. +// Otherwise, the error is silently written instead. func (r *Response) JSON(model any) *Response { resp, err := r.TryJSON(model) - if err != nil { - return r.Error(err) - } - - return resp + return resp.Error(err) } -// Error returns a response builder with an error set. If passed err is nil, nothing will happen. -// If an instance of status.HTTPError is passed, error code will be automatically set. Custom -// codes can be passed, however only first will be used. By default, the error is -// status.ErrInternalServerError +// Error returns the response builder with an error set. The nil value for error is a no-op. +// If the error is an instance of status.HTTPError, its status code is used instead the default one. +// The default code is status.ErrInternalServerError, which can be overridden if at least one code is +// specified (all others are ignored). func (r *Response) Error(err error, code ...status.Code) *Response { if err == nil { return r @@ -216,7 +211,6 @@ func (r *Response) Error(err error, code ...status.Code) *Response { c := status.InternalServerError if len(code) > 0 { - // peek the first, ignore the rest c = code[0] } @@ -225,63 +219,81 @@ func (r *Response) Error(err error, code ...status.Code) *Response { String(err.Error()) } +// Buffered allows to enable or disable writes deferring. When enabled, data from body stream +// is read until there is enough space available in an underlying buffer. If the data must be +// flushed soon possible (e.g. polling or proxying), the option should be disabled. +// +// By default, the option is enabled. +func (r *Response) Buffered(flag bool) *Response { + r.fields.Buffered = flag + return r +} + // Expose gives direct access to internal builder fields. func (r *Response) Expose() *response.Fields { return &r.fields } -// Clear discards everything was done with Response object before. +// Clear discards all changes. func (r *Response) Clear() *Response { r.fields.Clear() return r } -// Respond is a shorthand for request.Respond(). May be used as a dummy handler. +// Respond is a shorthand for request.Respond(). Can be used as a dummy handler. func Respond(request *Request) *Response { return request.Respond() } -// Code is a shorthand for request.Respond().Code(...) +// Code sets the response code. If the code is unrecognized, its default status string +// is "Nonstandard". Otherwise, it will be chosen automatically unless overridden. func Code(request *Request, code status.Code) *Response { return request.Respond().Code(code) } -// String is a shorthand for request.Respond().String(...) +// ContentType is a shorthand for request.Respond().ContentType(...) +// +// ContentType itself is a shorthand for Header("Content-Type", value) +// with an option of setting a charset, if at least one is specified. All others are ignored. +func ContentType(request *Request, contentType mime.MIME, charset ...mime.Charset) *Response { + return request.Respond().ContentType(contentType, charset...) +} + +// String sets the response body. func String(request *Request, str string) *Response { return request.Respond().String(str) } -// Bytes is a shorthand for request.Respond().Bytes(...) +// Bytes sets the response body. Please note that the passed slice must not be modified +// after being passed. func Bytes(request *Request, b []byte) *Response { return request.Respond().Bytes(b) } -// File is a shorthand for request.Respond().File(...) +// File opens a file by the path and sets it as an upload stream if succeeded. Otherwise, the error +// is silently written instead. func File(request *Request, path string) *Response { return request.Respond().File(path) } -// Stream is a shorthand for request.Respond().Stream(...) -func Stream(request *Request, reader io.Reader) *Response { - return request.Respond().Stream(reader) +// Stream sets a reader to be the source of the response's body. If no size is provided AND the reader +// doesn't have the Len() int method, the stream is considered unsized and therefore will be streamed +// using chunked transfer encoding. Otherwise, plain transfer is used, unless a compression is applied. +// Specifying the size of -1 forces the stream to be considered unsized. +func Stream(request *Request, reader io.Reader, size ...int64) *Response { + return request.Respond().Stream(reader, size...) } -// SizedStream is a shorthand for request.Respond().SizedStream(...) -func SizedStream(request *Request, reader io.Reader, size int64) *Response { - return request.Respond().SizedStream(reader, size) -} - -// JSON is a shorthand for request.Respond().JSON(...) +// JSON serializes the model into JSON and sets the Content-Type to application/json if succeeded. +// Otherwise, the error is silently written instead. func JSON(request *Request, model any) *Response { return request.Respond().JSON(model) } -// Error is a shorthand for request.Respond().Error(...) -// -// Error returns the response builder with an error set. If passed err is nil, nothing will happen. -// If an instance of status.HTTPError is passed, its status code is automatically set. Otherwise, -// status.ErrInternalServerError is used. A custom code can be set. Passing multiple status codes -// will discard all except the first one. +// Error returns the response builder with an error set. The nil value for error is a no-op. +// If the error is an instance of status.HTTPError, its status code is used instead the default one. +// The default code is status.ErrInternalServerError, which can be overridden if at least one code is +// specified (all others are ignored). func Error(request *Request, err error, code ...status.Code) *Response { return request.Respond().Error(err, code...) } diff --git a/http/status/codes.go b/http/status/codes.go index e68a764e..c6cd98fd 100644 --- a/http/status/codes.go +++ b/http/status/codes.go @@ -179,7 +179,7 @@ var ( ) func StringCode(code Code) string { - if code >= maxCodeValue { + if code < minCodeValue || code >= maxCodeValue { return "" } diff --git a/http/status/codes_test.go b/http/status/codes_test.go index 55c7f355..48431388 100644 --- a/http/status/codes_test.go +++ b/http/status/codes_test.go @@ -12,6 +12,9 @@ func Test(t *testing.T) { for _, code := range KnownCodes { require.Equal(t, strconv.Itoa(int(code)), StringCode(code)) } + + // it used to panic when passing codes <100 + require.Equal(t, "", StringCode(99)) } func Benchmark(b *testing.B) { diff --git a/indi.go b/indi.go index 65c8fa96..054b1bea 100644 --- a/indi.go +++ b/indi.go @@ -11,7 +11,7 @@ import ( "github.com/indigo-web/indigo/transport" ) -const Version = "0.17.2" +const Version = "0.17.3" // App is just a struct with addr and shutdown channel that is currently // not used. Planning to replace it with context.WithCancel() @@ -63,8 +63,8 @@ func (a *App) OnStop(cb func()) *App { } // Codec appends a new codec into the list of supported. -func (a *App) Codec(codec codec.Codec) *App { - a.codecs = append(a.codecs, codec) +func (a *App) Codec(codecs ...codec.Codec) *App { + a.codecs = append(a.codecs, codecs...) return a } diff --git a/indigo_test.go b/indigo_test.go index a469db78..412b4125 100644 --- a/indigo_test.go +++ b/indigo_test.go @@ -182,6 +182,7 @@ func TestFirstPhase(t *testing.T) { app := New(addr) go func(app *App) { r := getInbuiltRouter(). + EnableTRACE(true). Use( middleware.CustomContext( context.WithValue(context.Background(), "easter", "egg"), @@ -205,7 +206,7 @@ func TestFirstPhase(t *testing.T) { <-ch // Ensure the server is ready to accept connections - waitForAvailability(t) + waitForAvailability(t, addr, altAddr, httpsAddr) t.Run("root get", func(t *testing.T) { resp, err := stdhttp.DefaultClient.Get(appURL + "/") @@ -401,13 +402,13 @@ func TestFirstPhase(t *testing.T) { testStatic(t, "pics.vfs", mime.Unset) }) - t.Run("trace", func(t *testing.T) { + t.Run("TRACE", func(t *testing.T) { request := &stdhttp.Request{ Method: stdhttp.MethodTrace, URL: &url.URL{ Scheme: "http", Host: addr, - Path: "/", + Path: "/any-endpoint-is-good", }, Proto: "HTTP/1.1", ProtoMajor: 1, @@ -421,33 +422,9 @@ func TestFirstPhase(t *testing.T) { resp, err := stdhttp.DefaultClient.Do(request) require.NoError(t, err) require.Equal(t, stdhttp.StatusOK, resp.StatusCode) - require.Contains(t, resp.Header, "Content-Type") - require.Equal(t, 1, len(resp.Header["Content-Type"]), "too many content-type values") - require.Equal(t, "message/http", resp.Header["Content-Type"][0]) - - dataBytes, err := io.ReadAll(resp.Body) - data := string(dataBytes) - require.NoError(t, err) - - wantRequestLine := "TRACE / HTTP/1.1\r\n" - require.Greater(t, len(data), len(wantRequestLine)) - require.Equal(t, wantRequestLine, data[:len(wantRequestLine)]) - - headerLines := strings.Split(data[len(wantRequestLine):], "\r\n") - // request is terminated with \r\n\r\n, so 2 last values in headerLines - // are empty strings. Remove them - headerLines = headerLines[:len(headerLines)-2] - wantHeaderLines := []string{ - "Hello: World!", - "Host: " + addr, - "User-Agent: Go-http-client/1.1", - "Accept-Encoding: gzip", - "Content-Length: 0", - } - - for _, line := range headerLines { - require.True(t, slices.Contains(wantHeaderLines, line), "unwanted header line: "+line) - } + require.Equal(t, []string{"message/http"}, resp.Header["Content-Type"]) + // the actual content isn't that important, considering it's already covered by + // tests in router/inbuilt/trace_test.go. More importantly, we got a 200 OK response }) t.Run("not allowed method", func(t *testing.T) { @@ -718,7 +695,7 @@ func TestSecondPhase(t *testing.T) { s.NET.ReadTimeout = 1 * time.Second _ = app. Tune(s). - Codec(codec.NewGZIP()). + Codec(codec.Suit()...). OnStart(func() { ch <- struct{}{} }). @@ -729,13 +706,35 @@ func TestSecondPhase(t *testing.T) { }(app) <-ch - waitForAvailability(t) + waitForAvailability(t, addr) + + t.Run("TRACE", func(t *testing.T) { + request := &stdhttp.Request{ + Method: stdhttp.MethodTrace, + URL: &url.URL{ + Scheme: "http", + Host: addr, + Path: "/any-endpoint-is-good", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: stdhttp.Header{ + "Hello": {"World!"}, + }, + Host: addr, + RemoteAddr: addr, + } + resp, err := stdhttp.DefaultClient.Do(request) + require.NoError(t, err) + require.Equal(t, stdhttp.StatusMethodNotAllowed, resp.StatusCode) + }) t.Run("accept encoding", func(t *testing.T) { resp, err := stdhttp.DefaultClient.Head(appURL + "/") require.NoError(t, err) require.Equal(t, stdhttp.StatusOK, resp.StatusCode) - require.Equal(t, []string{"gzip"}, resp.Header["Accept-Encoding"]) + require.Equal(t, []string{"gzip, deflate, zstd"}, resp.Header["Accept-Encoding"]) require.Empty(t, readFullBody(t, resp)) }) @@ -833,18 +832,64 @@ func TestSecondPhase(t *testing.T) { }) } -func waitForAvailability(t *testing.T) { - deadline := time.Now().Add(2 * time.Second) - for { - conn, err := net.Dial("tcp4", addr) - if err == nil { - _ = conn.Close() - break +func TestEscaping(t *testing.T) { + runTest := func(dynamic bool) func(t *testing.T) { + return func(t *testing.T) { + app := New(addr) + go func(app *App) { + r := inbuilt.New(). + Get("/foo%2fbar", func(request *http.Request) *http.Response { + return http.Code(request, 201) + }). + Get("/foo%3abar", func(request *http.Request) *http.Response { + return http.Code(request, 202) + }) + + if dynamic { + r.Get("/ :", http.Respond) // unreachable endpoint + } + + _ = app.Serve(r) + }(app) + + waitForAvailability(t, addr) + + test := func(path string, wantCode int) func(t *testing.T) { + return func(t *testing.T) { + resp, err := stdhttp.Get(appURL + path) + require.NoError(t, err) + require.Equal(t, wantCode, resp.StatusCode) + } + } + + t.Run("escaped slash", test("/foo%2fbar", 201)) + t.Run("escaped unnormalized slash", test("/foo%2Fbar", 201)) + t.Run("unescaped slash", test("/foo/bar", 404)) + + t.Run("escaped colon", test("/foo%3abar", 202)) + t.Run("escaped unnormalized colon", test("/foo%3Abar", 202)) + t.Run("unescaped colon", test("/foo:bar", 202)) } - if time.Now().After(deadline) { - t.Fatalf("server did not start listening on %s in time: %v", addr, err) + } + + t.Run("static", runTest(false)) + t.Run("dynamic", runTest(true)) +} + +func waitForAvailability(t *testing.T, addrs ...string) { + for _, addr := range addrs { + deadline := time.Now().Add(2 * time.Second) + for { + conn, err := net.Dial("tcp4", addr) + if err == nil { + _ = conn.Close() + break + } + if time.Now().After(deadline) { + t.Fatalf("server did not start listening on %s in time: %v", addr, err) + } + time.Sleep(50 * time.Millisecond) } - time.Sleep(50 * time.Millisecond) } } diff --git a/internal/buffer/buffer.go b/internal/buffer/buffer.go index 94ec078e..df6adcce 100644 --- a/internal/buffer/buffer.go +++ b/internal/buffer/buffer.go @@ -15,6 +15,11 @@ func New(initialSize, maxSize int) *Buffer { } } +// AppendBytes is a variadic version of Append. +func (b *Buffer) AppendBytes(bytes ...byte) (ok bool) { + return b.Append(bytes) +} + // Append writes data, checking whether the new amount of elements (bytes) doesn't exceed the // limit, otherwise discarding the data and returning false. func (b *Buffer) Append(elements []byte) (ok bool) { diff --git a/internal/codecutil/codec.go b/internal/codecutil/cache.go similarity index 56% rename from internal/codecutil/codec.go rename to internal/codecutil/cache.go index 0a5e1af4..f5818593 100644 --- a/internal/codecutil/codec.go +++ b/internal/codecutil/cache.go @@ -1,10 +1,9 @@ package codecutil import ( - "iter" + "strings" "github.com/indigo-web/indigo/http/codec" - "github.com/indigo-web/indigo/internal/strutil" ) type Cache struct { @@ -13,10 +12,9 @@ type Cache struct { instances []codec.Instance } -func NewCache(codecs []codec.Codec) Cache { +func NewCache(codecs []codec.Codec, acceptString string) Cache { return Cache{ - // TODO: we're still allocating a string on every connection. Which we actually can avoid. - accept: acceptEncodings(codecs), + accept: acceptString, codecs: codecs, instances: make([]codec.Instance, len(codecs)), } @@ -47,24 +45,22 @@ func (c Cache) Get(token string) codec.Instance { return inst } -func (c Cache) AcceptEncodings() string { +func (c Cache) AcceptEncoding() string { return c.accept } -func acceptEncodings(codecs []codec.Codec) string { +func AcceptEncoding(codecs []codec.Codec) string { if len(codecs) == 0 { return "identity" } - return strutil.Join(traverseTokens(codecs), ", ") -} + var b strings.Builder -func traverseTokens(codecs []codec.Codec) iter.Seq[string] { - return func(yield func(string) bool) { - for _, c := range codecs { - if !yield(c.Token()) { - break - } - } + b.WriteString(codecs[0].Token()) + for _, c := range codecs[1:] { + b.WriteString(", ") + b.WriteString(c.Token()) } + + return b.String() } diff --git a/internal/hexconv/hexconv.go b/internal/hexconv/hexconv.go index 391ad138..61f692b8 100644 --- a/internal/hexconv/hexconv.go +++ b/internal/hexconv/hexconv.go @@ -4,7 +4,7 @@ package hexconv // meanwhile no valid value uses more than a single digit. // // Tip: in order to check whether two or more hexadecimals are valid characters, consider -// if a|b > 0x0f { /* fail */ } +// if a|b == 0xFF { /* fail */ } var Halfbyte = [256]byte{ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, @@ -23,3 +23,6 @@ var Halfbyte = [256]byte{ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, } + +// Char maps a 4-bit value to its ASCII hexadecimal letter. +var Char = [16]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'} diff --git a/internal/httptest/parse/parse.go b/internal/httptest/parse/parse.go index 6e9a4dcb..2c60142b 100644 --- a/internal/httptest/parse/parse.go +++ b/internal/httptest/parse/parse.go @@ -12,7 +12,7 @@ import ( func HTTP11Request(data string) (*http.Request, error) { client := dummy.NewMockClient([]byte(data)) request := construct.Request(config.Default(), client) - suit := http1.New(config.Default(), nil, client, request, codecutil.NewCache(nil)) + suit := http1.New(config.Default(), nil, client, request, codecutil.NewCache(nil, "identity")) request.Body = http.NewBody(suit) for { diff --git a/internal/httptest/serialize/serialize.go b/internal/httptest/serialize/serialize.go index 44deee72..b01e48ef 100644 --- a/internal/httptest/serialize/serialize.go +++ b/internal/httptest/serialize/serialize.go @@ -1,8 +1,6 @@ package serialize import ( - "strconv" - "github.com/indigo-web/indigo/http" ) @@ -34,13 +32,6 @@ func Headers(request *http.Request) string { buff = header(buff, h) } - if !request.Encoding.Chunked { - buff = header(buff, http.Header{ - Key: "Content-Length", - Value: strconv.Itoa(request.ContentLength), - }) - } - buff = append(buff, '\r', '\n') return string(buff) diff --git a/internal/protocol/http1/body.go b/internal/protocol/http1/body.go index 7453ac60..228c21be 100644 --- a/internal/protocol/http1/body.go +++ b/internal/protocol/http1/body.go @@ -32,7 +32,7 @@ func (b *body) Fetch() ([]byte, error) { } func (b *body) Reset(request *http.Request) { - if request.Encoding.Chunked { + if request.Chunked { b.initChunked() b.reader = (*body).readChunked } else if request.Connection == "close" { diff --git a/internal/protocol/http1/body_test.go b/internal/protocol/http1/body_test.go index 4d579948..edfb186b 100644 --- a/internal/protocol/http1/body_test.go +++ b/internal/protocol/http1/body_test.go @@ -48,7 +48,7 @@ func getRequestWithBody(chunked bool, body ...[]byte) (*http.Request, *body) { req.Headers = hdrs req.ContentLength = contentLength - req.Encoding.Chunked = chunked + req.Chunked = chunked req.Body.Reset(req) return req, b diff --git a/internal/protocol/http1/chunked.go b/internal/protocol/http1/chunked.go index 394684a2..03418ac3 100644 --- a/internal/protocol/http1/chunked.go +++ b/internal/protocol/http1/chunked.go @@ -3,7 +3,6 @@ package http1 import ( "bytes" "io" - "strings" "github.com/indigo-web/indigo/http/status" "github.com/indigo-web/indigo/internal/hexconv" @@ -212,11 +211,3 @@ chunkTrailerFieldLine: goto trailer } } - -var ( - // chunkExtZeroFill is used to fill the gap between chunk length and chunk content. The count - // 64/4 represents 64 bits - the maximal uint size, and 4 - bits per hex value, therefore - // resulting in 15 characters (plus semicolon) total. - chunkExtZeroFill = ";" + strings.Repeat("0", 64/4-1) - chunkZeroTrailer = []byte("0\r\n\r\n") -) diff --git a/internal/protocol/http1/parser.go b/internal/protocol/http1/parser.go index ce073182..ff4213b6 100644 --- a/internal/protocol/http1/parser.go +++ b/internal/protocol/http1/parser.go @@ -53,10 +53,9 @@ type Parser struct { func NewParser(cfg *config.Config, request *http.Request, statusBuff, headers *buffer.Buffer) *Parser { return &Parser{ - cfg: cfg, - state: eMethod, - request: request, - // TODO: pass these through arguments instead of allocating in-place + cfg: cfg, + state: eMethod, + request: request, acceptEncodings: make([]string, 0, cfg.Headers.MaxAcceptEncodingTokens), encodings: make([]string, 0, cfg.Headers.MaxEncodingTokens), requestLine: statusBuff, @@ -150,18 +149,20 @@ path: for i := 0; i < len(data); i++ { switch char := data[i]; char { case '%': - if !requestLine.Append(data[checkpoint:i]) { - return true, nil, status.ErrURITooLong - } - if len(data[i+1:]) >= 2 { // fast path c := (hexconv.Halfbyte[data[i+1]] << 4) | hexconv.Halfbyte[data[i+2]] - if isProhibitedChar(c) { + if strutil.IsASCIINonprintable(c) { return true, nil, status.ErrBadRequest } + if strutil.IsURLUnsafeChar(c) { + data[i+1] |= 0x20 + data[i+2] |= 0x20 + i += 2 + continue + } - if !requestLine.AppendByte(c) { + if !(requestLine.Append(data[checkpoint:i]) && requestLine.AppendByte(c)) { return true, nil, status.ErrURITooLong } @@ -169,6 +170,10 @@ path: checkpoint = i + 1 } else { // slow path + if !requestLine.Append(data[checkpoint:i]) { + return true, nil, status.ErrURITooLong + } + data = data[i+1:] goto pathDecode1Char } @@ -197,7 +202,7 @@ path: // compact and not bloat it with unnecessary states, simply reject such requests. return true, nil, status.ErrBadRequest default: - if isProhibitedChar(char) { + if strutil.IsASCIINonprintable(char) { return true, nil, status.ErrBadRequest } } @@ -229,9 +234,17 @@ pathDecode2Char: } char := (hexconv.Halfbyte[p.urlEncodedChar] << 4) | hexconv.Halfbyte[data[0]] - if isProhibitedChar(char) { + if strutil.IsASCIINonprintable(char) { return true, nil, status.ErrBadRequest } + if strutil.IsURLUnsafeChar(char) { + if !requestLine.AppendBytes('%', p.urlEncodedChar|0x20, data[0]|0x20) { + return true, nil, status.ErrURITooLong + } + + data = data[1:] + goto path + } if !requestLine.AppendByte(char) { return true, nil, status.ErrURITooLong @@ -252,7 +265,7 @@ paramsKey: if len(data[i+1:]) >= 2 { // fast path c := (hexconv.Halfbyte[data[i+1]] << 4) | hexconv.Halfbyte[data[i+2]] - if isProhibitedChar(c) { + if strutil.IsASCIINonprintable(c) { return true, nil, status.ErrBadParams } @@ -277,7 +290,7 @@ paramsKey: case '#': return true, nil, status.ErrBadRequest default: - if isProhibitedChar(char) { + if strutil.IsASCIINonprintable(char) { return true, nil, status.ErrBadParams } @@ -308,7 +321,7 @@ paramsKeyDecode2Char: } char := (hexconv.Halfbyte[p.urlEncodedChar] << 4) | hexconv.Halfbyte[data[0]] - if isProhibitedChar(char) { + if strutil.IsASCIINonprintable(char) { return true, nil, status.ErrBadParams } @@ -337,7 +350,7 @@ paramsValue: if len(data[i+1:]) >= 2 { // fast path c := (hexconv.Halfbyte[data[i+1]] << 4) | hexconv.Halfbyte[data[i+2]] - if isProhibitedChar(c) { + if strutil.IsASCIINonprintable(c) { return true, nil, status.ErrBadParams } @@ -399,7 +412,7 @@ paramsValueDecode2Char: } char := (hexconv.Halfbyte[p.urlEncodedChar] << 4) | hexconv.Halfbyte[data[0]] - if isProhibitedChar(char) { + if strutil.IsASCIINonprintable(char) { return true, nil, status.ErrBadParams } @@ -530,11 +543,11 @@ headerValue: } case 15: if strutil.CmpFoldFast(key, "Accept-Encoding") { - p.acceptEncodings, request.Encoding.Accept, err = splitTokens(p.acceptEncodings, value) + p.acceptEncodings, request.AcceptEncoding, err = splitTokens(p.acceptEncodings, value) } case 16: if strutil.CmpFoldFast(key, "Content-Encoding") { - p.encodings, request.Encoding.Content, err = splitTokens(p.encodings, value) + p.encodings, request.ContentEncoding, err = splitTokens(p.encodings, value) if err != nil { return true, nil, err } @@ -546,20 +559,19 @@ headerValue: } p.metTransferEncoding = true - - p.encodings, request.Encoding.Transfer, err = splitTokens(p.encodings, value) + p.encodings, request.TransferEncoding, err = splitTokens(p.encodings, value) if err != nil { return true, nil, err } - te := request.Encoding.Transfer - if len(te) > 0 { - if te[len(te)-1] != "chunked" { - return true, nil, status.ErrBadEncoding - } - - request.Encoding.Chunked = true + te := request.TransferEncoding + if len(te) == 0 || len(te) == 1 && te[0] == "identity" { + break + } else if len(te) > 0 && te[len(te)-1] != "chunked" { + return true, nil, status.ErrBadEncoding } + + request.Chunked = true } } @@ -592,17 +604,17 @@ contentLength: } p.contentLength = p.contentLength*10 + int64(char-'0') + if !p.headers.AppendByte(char) { + return true, nil, status.ErrHeaderFieldsTooLarge + } } p.state = eContentLength return false, nil, nil contentLengthEnd: - // guaranteed, that data at this point contains AT LEAST 1 byte. - // The proof is, that this code is reachable ONLY if loop has reached a non-digit - // ascii symbol. In case loop has finished peacefully, as no more data left, but also no - // character found to satisfy the exit condition, this code will never be reached request.ContentLength = int(p.contentLength) + request.Headers.Add("Content-Length", uf.B2S(p.headers.Finish())) switch data[0] { case '\r': @@ -630,14 +642,18 @@ contentLengthCR: } func (p *Parser) cleanup() { - p.metTransferEncoding = false - p.headersNumber = 0 p.requestLine.Clear() p.headers.Clear() - p.contentLength = 0 - p.acceptEncodings = p.acceptEncodings[:0] - p.encodings = p.encodings[:0] - p.state = eMethod + + *p = Parser{ + state: eMethod, + cfg: p.cfg, + request: p.request, + requestLine: p.requestLine, + headers: p.headers, + acceptEncodings: p.acceptEncodings[:0], + encodings: p.encodings[:0], + } } func splitTokens(buff []string, value string) (alteredBuff, toks []string, err error) { @@ -709,7 +725,3 @@ func stripCR(b []byte) []byte { return b } - -func isProhibitedChar(c byte) bool { - return c < 0x20 || c > 0x7e -} diff --git a/internal/protocol/http1/parser_test.go b/internal/protocol/http1/parser_test.go index 91ee708f..5f7eaf8f 100644 --- a/internal/protocol/http1/parser_test.go +++ b/internal/protocol/http1/parser_test.go @@ -151,21 +151,19 @@ func TestParser(t *testing.T) { } compareRequests(t, wanted, request) - request.Reset() }) - t.Run("simple GET with leading CRLF", func(t *testing.T) { + t.Run("leading CRLF", func(t *testing.T) { raw := "\r\n\r\nGET / HTTP/1.1\r\n\r\n" - parser, request := getParser(cfg) + parser, _ := getParser(cfg) done, extra, err := parser.Parse([]byte(raw)) // unfortunately, we don't support this. Such clients must die. require.Error(t, err, status.ErrBadRequest.Error()) require.True(t, done) require.Empty(t, extra) - request.Reset() }) - t.Run("normal GET", func(t *testing.T) { + t.Run("GET with headers", func(t *testing.T) { raw := "GET / HTTP/1.1\r\nHello: World!\r\nEaster: Egg\r\n\r\n" parser, request := getParser(cfg) done, extra, err := parser.Parse([]byte(raw)) @@ -183,7 +181,6 @@ func TestParser(t *testing.T) { } compareRequests(t, wanted, request) - request.Reset() }) t.Run("multiple header values", func(t *testing.T) { @@ -204,7 +201,6 @@ func TestParser(t *testing.T) { } compareRequests(t, wanted, request) - request.Reset() }) t.Run("only lf", func(t *testing.T) { @@ -225,7 +221,6 @@ func TestParser(t *testing.T) { } compareRequests(t, wanted, request) - request.Reset() }) t.Run("fuzz GET", func(t *testing.T) { @@ -268,7 +263,6 @@ func TestParser(t *testing.T) { } compareRequests(t, wanted, request) - request.Reset() }) t.Run("content length", func(t *testing.T) { @@ -279,6 +273,7 @@ func TestParser(t *testing.T) { require.True(t, done) require.Equal(t, "Hello, world!", string(extra)) require.Equal(t, 13, request.ContentLength) + require.Equal(t, "13", request.Headers.Value("content-length")) request.Reset() raw = "GET / HTTP/1.1\r\nContent-Length: 13\r\nHi-Hi: ha-ha\r\n\r\nHello, world!" @@ -287,9 +282,8 @@ func TestParser(t *testing.T) { require.True(t, done) require.Equal(t, "Hello, world!", string(extra)) require.Equal(t, 13, request.ContentLength) - require.True(t, request.Headers.Has("hi-hi")) + require.Equal(t, "13", request.Headers.Value("content-length")) require.Equal(t, "ha-ha", request.Headers.Value("hi-hi")) - request.Reset() }) t.Run("connection", func(t *testing.T) { @@ -300,7 +294,6 @@ func TestParser(t *testing.T) { require.True(t, done) require.Empty(t, string(extra)) require.Equal(t, "Keep-Alive", request.Connection) - request.Reset() }) t.Run("Transfer-Encoding and Content-Encoding", func(t *testing.T) { @@ -310,25 +303,29 @@ func TestParser(t *testing.T) { require.NoError(t, err) require.True(t, done) require.Empty(t, string(extra)) - require.Equal(t, []string{"chunked"}, request.Encoding.Transfer) - require.True(t, request.Encoding.Chunked) - require.Equal(t, []string{"gzip", "deflate"}, request.Encoding.Content) - request.Reset() + require.Equal(t, []string{"chunked"}, request.TransferEncoding) + require.True(t, request.Chunked) + require.Equal(t, []string{"gzip", "deflate"}, request.ContentEncoding) }) t.Run("urldecode", func(t *testing.T) { - parsePath := func(encodedPath string) (string, error) { - raw := fmt.Sprintf("GET %s ", encodedPath) + parseRequestLine := func(path string) (*http.Request, error) { + raw := fmt.Sprintf("GET %s ", path) parser, request := getParser(cfg) for i := 0; i < len(raw); i++ { _, _, err := parser.Parse([]byte{raw[i]}) if err != nil { - return "", err + return request, err } } - return request.Path, nil + return request, nil + } + + parsePath := func(path string) (string, error) { + request, err := parseRequestLine(path) + return request.Path, err } t.Run("path", func(t *testing.T) { @@ -342,6 +339,33 @@ func TestParser(t *testing.T) { require.EqualError(t, err, status.ErrBadRequest.Error()) }) + t.Run("unsafe", func(t *testing.T) { + test := func(t *testing.T, request *http.Request) { + require.Equal(t, "/ foo%2f:bar#?", request.Path) + require.Equal(t, "bar", request.Params.Value("foo+=")) + } + + t.Run("slowpath", func(t *testing.T) { + request, err := parseRequestLine("/%20foo%2f%3abar%23%3f?foo%2b%3d=bar") + require.NoError(t, err) + test(t, request) + }) + + t.Run("fastpath", func(t *testing.T) { + parser, request := getParser(config.Default()) + _, _, err := parser.Parse([]byte("GET /%20foo%2f%3abar%23%3f?foo%2b%3d=bar ")) + require.NoError(t, err) + test(t, request) + }) + + t.Run("normalize", func(t *testing.T) { + parser, request := getParser(config.Default()) + _, _, err := parser.Parse([]byte("GET /foo%2Fbar ")) + require.NoError(t, err) + require.Equal(t, "/foo%2fbar", request.Path) + }) + }) + t.Run("params", func(t *testing.T) { parseParams := func(params ...string) (http.Params, error) { parser, request := getParser(config.Default()) diff --git a/internal/protocol/http1/serializer.go b/internal/protocol/http1/serializer.go index 1c5e1948..50ca74aa 100644 --- a/internal/protocol/http1/serializer.go +++ b/internal/protocol/http1/serializer.go @@ -5,6 +5,7 @@ import ( "math/bits" "slices" "strconv" + "strings" "time" "github.com/indigo-web/indigo/config" @@ -16,6 +17,7 @@ import ( "github.com/indigo-web/indigo/http/proto" "github.com/indigo-web/indigo/http/status" "github.com/indigo-web/indigo/internal/codecutil" + "github.com/indigo-web/indigo/internal/hexconv" "github.com/indigo-web/indigo/internal/response" "github.com/indigo-web/indigo/internal/strutil" "github.com/indigo-web/indigo/kv" @@ -25,6 +27,7 @@ import ( type serializer struct { cfg *config.Config request *http.Request + response *response.Fields client transport.Client buff []byte streamReadBuff []byte @@ -40,12 +43,14 @@ func newSerializer( buff []byte, ) *serializer { return &serializer{ - cfg: cfg, - request: request, - client: client, - codecs: codecs, - buff: buff, - defaultHeaders: preprocessDefaultHeaders(cfg.Headers.Default, codecs.AcceptEncodings()), + cfg: cfg, + request: request, + client: client, + codecs: codecs, + buff: buff, + defaultHeaders: newDefaultHeaders( + pairsFromMap(cfg.Headers.Default, codecs.AcceptEncoding()), + ), } } @@ -54,15 +59,16 @@ func (s *serializer) Upgrade() { s.appendProtocol(s.request.Protocol) s.buff = append(s.buff, "101 Switching Protocol\r\n"...) - s.appendKnownHeader("Connection: ", "upgrade") - s.appendKnownHeader("Upgrade: ", s.request.Upgrade.String()) + s.appendKnownHeader("Connection", "upgrade") + s.appendKnownHeader("Upgrade", s.request.Upgrade.String()) s.crlf() } func (s *serializer) Write(protocol proto.Protocol, response *http.Response) error { - s.appendProtocol(protocol) resp := response.Expose() + + s.appendProtocol(protocol) s.appendStatus(resp) s.appendHeaders(resp) @@ -75,16 +81,15 @@ func (s *serializer) Write(protocol proto.Protocol, response *http.Response) err return err } - err = s.flush() - s.cleanup() - - return err + return s.flush() } func (s *serializer) writeStream(resp *response.Fields) (err error) { + s.response = resp stream, length := resp.Stream, resp.StreamSize + unsized := length == -1 if length == 0 { - s.appendKnownHeader("Content-Length: ", "0") + s.appendKnownHeader("Content-Length", "0") s.crlf() return nil } @@ -94,25 +99,45 @@ func (s *serializer) writeStream(resp *response.Fields) (err error) { return status.ErrInternalServerError } + var closeConnection bool + defer func() { if c, ok := stream.(io.Closer); ok { if cerr := c.Close(); cerr != nil && err == nil { err = cerr } } + + if closeConnection && err == nil { + err = status.ErrCloseConnection + } }() var encoder io.WriteCloser - compressor := s.getCompressor(resp.ContentEncoding) - if length != -1 && compressor != nil { + compression := resp.ContentEncoding + if resp.AutoCompress && (unsized || length >= s.cfg.NET.SmallBody) { + // if the stream is sized and the size is below limit (i.e. is considered a small one), + // do not compress it. It won't give much gain anyway, yet the performance is impacted, + // especially if we otherwise could use a zero-copy mechanism + compression = s.request.PreferredEncoding() + } + + compressor := s.getCompressor(compression) + if !unsized && compressor != nil { // if sized stream is compressed, convert it to unsized length = -1 } if length == -1 { - encoder = chunkedWriter{s} - s.appendKnownHeader("Transfer-Encoding: ", "chunked") + if s.request.Protocol == proto.HTTP11 { + encoder = chunkedWriter{s} + s.appendKnownHeader("Transfer-Encoding", "chunked") + } else { + encoder = identityWriter{s} + s.appendKnownHeader("Connection", "close") + closeConnection = true + } } else { encoder = identityWriter{s} s.appendContentLength(length) @@ -131,7 +156,8 @@ func (s *serializer) writeStream(resp *response.Fields) (err error) { return err } - s.growToContain(int(length) + len(crlf)) // because CRLF wasn't written yet + // +len(crlf) because it wasn't written yet, therefore not yet included in the len(s.buff) + s.growToContain(len(crlf) + int(length)) } s.crlf() // finalize the headers block @@ -186,26 +212,28 @@ func (s *serializer) writeStream(resp *response.Fields) (err error) { } } -func (s *serializer) growToContain(n int) { - extra := min(s.cfg.NET.WriteBufferSize.Maximal-cap(s.buff), n) - if extra <= 0 { - // surplus is possible and in fact isn't a bug, because slices.Grow doesn't guarantee - // that the grown slice will have a specific capacity. It is guaranteed to be _enough_ - // to hold n new values. It's anyway not that bad to have a few extra bytes, after all. - return +func (s *serializer) grow(newsize int) { + // cap the size at its top value from the config. + newsize = min(s.cfg.NET.WriteBufferSize.Maximal, newsize) + // the growth can be triggered even the buffer is already at its maximal size. Do nothing then. + if newsize > cap(s.buff) { + s.buff = make([]byte, 0, newsize) } +} - s.buff = slices.Grow(s.buff, extra) +func (s *serializer) growToContain(n int) { + newsize := min(s.cfg.NET.WriteBufferSize.Maximal-len(s.buff), n) + s.buff = slices.Grow(s.buff, newsize) } func (s *serializer) getCompressor(token string) codec.Compressor { - if len(token) == 0 { + if token == "" || token == "identity" { return nil } compressor := s.codecs.Get(token) if compressor != nil { - s.appendKnownHeader("Content-Encoding: ", token) + s.appendKnownHeader("Content-Encoding", token) } return compressor @@ -234,12 +262,14 @@ func (s *serializer) safeAppend(data []byte) error { return nil } -func (s *serializer) flush() (err error) { - if len(s.buff) > 0 { - _, err = s.client.Write(s.buff) - s.buff = s.buff[:0] +func (s *serializer) flush() error { + if len(s.buff) == 0 { + return nil } + _, err := s.client.Write(s.buff) + s.buff = s.buff[:0] + return err } @@ -277,12 +307,14 @@ func (s *serializer) appendHeaders(fields *response.Fields) { s.crlf() } - for _, header := range s.defaultHeaders { + for i, header := range s.defaultHeaders { if header.Excluded { + s.defaultHeaders[i].Excluded = false continue } - s.buff = append(s.buff, header.Full...) + s.appendHeader(header.Pair) + s.crlf() } } @@ -297,6 +329,7 @@ func (s *serializer) appendHeader(header kv.Pair) { // have a colon and a space included. func (s *serializer) appendKnownHeader(key, value string) { s.buff = append(s.buff, key...) + s.colonsp() s.buff = append(s.buff, value...) s.crlf() } @@ -390,35 +423,32 @@ func (s *serializer) crlf() { s.buff = append(s.buff, crlf...) } -func (s *serializer) cleanup() { - s.defaultHeaders.Reset() -} - type chunkedWriter struct { s *serializer } -func (c chunkedWriter) maxhex(n int) int { - return (bits.Len64(uint64(n))-1)>>2 + 1 -} - func (c chunkedWriter) ReadFrom(r io.Reader) (total int64, err error) { const crlflen = len(crlf) for { var ( - buff = c.s.buff[len(c.s.buff):cap(c.s.buff)] - maxHexLength = c.maxhex(len(buff)) - dataOffset = maxHexLength + crlflen + buff = c.s.buff[len(c.s.buff):cap(c.s.buff)] + maxHexLen = hexlen(len(buff)) + dataOffset = maxHexLen + crlflen ) n, err := r.Read(buff[dataOffset : len(buff)-crlflen]) if n > 0 { total += int64(n) - if err := c.writechunk(maxHexLength, n); err != nil { + if err := c.writechunk(maxHexLen, n); err != nil { return 0, err } + + if n+dataOffset+crlflen >= cap(c.s.buff)-cap(c.s.buff)>>6 { + // if the chunk solely occupies ~98.44% of the whole buffer capacity, double the size + c.s.grow(cap(c.s.buff) << 1) + } } switch err { @@ -436,19 +466,17 @@ func (c chunkedWriter) Write(b []byte) (n int, err error) { const crlflen = len(crlf) blen := len(b) - // TODO: add optional but enabled by default buffering chunks. - // otherwise, cap(b) = cap(c.s.buff) => 7 bytes leftover, which will be sent - // as an independent chunk. Highly inefficient. However, making buffering a default behaviour - // completely disables the possibility to implement longpolling based on chunked transfer encoding. - // Also undesired. But by default very little people do really use it like that or anyhow rely on - // lag-free chunk upload or on their consistency. Therefore, enabling buffering by default would result - // in a great choice. + // knowing the size of b in advance, grow to contain it fully if needed + c.s.grow(hexlen(cap(c.s.buff)) + crlflen + blen + crlflen + 1) for len(b) > 0 { - buff := c.s.buff[len(c.s.buff):cap(c.s.buff)] - maxHexLen := c.maxhex(len(buff)) + var ( + buff = c.s.buff[len(c.s.buff):cap(c.s.buff)] + maxHexLen = hexlen(len(buff)) + dataOffset = maxHexLen + crlflen + ) - n = copy(buff[maxHexLen+crlflen:len(buff)-crlflen], b) + n = copy(buff[dataOffset:len(buff)-crlflen], b) if err = c.writechunk(maxHexLen, n); err != nil { return 0, err } @@ -462,18 +490,9 @@ func (c chunkedWriter) Write(b []byte) (n int, err error) { func (c chunkedWriter) writechunk(maxHexLen, datalen int) error { const crlflen = len(crlf) - // TODO: add optional but enabled by default buffering chunks. - // otherwise, cap(b) = cap(c.s.buff) => 7 bytes leftover, which will be sent - // as an independent chunk. Highly inefficient. However, making buffering a default behaviour - // completely disables the possibility to implement longpolling based on chunked transfer encoding. - // Also undesired. But by default very little people do really use it like that or anyhow rely on - // lag-free chunk upload or on their consistency. Therefore, enabling buffering by default would result - // in a great choice. - for { var ( buff = c.s.buff[len(c.s.buff):cap(c.s.buff)] - buffOffset = 0 dataOffset = maxHexLen + crlflen ) @@ -491,40 +510,24 @@ func (c chunkedWriter) writechunk(maxHexLen, datalen int) error { return err } - // TODO: buffer chunks here continue } - hexlen := len(strconv.AppendUint(buff[:0], uint64(datalen), 16)) // chunk length - - if len(c.s.buff) > 0 { - // if there was any data in the buffer before, we must fill the gap in between. - // The best way to do it is via an extension. - copy(buff[hexlen:maxHexLen], chunkExtZeroFill) - } else { - // otherwise, we can save a couple of bytes by simply truncating the unused prefix slots. - buffOffset = maxHexLen - hexlen - copy(buff[buffOffset:], buff[:hexlen]) + // write the zero-filled hex length + chunklen := datalen + for i := maxHexLen; i > 0; i-- { + buff[i-1] = hexconv.Char[chunklen&0b1111] + chunklen >>= 4 } copy(buff[maxHexLen:], crlf) // CRLF between length and data copy(buff[dataOffset+datalen:], crlf) // CRLF at the end of the data - restore := c.s.buff[:0] - c.s.buff = c.s.buff[buffOffset : len(c.s.buff)+dataOffset+datalen+crlflen] // extend buffer to include the written data - if err := c.s.flush(); err != nil { - return err - } - c.s.buff = restore - - if cap(c.s.buff)-dataOffset-datalen-crlflen <= cap(c.s.buff)>>6 { - // if free space left after the whole chunk was written is less than - // ~1.56% of the buffer total capacity, double the buffer size. - newsize := min(c.s.cfg.NET.WriteBufferSize.Maximal, cap(c.s.buff)<<1) - // the growth can be triggered even the buffer is already at its maximal size. Do nothing then. - if newsize > cap(c.s.buff) { - c.s.buff = make([]byte, 0, newsize) - } + // extend the buffer to include the written data + c.s.buff = c.s.buff[:len(c.s.buff)+dataOffset+datalen+crlflen] + + if !c.s.response.Buffered || len(c.s.buff) >= 3*cap(c.s.buff)>>2 { + return c.s.flush() } return nil @@ -532,7 +535,7 @@ func (c chunkedWriter) writechunk(maxHexLen, datalen int) error { } func (c chunkedWriter) Close() error { - if err := c.s.safeAppend(chunkZeroTrailer); err != nil { + if err := c.s.safeAppend([]byte("0\r\n\r\n")); err != nil { return err } @@ -544,66 +547,89 @@ type identityWriter struct { } func (i identityWriter) ReadFrom(r io.Reader) (total int64, err error) { - for { - n, err := r.Read(i.s.buff[len(i.s.buff):cap(i.s.buff)]) + streamSize := i.s.response.StreamSize + // identityWriter is used to write unsized streams if chunked transfer encoding + // isn't available (e.g. HTTP/1.0 clients). The stream is finalized by the connection + // close in this case. + unsized := streamSize == -1 + + for total < streamSize || unsized { + boundary := cap(i.s.buff) + if !unsized { + boundary = min(boundary, int(streamSize-total)+len(i.s.buff)) + } + + n, err := r.Read(i.s.buff[len(i.s.buff):boundary]) total += int64(n) - i.s.buff = i.s.buff[0 : len(i.s.buff)+n] - if err := i.s.flush(); err != nil { - return 0, err + i.s.buff = i.s.buff[0 : len(i.s.buff)+n] + if !i.s.response.Buffered || len(i.s.buff) >= 3*cap(i.s.buff)>>2 { + // flush if unbuffered OR buffered and the buffer is >=3/4 full. + if ferr := i.s.flush(); ferr != nil { + return 0, ferr + } } switch err { case nil: case io.EOF: + if !unsized && total < streamSize { + // the stream is exhausted before it must have been. No good. + return total, status.ErrInternalServerError + } + return total, nil default: return 0, err } } + + return total, nil } func (i identityWriter) Write(p []byte) (int, error) { - err := i.s.safeAppend(p) - return len(p), err + return len(p), i.s.safeAppend(p) } func (i identityWriter) Close() error { return i.s.flush() } -func preprocessDefaultHeaders(headers map[string]string, acceptEncoding string) defaultHeaders { - processed := make(defaultHeaders, 0, len(headers)+1) +type excludablePair struct { + Excluded bool + kv.Pair +} + +func pairsFromMap(m map[string]string, acceptEncoding string) []excludablePair { + pairs := make([]excludablePair, 0, len(m)+1) + pairs = append(pairs, excludablePair{ + Pair: kv.Pair{Key: "Accept-Encoding", Value: acceptEncoding}, + }) - for key, value := range headers { - serialized := key + ": " + value + crlf - processed = append(processed, defaultHeader{ - // we let the GC release all the values of the map, as here we're using only - // the brand-new line without keeping the original string - Key: serialized[:len(key)], - Full: serialized, + for key, value := range m { + pairs = append(pairs, excludablePair{ + Pair: kv.Pair{Key: key, Value: value}, }) } - processed = append(processed, defaultHeader{ - Key: "Accept-Encoding", - Full: "Accept-Encoding: " + acceptEncoding + crlf, - }) - - return processed + return pairs } -type defaultHeader struct { - Excluded bool - Key string - Full string -} +type defaultHeaders []excludablePair -type defaultHeaders []defaultHeader +func newDefaultHeaders(pairs []excludablePair) defaultHeaders { + slices.SortFunc(pairs, func(a, b excludablePair) int { + return strings.Compare(a.Key, b.Key) + }) -func (d defaultHeaders) Exclude(key string) { - // TODO: binary search + return pairs +} +func (d defaultHeaders) Exclude(key string) { + // it's a perfect candidate for binary search, however in reality it introduced any visible + // benefit only starting at 10 and more default headers. If less, the penalty is also significant. + // The only optimization left to try out is stopping the iteration when `header.Key > key`, + // considering the headers are still sorted. for i, header := range d { if strutil.CmpFoldFast(header.Key, key) { header.Excluded = true @@ -613,8 +639,6 @@ func (d defaultHeaders) Exclude(key string) { } } -func (d defaultHeaders) Reset() { - for i := range d { - d[i].Excluded = false - } +func hexlen(n int) int { + return (bits.Len64(uint64(n))-1)>>2 + 1 } diff --git a/internal/protocol/http1/serializer_test.go b/internal/protocol/http1/serializer_test.go index bdcf1903..7be780ed 100644 --- a/internal/protocol/http1/serializer_test.go +++ b/internal/protocol/http1/serializer_test.go @@ -20,12 +20,13 @@ import ( "github.com/indigo-web/indigo/http/proto" "github.com/indigo-web/indigo/internal/codecutil" "github.com/indigo-web/indigo/internal/construct" + respfields "github.com/indigo-web/indigo/internal/response" "github.com/indigo-web/indigo/kv" "github.com/indigo-web/indigo/transport/dummy" "github.com/stretchr/testify/require" ) -var noCodecs = codecutil.NewCache(nil) +var noCodecs = codecutil.NewCache(nil, "identity") func BenchmarkSerializer(b *testing.B) { getRequest := func(cfg *config.Config, m method.Method) *http.Request { @@ -34,9 +35,10 @@ func BenchmarkSerializer(b *testing.B) { return request } - getSerializer := func(cfg *config.Config, m method.Method) *serializer { + getSerializer := func(cfg *config.Config, m method.Method, codecs ...codec.Codec) *serializer { buff := make([]byte, 0, cfg.NET.WriteBufferSize.Default) - return newSerializer(cfg, getRequest(cfg, method.GET), new(dummy.NopClient), noCodecs, buff) + cache := codecutil.NewCache(codecs, codecutil.AcceptEncoding(codecs)) + return newSerializer(cfg, getRequest(cfg, method.GET), new(dummy.NopClient), cache, buff) } getResponseWithHeaders := func(n int) *http.Response { @@ -107,100 +109,52 @@ func BenchmarkSerializer(b *testing.B) { }) b.Run("stream", func(b *testing.B) { - b.Run("sized 512b", func(b *testing.B) { - content := strings.Repeat("a", 512) - resp := http.NewResponse() - s := getSerializer(config.Default(), method.GET) - b.SetBytes(512) - b.ReportAllocs() - b.ResetTimer() - - for range b.N { - _ = s.Write(proto.HTTP11, resp.String(content)) - } - }) - - b.Run("unsized 32x16", func(b *testing.B) { - r := &circularReader{ - n: 32, - data: []byte(strings.Repeat("a", 16)), - } - resp := http.NewResponse() - s := getSerializer(config.Default(), method.GET) - b.SetBytes(512) - b.ReportAllocs() - b.ResetTimer() - - for range b.N { - r.n = 32 - _ = s.Write(proto.HTTP11, resp.Stream(r)) - } - }) - - b.Run("unsized 8x64", func(b *testing.B) { - r := &circularReader{ - n: 8, - data: []byte(strings.Repeat("a", 64)), - } - resp := http.NewResponse() - s := getSerializer(config.Default(), method.GET) - b.SetBytes(512) - b.ReportAllocs() - b.ResetTimer() - - for range b.N { - r.n = 8 - _ = s.Write(proto.HTTP11, resp.Stream(r)) - } - }) - - b.Run("sized 262144b", func(b *testing.B) { - // 262144 = 16 * 16384 = 8 * 32768 - content := strings.Repeat("a", 262144) - resp := http.NewResponse() - s := getSerializer(config.Default(), method.GET) - b.SetBytes(262144) - b.ReportAllocs() - b.ResetTimer() - - for range b.N { - _ = s.Write(proto.HTTP11, resp.String(content)) - } - }) - - b.Run("unsized 16x16384", func(b *testing.B) { - r := &circularReader{ - n: 16, - data: []byte(strings.Repeat("a", 16384)), + benchSized := func(n int) func(b *testing.B) { + return func(b *testing.B) { + content := strings.Repeat("a", n) + resp := http.NewResponse() + s := getSerializer(config.Default(), method.GET) + b.SetBytes(int64(n)) + b.ReportAllocs() + b.ResetTimer() + + for range b.N { + _ = s.Write(proto.HTTP11, resp.String(content)) + } } - resp := http.NewResponse() - s := getSerializer(config.Default(), method.GET) - b.SetBytes(262144) - b.ReportAllocs() - b.ResetTimer() + } - for range b.N { - r.n = 16 - _ = s.Write(proto.HTTP11, resp.Stream(r)) - } - }) + benchUnsized := func(n, chunklen int, compression ...string) func(b *testing.B) { + return func(b *testing.B) { + compress := "" + if len(compression) > 0 { + compress = compression[0] + } - b.Run("unsized 8x32768", func(b *testing.B) { - r := &circularReader{ - n: 8, - data: []byte(strings.Repeat("a", 32768)), + r := &circularReader{ + n: n, + data: []byte(strings.Repeat("a", chunklen)), + } + resp := http.NewResponse() + s := getSerializer(config.Default(), method.GET, codec.NewGZIP()) + b.SetBytes(int64(n * chunklen)) + b.ReportAllocs() + b.ResetTimer() + + for range b.N { + r.n = n + _ = s.Write(proto.HTTP11, resp.Stream(r).Compression(compress)) + } } - resp := http.NewResponse() - s := getSerializer(config.Default(), method.GET) - b.SetBytes(262144) - b.ReportAllocs() - b.ResetTimer() + } - for range b.N { - r.n = 8 - _ = s.Write(proto.HTTP11, resp.Stream(r)) - } - }) + b.Run("sized 512b", benchSized(512)) + b.Run("unsized 32x16", benchUnsized(32, 16)) + b.Run("unsized 8x64", benchUnsized(8, 64)) + b.Run("sized 262144b", benchSized(262144)) + b.Run("unsized 16x16384", benchUnsized(16, 16384)) + b.Run("gzipped 16x16384", benchUnsized(16, 16384, "gzip")) + b.Run("unsized 8x32768", benchUnsized(8, 32768)) }) } @@ -401,7 +355,7 @@ func TestSerializer(t *testing.T) { t.Run("streams", func(t *testing.T) { request := newRequest(method.GET) - codecs := codecutil.NewCache([]codec.Codec{codec.NewGZIP()}) + codecs := codecutil.NewCache([]codec.Codec{codec.NewGZIP()}, codec.NewGZIP().Token()) s, w := getSerializer(nil, request, codecs) testSized := func(t *testing.T, method string, contentLength int, body string, contentEncoding ...string) { @@ -463,7 +417,7 @@ func TestSerializer(t *testing.T) { t.Run("unsized", func(t *testing.T) { w.Reset() request.Method = method.GET - resp := http.NewResponse().Stream(strings.NewReader(helloworld)) + resp := http.NewResponse().Stream(strings.NewReader(helloworld), -1) require.NoError(t, s.Write(proto.HTTP11, resp)) testUnsized(t, "GET", helloworld) }) @@ -479,7 +433,7 @@ func TestSerializer(t *testing.T) { t.Run("HEAD unsized", func(t *testing.T) { w.Reset() request.Method = method.HEAD - resp := http.NewResponse().Stream(strings.NewReader(helloworld)) + resp := http.NewResponse().Stream(strings.NewReader(helloworld), -1) require.NoError(t, s.Write(proto.HTTP11, resp)) testUnsized(t, "HEAD", "") }) @@ -489,7 +443,7 @@ func TestSerializer(t *testing.T) { t.Run("HEAD WriterTo", func(t *testing.T) { w.Reset() request.Method = method.HEAD - resp := http.NewResponse().SizedStream(strings.NewReader(helloworld), int64(len(helloworld))) + resp := http.NewResponse().Stream(strings.NewReader(helloworld)) require.NoError(t, s.Write(proto.HTTP11, resp)) testSized(t, "HEAD", len(helloworld), "") @@ -498,8 +452,11 @@ func TestSerializer(t *testing.T) { testGZIP := func(t *testing.T, resp *http.Response) { w.Reset() request.Method = method.GET + request.AcceptEncoding = []string{"gzip"} - require.NoError(t, s.Write(proto.HTTP11, resp.Compress("gzip"))) + // enforce gzip, because otherwise sized stream is unlikely to be compressed + // due to smallness. + require.NoError(t, s.Write(proto.HTTP11, resp.Compression("gzip"))) wantBody := encodeGZIP(helloworld) testUnsized(t, "GET", string(wantBody), "gzip") } @@ -509,13 +466,13 @@ func TestSerializer(t *testing.T) { }) t.Run("unsized GZIP", func(t *testing.T) { - testGZIP(t, http.NewResponse().Stream(strings.NewReader(helloworld))) + testGZIP(t, http.NewResponse().Stream(strings.NewReader(helloworld), -1)) }) t.Run("sized WriterTo", func(t *testing.T) { w.Reset() request.Method = method.GET - resp := http.NewResponse().SizedStream(strings.NewReader(helloworld), int64(len(helloworld))) + resp := http.NewResponse().Stream(strings.NewReader(helloworld)) require.NoError(t, s.Write(proto.HTTP11, resp)) testSized(t, "GET", len(helloworld), helloworld) @@ -530,12 +487,19 @@ func TestSerializer(t *testing.T) { return s, string(w.Written()) } + testResp := func(t *testing.T, resp string) { + // the point is to make sure the written response is a valid HTTP + _, err := parseHTTP11Response("GET", []byte(resp)) + require.NoError(t, err) + } + t.Run("fill the buffer exactly full", func(t *testing.T) { // estimate headers length, so we can know how many bytes of body we need to trigger the growth const buffsize = 128 s, defaultResponse := writeResp(t, http.NewResponse(), buffsize, config.Default()) want := strings.Repeat("a", cap(s.buff)-len(defaultResponse)-1) - s, _ = writeResp(t, http.NewResponse().String(want), buffsize, config.Default()) + s, resp := writeResp(t, http.NewResponse().String(want), buffsize, config.Default()) + testResp(t, resp) require.Equal(t, cap(s.buff), buffsize) }) @@ -543,7 +507,8 @@ func TestSerializer(t *testing.T) { const buffsize = 128 s, defaultResponse := writeResp(t, http.NewResponse(), buffsize, config.Default()) want := strings.Repeat("a", cap(s.buff)-len(defaultResponse)+1) - s, _ = writeResp(t, http.NewResponse().String(want), buffsize, config.Default()) + s, resp := writeResp(t, http.NewResponse().String(want), buffsize, config.Default()) + testResp(t, resp) require.Greater(t, cap(s.buff), buffsize) }) @@ -552,7 +517,8 @@ func TestSerializer(t *testing.T) { cfg := config.Default() cfg.NET.WriteBufferSize.Maximal = buffsize b := strings.Repeat("a", buffsize) - s, _ := writeResp(t, http.NewResponse().String(b), buffsize-1, cfg) + s, resp := writeResp(t, http.NewResponse().String(b), buffsize-1, cfg) + testResp(t, resp) wantBuffsize := cap(slices.Grow(make([]byte, buffsize-1), 1)) require.Equal(t, wantBuffsize, cap(s.buff)) }) @@ -561,9 +527,10 @@ func TestSerializer(t *testing.T) { t.Run("writer", func(t *testing.T) { t.Run("identity", func(t *testing.T) { - t.Run("flush", func(t *testing.T) { + t.Run("flush buffered", func(t *testing.T) { s, w := getSerializer(nil, newRequest(method.GET), noCodecs) s.buff = make([]byte, 0, 16) + s.response = &respfields.Fields{Buffered: true} writer := identityWriter{s} _, err := writer.Write(bytes.Repeat([]byte("a"), 10)) require.NoError(t, err) @@ -580,6 +547,7 @@ func TestSerializer(t *testing.T) { client := dummy.NewMockClient().Journaling() buff := make([]byte, 0, cfg.NET.WriteBufferSize.Default) s := newSerializer(cfg, newRequest(method.GET), client, codecs, buff) + s.response = http.NewResponse().Expose() return s, client } @@ -609,18 +577,11 @@ func TestSerializer(t *testing.T) { require.NoError(t, writer.Close()) } - t.Run("elide zerofill", func(t *testing.T) { - s, w := init(config.Default(), noCodecs) - encodeChunked(t, s, "Hello, ", "world!") - want := "7\r\nHello, \r\n6\r\nworld!\r\n0\r\n\r\n" - require.Equal(t, want, string(w.Written())) - }) - - t.Run("use zerofill", func(t *testing.T) { + t.Run("preserve prior data", func(t *testing.T) { s, w := init(config.Default(), noCodecs) s.buff = append(s.buff, "Foo! "...) encodeChunked(t, s, "Hello, ", "world!") - want := "Foo! 7;0\r\nHello, \r\n6\r\nworld!\r\n0\r\n\r\n" + want := "Foo! 007\r\nHello, \r\n006\r\nworld!\r\n0\r\n\r\n" require.Equal(t, want, string(w.Written())) }) @@ -640,15 +601,41 @@ func TestSerializer(t *testing.T) { cfg := config.Default() cfg.NET.WriteBufferSize.Default = writeBufferSize s, w := init(cfg, noCodecs) - encodeChunked(t, s, "Hello, world!") - want := "2\r\nHe\r\n9\r\nllo, worl\r\n2\r\nd!\r\n0\r\n\r\n" + encodeChunked2(t, s, "Hello, world!") + want := "2\r\nHe\r\n9\r\nllo, worl\r\n02\r\nd!\r\n0\r\n\r\n" require.Equal(t, want, string(w.Written())) }) t.Run("ReaderFrom", func(t *testing.T) { s, w := init(config.Default(), noCodecs) encodeChunked2(t, s, "Hello, ", "world!") - want := "7\r\nHello, \r\n6\r\nworld!\r\n0\r\n\r\n" + want := "007\r\nHello, \r\n006\r\nworld!\r\n0\r\n\r\n" + require.Equal(t, want, string(w.Written())) + }) + + t.Run("buffered", func(t *testing.T) { + const writeBufferSize = 32 + cfg := config.Default() + cfg.NET.WriteBufferSize.Default = writeBufferSize + cfg.NET.WriteBufferSize.Maximal = writeBufferSize + s, w := init(cfg, noCodecs) + s.response = &respfields.Fields{Buffered: true} + writer := chunkedWriter{s} + + writeChunks := func(t *testing.T, data ...string) { + for _, chunk := range data { + _, err := writer.Write([]byte(chunk)) + require.NoError(t, err) + } + } + + writeChunks(t, "a", "b", "c") + require.Empty(t, string(w.Written())) + + bigchunk := strings.Repeat("a", writeBufferSize) + writeChunks(t, bigchunk) + require.NoError(t, writer.Close()) + want := "01\r\na\r\n01\r\nb\r\n01\r\nc\r\n6\r\n" + bigchunk[:6] + "\r\n1a\r\n" + bigchunk[6:] + "\r\n0\r\n\r\n" require.Equal(t, want, string(w.Written())) }) }) diff --git a/internal/protocol/http1/suit.go b/internal/protocol/http1/suit.go index 99c84adc..9bbd4dae 100644 --- a/internal/protocol/http1/suit.go +++ b/internal/protocol/http1/suit.go @@ -97,7 +97,7 @@ func (s *Suit) serve(once bool) (ok bool) { request.Body.Reset(request) s.body.Reset(request) - transferEncoding := request.Encoding.Transfer + transferEncoding := request.TransferEncoding if !validateTransferEncodingTokens(transferEncoding) { resp := respond(request, s.router.OnError(request, status.ErrUnsupportedEncoding)) _ = s.Write(request.Protocol, resp) @@ -115,7 +115,7 @@ func (s *Suit) serve(once bool) (ok bool) { } } - if err = s.applyDecoders(request.Encoding.Content); err != nil { + if err = s.applyDecoders(request.ContentEncoding); err != nil { resp := respond(request, s.router.OnError(request, err)) _ = s.Write(request.Protocol, resp) return false @@ -192,6 +192,7 @@ func validateTransferEncodingTokens(tokens []string) bool { func (s *Suit) applyDecoders(tokens []string) error { request := s.Parser.request + bufferSize := s.Parser.cfg.NET.ReadBufferSize for i := len(tokens); i > 0; i-- { c := s.codecs.Get(tokens[i-1]) @@ -199,7 +200,7 @@ func (s *Suit) applyDecoders(tokens []string) error { return status.ErrUnsupportedEncoding } - if err := c.ResetDecompressor(request.Body.Fetcher); err != nil { + if err := c.ResetDecompressor(request.Body.Fetcher, bufferSize); err != nil { return status.ErrInternalServerError } diff --git a/internal/protocol/http1/suit_test.go b/internal/protocol/http1/suit_test.go index aac039ea..54e867ed 100644 --- a/internal/protocol/http1/suit_test.go +++ b/internal/protocol/http1/suit_test.go @@ -164,7 +164,7 @@ func getSuit(client transport.Client, codecs ...codec.Codec) (*Suit, *http.Reque cfg := config.Default() r := getInbuiltRouter() req := construct.Request(cfg, client) - suit := New(cfg, r, client, req, codecutil.NewCache(codecs)) + suit := New(cfg, r, client, req, codecutil.NewCache(codecs, codecutil.AcceptEncoding(codecs))) req.Body = http.NewBody(suit) return suit, req @@ -215,14 +215,10 @@ func TestSuit(t *testing.T) { request := construct.Request(config.Default(), dummy.NewNopClient()) request.Method = m request.Path = path - request.Headers = headers + request.Headers = headers.Add("Content-Length", strconv.Itoa(len(body))) - request.Encoding.Transfer = slices.Collect(headers.Values("Transfer-Encoding")) - request.Encoding.Content = slices.Collect(headers.Values("Content-Encoding")) - - if len(request.Encoding.Transfer) == 0 { - request.ContentLength = len(body) - } + request.TransferEncoding = slices.Collect(headers.Values("Transfer-Encoding")) + request.ContentEncoding = slices.Collect(headers.Values("Content-Encoding")) return serialize.Headers(request) + body } diff --git a/internal/response/fields.go b/internal/response/fields.go index 02993b64..af2cf788 100644 --- a/internal/response/fields.go +++ b/internal/response/fields.go @@ -10,8 +10,10 @@ import ( ) type Fields struct { + Buffered bool Code status.Code Status status.Status + AutoCompress bool ContentEncoding string Charset mime.Charset Stream io.Reader @@ -23,9 +25,10 @@ type Fields struct { func (f *Fields) Clear() { *f = Fields{ - Code: status.OK, - Buffer: f.Buffer[:0], - Headers: f.Headers[:0], - Cookies: f.Cookies[:0], + Code: status.OK, + Buffered: true, + Buffer: f.Buffer[:0], + Headers: f.Headers[:0], + Cookies: f.Cookies[:0], } } diff --git a/internal/strutil/helpers.go b/internal/strutil/helpers.go index 188a7eb8..8c3ebe31 100644 --- a/internal/strutil/helpers.go +++ b/internal/strutil/helpers.go @@ -5,7 +5,6 @@ import "strings" func LStripWS(str string) string { for i, c := range str { switch c { - // TODO: consider adding more whitespace characters? case ' ', '\t': default: return str[i:] @@ -31,6 +30,29 @@ func CutHeader(header string) (value, params string) { return header[:sep], LStripWS(header[sep+1:]) } +// ParseQualifier returns an int in range [0, 10] representing the qualifier value. +// All values below the 0.1 resolution are ignored. Invalid values result in 0. But +// keep in mind that 0 is also a valid value. +func ParseQualifier(q string) int { + const sampleQualifier = "q=p.q" + + if len(q) < len(sampleQualifier) { + return 0 + } + + // ignore all values below the 0.1 resolution + qualifier := ctoi(q[2])*10 + ctoi(q[4]) + if qualifier < 0 || qualifier > 10 { + qualifier = 0 + } + + return qualifier +} + +func ctoi(char byte) int { + return int(char - '0') +} + func Unquote(str string) string { if len(str) > 1 && str[0] == '"' && str[len(str)-1] == '"' { return str[1 : len(str)-1] @@ -38,3 +60,7 @@ func Unquote(str string) string { return str } + +func IsASCIINonprintable(c byte) bool { + return c < 0x20 || c > 0x7e +} diff --git a/internal/strutil/join.go b/internal/strutil/join.go deleted file mode 100644 index 953ab902..00000000 --- a/internal/strutil/join.go +++ /dev/null @@ -1,22 +0,0 @@ -package strutil - -import ( - "iter" - "strings" -) - -// Join works in the same way as the strings.Join does, except that it operates an iterator -// as opposed to greedy string slice. -func Join(elems iter.Seq[string], sep string) string { - var b strings.Builder - - for elem := range elems { - if b.Len() > 0 { - b.WriteString(sep) - } - - b.WriteString(elem) - } - - return b.String() -} diff --git a/internal/strutil/join_test.go b/internal/strutil/join_test.go deleted file mode 100644 index c96616f0..00000000 --- a/internal/strutil/join_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package strutil - -import ( - "iter" - "testing" - - "github.com/stretchr/testify/require" -) - -func asIterator(elems ...string) iter.Seq[string] { - return func(yield func(string) bool) { - for _, elem := range elems { - yield(elem) - } - } -} - -func TestJoin(t *testing.T) { - str := Join(asIterator(), ", ") - require.Empty(t, str) - - str = Join(asIterator("hello"), ", ") - require.Equal(t, "hello", str) - - str = Join(asIterator("hello", "world"), ", ") - require.Equal(t, "hello, world", str) - - str = Join(asIterator("hello", "world", "as usual"), ", ") - require.Equal(t, "hello, world, as usual", str) -} diff --git a/internal/strutil/url.go b/internal/strutil/url.go new file mode 100644 index 00000000..7e3f6416 --- /dev/null +++ b/internal/strutil/url.go @@ -0,0 +1,54 @@ +package strutil + +import ( + "strings" + + "github.com/indigo-web/indigo/internal/hexconv" +) + +// IsURLUnsafeChar tells whether it's safe to decode an urlencoded character. +func IsURLUnsafeChar(c byte) bool { + return c == '/' +} + +// URLDecode decodes an urlencoded string and tells whether the string was properly formed. +func URLDecode(str string) (string, bool) { + var b strings.Builder + b.Grow(len(str)) + s := str + + for len(s) > 0 { + percent := strings.IndexByte(s, '%') + if percent == -1 { + break + } + + b.WriteString(s[:percent]) + s = s[percent+1:] + if len(s) < 2 { + return "", false + } + + c1, c2 := s[0], s[1] + s = s[2:] + x, y := hexconv.Halfbyte[c1], hexconv.Halfbyte[c2] + if x|y == 0xFF { + return "", false + } + + char := (x << 4) | y + if IsASCIINonprintable(char) { + return "", false + } + if IsURLUnsafeChar(char) { + b.Write([]byte{'%', c1 | 0x20, c2 | 0x20}) + continue + } + + b.WriteByte(char) + } + + b.WriteString(s) + + return b.String(), true +} diff --git a/internal/strutil/url_test.go b/internal/strutil/url_test.go new file mode 100644 index 00000000..91ef51b8 --- /dev/null +++ b/internal/strutil/url_test.go @@ -0,0 +1,33 @@ +package strutil + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestURLDecode(t *testing.T) { + t.Run("base", func(t *testing.T) { + res, ok := URLDecode("%61") + require.True(t, ok) + require.Equal(t, "a", res) + + for i, tc := range []string{"abc", "%61bc", "a%62c", "ab%63", "%61%62%63"} { + res, ok = URLDecode(tc) + require.True(t, ok, i) + require.Equal(t, "abc", res, i) + } + }) + + t.Run("unsafe char", func(t *testing.T) { + res, ok := URLDecode("%61%2f") + require.True(t, ok) + require.Equal(t, "a%2f", res) + }) + + t.Run("unsafe normalization", func(t *testing.T) { + res, ok := URLDecode("%61%2F") + require.True(t, ok) + require.Equal(t, "a%2f", res) + }) +} diff --git a/internal/timer/timer.go b/internal/timer/timer.go index dbc53c05..4a2a7084 100644 --- a/internal/timer/timer.go +++ b/internal/timer/timer.go @@ -24,8 +24,8 @@ func init() { go func() { for { - Time.Store(time.Now().UnixMilli()) time.Sleep(Resolution) + Time.Store(time.Now().UnixMilli()) } }() } diff --git a/kv/storage.go b/kv/storage.go index 86d5b305..681c93e0 100644 --- a/kv/storage.go +++ b/kv/storage.go @@ -54,11 +54,18 @@ func NewFromPairs(pairs []Pair) *Storage { // Add adds a new pair of key and value. func (s *Storage) Add(key, value string) *Storage { - // TODO: if `s.deleted` > 0, try to insert the key into the deleted fields - s.pairs = append(s.pairs, Pair{ - Key: key, - Value: value, - }) + if s.deleted == 0 { + s.pairs = append(s.pairs, Pair{key, value}) + return s + } + + for i, pair := range s.pairs { + if len(pair.Key) == 0 { + s.pairs[i] = Pair{key, value} + break + } + } + return s } diff --git a/router/inbuilt/aliases.go b/router/inbuilt/aliases.go deleted file mode 100644 index c8445ce3..00000000 --- a/router/inbuilt/aliases.go +++ /dev/null @@ -1,16 +0,0 @@ -package inbuilt - -import ( - "path" - - "github.com/indigo-web/indigo/http/method" - "github.com/indigo-web/indigo/router/inbuilt/mutator" -) - -// Alias makes an implicitly redirects to other endpoint by changing request path -// before a handler is called. In case of implicit redirect, original path is stored in -// Request.Env.AliasFrom. Optionally request methods can be set, such that only requests -// with those methods will be aliased. -func (r *Router) Alias(from, to string, forMethods ...method.Method) *Router { - return r.Mutator(mutator.Alias(path.Join(r.prefix, from), to, forMethods...)) -} diff --git a/router/inbuilt/defaulterrhandlers.go b/router/inbuilt/defaulterrhandlers.go deleted file mode 100644 index 3463e264..00000000 --- a/router/inbuilt/defaulterrhandlers.go +++ /dev/null @@ -1,23 +0,0 @@ -package inbuilt - -import ( - "github.com/indigo-web/indigo/http" - "github.com/indigo-web/indigo/http/status" -) - -func newErrorHandlers() errorHandlers { - return errorHandlers{ - AllErrors: defaultAllErrorsHandler, - status.MethodNotAllowed: defaultMethodNotAllowedHandler, - } -} - -func defaultAllErrorsHandler(request *http.Request) *http.Response { - return request.Respond().Error(request.Env.Error) -} - -func defaultMethodNotAllowedHandler(request *http.Request) *http.Response { - return request.Respond(). - Error(status.ErrMethodNotAllowed). - Header("Allow", request.Env.AllowedMethods) -} diff --git a/router/inbuilt/errorhandlers.go b/router/inbuilt/errorhandlers.go new file mode 100644 index 00000000..1e51653c --- /dev/null +++ b/router/inbuilt/errorhandlers.go @@ -0,0 +1,28 @@ +package inbuilt + +import ( + "github.com/indigo-web/indigo/http" + "github.com/indigo-web/indigo/http/method" + "github.com/indigo-web/indigo/http/status" +) + +func newErrorHandlers() errorHandlers { + return errorHandlers{ + AllErrors: genericErrorHandler, + status.MethodNotAllowed: generic405Handler, + } +} + +func genericErrorHandler(request *http.Request) *http.Response { + return http.Error(request, request.Env.Error) +} + +func generic405Handler(request *http.Request) *http.Response { + resp := request.Respond().Header("Allow", request.Env.AllowedMethods) + + if request.Method != method.OPTIONS { + resp.Code(status.MethodNotAllowed) + } + + return resp +} diff --git a/router/inbuilt/groups.go b/router/inbuilt/groups.go deleted file mode 100644 index 63b86ed5..00000000 --- a/router/inbuilt/groups.go +++ /dev/null @@ -1,40 +0,0 @@ -package inbuilt - -/* -This file is responsible for endpoint groups -*/ - -// Group creates a new router with pre-defined prefix for all paths. It'll automatically be -// merged into the head router on server start. Middlewares, applied on this router, will not -// affect the head router, but initially head router's middlewares will be inherited and will -// be called in the first order. Registering new error handlers will result in affecting error -// handlers among ALL the existing groups, including head router -func (r *Router) Group(prefix string) *Router { - subrouter := &Router{ - prefix: r.prefix + prefix, - registrar: newRegistrar(), - errHandlers: r.errHandlers, - } - - r.children = append(r.children, subrouter) - - return subrouter -} - -func (r *Router) prepare() error { - for _, child := range r.children { - if err := child.prepare(); err != nil { - return err - } - - if err := r.registrar.Merge(child.registrar); err != nil { - return err - } - - r.mutators = append(r.mutators, child.mutators...) - } - - r.applyMiddlewares() - - return nil -} diff --git a/router/inbuilt/inbuilt.go b/router/inbuilt/inbuilt.go index 45c94e39..cf8ccb9c 100644 --- a/router/inbuilt/inbuilt.go +++ b/router/inbuilt/inbuilt.go @@ -1,52 +1,184 @@ package inbuilt import ( + "path" + "github.com/indigo-web/indigo/http" "github.com/indigo-web/indigo/http/method" "github.com/indigo-web/indigo/http/status" "github.com/indigo-web/indigo/router" - "github.com/indigo-web/indigo/router/inbuilt/internal/radix" + "github.com/indigo-web/indigo/router/inbuilt/internal" + "github.com/indigo-web/indigo/router/inbuilt/mutator" "github.com/indigo-web/indigo/router/inbuilt/uri" ) +// Middleware works like a chain of nested calls, next may be even directly +// handler. But if we are not a closing middleware, we will call next +// middleware that is simply a partial middleware with already provided next +type Middleware func(next Handler, request *http.Request) *http.Response + var _ router.Builder = new(Router) -// Router is a built-in routing entity. It provides support for all the methods defined in -// the methods package, including shortcuts for those. It also supports dynamic routing -// (enabled automatically if dynamic path template is registered; otherwise more performant -// static-routing implementation is used). It also provides custom error handlers for any -// HTTP error that may occur during parsing the request or the routing of it by itself. -// By default, TRACE requests are supported (if no handler is attached, the request will be -// automatically processed), OPTIONS (including server-wide ones) and 405 Method Not Allowed -// errors in compliance with their HTTP semantics. +// Router is a recommended router for indigo. It features groups, middlewares, pre-middlewares, +// resources, automatic OPTIONS and TRACE response capabilities and dynamic routing (enabled +// automatically if any of routes is dynamic, otherwise more efficient map-based static routing +// is used.) type Router struct { - isRoot bool - prefix string - mutators []Mutator - middlewares []Middleware - registrar *registrar - children []*Router - errHandlers errorHandlers + enableTRACE bool + prefix string + mutators []Mutator + middlewares []Middleware + registrar *registrar + children []*Router + traceHandler Handler + errHandlers errorHandlers } // New constructs a new instance of inbuilt router func New() *Router { return &Router{ - isRoot: true, registrar: newRegistrar(), errHandlers: newErrorHandlers(), } } -// runtimeRouter is the actual router that'll be running. The reason to separate Router from runtimeRouter -// is the fact, that there is a lot of data that is used only at registering/initialization stage. +// AllErrors tells the Router.RouteError to use the passed error handler as a generic +// handler. A generic error handler is usually called only if no other was matched. +const AllErrors = status.Code(0) + +// Route registers a new endpoint. +func (r *Router) Route(method method.Method, path string, handler Handler, middlewares ...Middleware) *Router { + err := r.registrar.Add(r.prefix+path, method, compose(handler, middlewares)) + if err != nil { + panic(err) + } + + return r +} + +// TODO: update the error handling mechanism. It's way too tedious + +// RouteError adds an error handler for a corresponding HTTP error code. +// +// The following error codes may be registered: +// - AllErrors (called only if no other error handlers found) +// - status.BadRequest +// - status.NotFound +// - status.MethodNotAllowed +// - status.RequestEntityTooLarge +// - status.CloseConnection +// - status.RequestURITooLong +// - status.HeaderFieldsTooLarge +// - status.HTTPVersionNotSupported +// - status.UnsupportedMediaType +// - status.NotImplemented +// - status.RequestTimeout +// +// Note: if handler returned one of error codes above, error handler WON'T be called. +// Also, global middlewares, applied to the root router, will also be used for error handlers. +// However, global middlewares defined on groups won't be used. +// +// WARNING: calling this method from groups will affect ALL routers, including root +func (r *Router) RouteError(handler Handler, codes ...status.Code) *Router { + if len(codes) == 0 { + codes = append(codes, AllErrors) + } + + for _, code := range codes { + r.errHandlers[code] = handler + } + + return r +} + +// Use registers a new middleware in the group. +func (r *Router) Use(middlewares ...Middleware) *Router { + r.middlewares = append(r.middlewares, middlewares...) + return r +} + +func (r *Router) applyMiddlewares() { + r.registrar.Apply(func(handler Handler) Handler { + return compose(handler, r.middlewares) + }) +} + +// Group creates a subrouter with its own scoping and path prefix. The scoping affects mainly +// middleware application rules: a new group inherits its parental middlewares, but middlewares, +// registered on the group, don't affect its parents ones. Parent middlewares are chained first, +// therefore will also be called earlier than middlewares registered directly on the group. +func (r *Router) Group(prefix string) *Router { + subrouter := &Router{ + prefix: r.prefix + prefix, + registrar: newRegistrar(), + errHandlers: r.errHandlers, + } + + r.children = append(r.children, subrouter) + + return subrouter +} + +func (r *Router) prepare() error { + for _, child := range r.children { + if err := child.prepare(); err != nil { + return err + } + + if err := r.registrar.Merge(child.registrar); err != nil { + return err + } + + r.mutators = append(r.mutators, child.mutators...) + } + + r.applyMiddlewares() + + return nil +} + +// Resource returns a new Resource object for a provided resource path. +func (r *Router) Resource(path string) Resource { + return Resource{ + group: r.Group(path), + } +} + +// Alias is an implicit redirect, made absolutely transparently before a specific handler is chosen. +// The original path is stored in Request.Env.AliasFrom. Optionally only specific methods can be set +// to be aliased. Otherwise, ANY requests matching alias will be aliased, which might not always be +// the desired behavior. +func (r *Router) Alias(from, to string, forMethods ...method.Method) *Router { + return r.Mutator(mutator.Alias(path.Join(r.prefix, from), to, forMethods...)) +} + +type Mutator = internal.Mutator + +// Mutator adds a new Mutator. Please note that groups scoping rules don't apply on them, only the +// execution order is affected. +func (r *Router) Mutator(mutator Mutator) *Router { + r.mutators = append(r.mutators, mutator) + return r +} + +// EnableTRACE allows the router to automatically respond to TRACE requests if there is no +// matching handler registered. To explore why it's better to keep the option disabled, see +// https://owasp.org/www-community/attacks/Cross_Site_Tracing +func (r *Router) EnableTRACE(flag bool) *Router { + r.enableTRACE = flag + return r +} + +// runtimeRouter is a compiled router. Router represents a "dummy" builder, while the actual +// action happens here. type runtimeRouter struct { - mutators []Mutator - traceBuff []byte - tree *radix.Node[endpoint] - routesMap routesMap - errHandlers errorHandlers - isStatic bool + enableTRACE bool + isStatic bool + tree radixTree + routesMap routesMap + errHandlers errorHandlers + serverOptions string + mutators []Mutator } func (r *Router) Build() router.Router { @@ -59,7 +191,7 @@ func (r *Router) Build() router.Router { isDynamic := r.registrar.IsDynamic() var ( rmap routesMap - tree *radix.Node[endpoint] + tree radixTree ) if isDynamic { tree = r.registrar.AsRadixTree() @@ -68,20 +200,20 @@ func (r *Router) Build() router.Router { } return &runtimeRouter{ - mutators: r.mutators, - tree: tree, - routesMap: rmap, - errHandlers: r.errHandlers, - isStatic: !isDynamic, + enableTRACE: r.enableTRACE, + isStatic: !isDynamic, + tree: tree, + routesMap: rmap, + errHandlers: r.errHandlers, + serverOptions: r.registrar.Options(r.enableTRACE), + mutators: r.mutators, } } // OnRequest processes the request func (r *runtimeRouter) OnRequest(request *http.Request) *http.Response { - r.runMutators(request) - - // TODO: should path normalization be implemented as a mutator? request.Path = uri.Normalize(request.Path) + r.runMutators(request) return r.onRequest(request) } @@ -120,10 +252,16 @@ func (r *runtimeRouter) OnError(request *http.Request, err error) *http.Response } func (r *runtimeRouter) onError(request *http.Request, err error) *http.Response { - if request.Method == method.TRACE && err == status.ErrMethodNotAllowed { - r.traceBuff = renderHTTPRequest(request, r.traceBuff) - - return traceResponse(request.Respond(), r.traceBuff) + switch { + case request.Method == method.OPTIONS && request.Path == "*": // server-wide options + return request.Respond().Header("Allow", r.serverOptions) + case request.Method == method.TRACE: + if !r.enableTRACE { + err = status.ErrMethodNotAllowed + break + } + + return traceHandler(request) } httpErr, ok := err.(status.HTTPError) @@ -146,8 +284,8 @@ func (r *runtimeRouter) onError(request *http.Request, err error) *http.Response } func (r *runtimeRouter) runMutators(request *http.Request) { - for _, mutator := range r.mutators { - mutator(request) + for _, mut := range r.mutators { + mut(request) } } @@ -166,6 +304,19 @@ func (r *Router) applyErrorHandlersMiddlewares() { } } +// compose produces an array of middlewares into the chain, represented by types.Handler +func compose(handler Handler, middlewares []Middleware) Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + handler = func(handler Handler, middleware Middleware) Handler { + return func(request *http.Request) *http.Response { + return middleware(handler, request) + } + }(handler, middlewares[i]) + } + + return handler +} + // getHandler looks up for a handler in the methodsMap. In case request method is HEAD, however // no matching handler is found, a handler for corresponding GET request will be retrieved func getHandler(reqMethod method.Method, mlut methodLUT) Handler { @@ -176,8 +327,3 @@ func getHandler(reqMethod method.Method, mlut methodLUT) Handler { return handler } - -// TODO: implement responding on such requests with a global list of all the available methods -func isServerWideOptions(req *http.Request) bool { - return req.Method == method.OPTIONS && req.Path == "*" -} diff --git a/router/inbuilt/inbuilt_test.go b/router/inbuilt/inbuilt_test.go index d76e9e52..61128fb7 100644 --- a/router/inbuilt/inbuilt_test.go +++ b/router/inbuilt/inbuilt_test.go @@ -8,16 +8,14 @@ import ( "testing" "github.com/indigo-web/indigo/config" + "github.com/indigo-web/indigo/http" + "github.com/indigo-web/indigo/http/method" + "github.com/indigo-web/indigo/http/status" "github.com/indigo-web/indigo/internal/construct" + "github.com/indigo-web/indigo/kv" "github.com/indigo-web/indigo/router" "github.com/indigo-web/indigo/transport/dummy" "github.com/stretchr/testify/assert" - - "github.com/indigo-web/indigo/http" - - "github.com/indigo-web/indigo/http/status" - - "github.com/indigo-web/indigo/http/method" "github.com/stretchr/testify/require" ) @@ -27,7 +25,15 @@ func readbody(t *testing.T, r io.Reader) string { return string(data) } -func BenchmarkRouter_OnRequest_Static(b *testing.B) { +func getRequest(m method.Method, path string) *http.Request { + request := construct.Request(config.Default(), dummy.NewNopClient()) + request.Method = m + request.Path = path + + return request +} + +func BenchmarkStatic(b *testing.B) { raw := New() GETRootRequest := getRequest(method.GET, "/") @@ -80,11 +86,11 @@ func BenchmarkRouter_OnRequest_Static(b *testing.B) { } func TestRoute(t *testing.T) { - raw := New() - raw.Route(method.GET, "/", http.Respond) - raw.Route(method.POST, "/", http.Respond) - raw.Route(method.POST, "/hello", http.Respond) - r := raw.Build() + r := New(). + Route(method.GET, "/", http.Respond). + Route(method.POST, "/", http.Respond). + Route(method.POST, "/hello", http.Respond). + Build() t.Run("GET /", func(t *testing.T) { request := getRequest(method.GET, "/") @@ -112,100 +118,197 @@ func TestRoute(t *testing.T) { require.Equal(t, status.OK, resp.Expose().Code) require.Nil(t, resp.Expose().Stream) }) + + t.Run("OPTIONS", func(t *testing.T) { + testOPTIONS := func(path, wantAllow string, enableTRACE bool) func(t *testing.T) { + return func(t *testing.T) { + r := New(). + EnableTRACE(enableTRACE). + Get("/", http.Respond). + Post("/", http.Respond). + Get("/hello", http.Respond). + Build() + + req := getRequest(method.OPTIONS, path) + resp := r.OnRequest(req) + require.Equal(t, 200, int(resp.Expose().Code)) + headers := kv.NewFromPairs(resp.Expose().Headers) + require.Equal(t, wantAllow, headers.Value("Allow")) + } + } + + t.Run("on specific endpoint", testOPTIONS("/", "GET, HEAD, POST", false)) + t.Run("server-wide", testOPTIONS("*", "GET, HEAD, OPTIONS", false)) + t.Run("server-wide with TRACE", testOPTIONS("*", "GET, HEAD, OPTIONS, TRACE", true)) + }) + + t.Run("escaping", func(t *testing.T) { + testRoute := func(r router.Router, path string) func(t *testing.T) { + return func(t *testing.T) { + request := getRequest(method.GET, path) + resp := r.OnRequest(request) + require.Equal(t, 200, int(resp.Expose().Code)) + } + } + + newRouter := func(dynamic bool) router.Router { + r := New(). + Get("/foo%2fbar", http.Respond). + Get("/foo%3abar", http.Respond) + + if dynamic { + r.Get("/unreachable route/:", http.Respond) + } + + return r.Build() + } + + test := func(r router.Router) func(t *testing.T) { + return func(t *testing.T) { + t.Run("escaped slash", testRoute(r, "/foo%2fbar")) + t.Run("unescaped colon", testRoute(r, "/foo:bar")) + } + } + + t.Run("static", test(newRouter(false))) + t.Run("dynamic", test(newRouter(true))) + }) } func TestDynamic(t *testing.T) { - t.Run("first level", func(t *testing.T) { - raw := New(). - Get("/:name", func(request *http.Request) *http.Response { - return http.String(request, request.Vars.Value("name")) + testDynamic := func(t *testing.T, path, want, key string, templates ...string) { + r := New() + + for _, template := range templates { + r.Get(template, func(request *http.Request) *http.Response { + return request.Respond().Status(request.Vars.Value(key)) }) - r := raw.Build() + } - request := getRequest(method.GET, "/hello") - resp := r.OnRequest(request) - require.Equal(t, "hello", readbody(t, resp.Expose().Stream)) + request := getRequest(method.GET, path) + resp := r.Build().OnRequest(request) + require.Equal(t, want, resp.Expose().Status) + } + + t.Run("base", func(t *testing.T) { + testDynamic(t, "/Pavlo", "Pavlo", "name", "/", "/:name", "/hello") + testDynamic(t, "/user/123/edit", "123", "id", "/user/:id/edit") }) - t.Run("second level", func(t *testing.T) { - raw := New(). - Get("/hello/:name", func(request *http.Request) *http.Response { - return http.String(request, request.Vars.Value("name")) - }) - r := raw.Build() + t.Run("anonymous", func(t *testing.T) { + testDynamic(t, "/Pavlo", "", "", "/:") + testDynamic(t, "/Pavlo/hello", "", "", "/:/hello") + }) - request := getRequest(method.GET, "/hello/pavlo") - resp := r.OnRequest(request) - require.Equal(t, "pavlo", readbody(t, resp.Expose().Stream)) + t.Run("prefix", func(t *testing.T) { + testDynamic(t, "/user123", "123", "id", "/user:id", "/user:id/edit") + testDynamic(t, "/user123/edit", "123", "id", "/user:id", "/user:id/edit") }) +} - t.Run("in the middle", func(t *testing.T) { - r := New(). - Get("/api/:method/doc", func(request *http.Request) *http.Response { - return http.String(request, request.Vars.Value("method")) - }). - Build() +func TestMethodShorthands(t *testing.T) { + r := New() + testShorthand := func( + t *testing.T, router *Router, route func(string, Handler, ...Middleware) *Router, method method.Method, + ) { + route("/", http.Respond) + require.Contains(t, router.registrar.endpoints, "/") + require.NotNil(t, router.registrar.endpoints["/"][method]) + } - request := getRequest(method.GET, "/api/getUser/doc") - resp := r.OnRequest(request) - require.Equal(t, "getUser", readbody(t, resp.Expose().Stream)) - }) + testShorthand(t, r, r.Get, method.GET) + testShorthand(t, r, r.Head, method.HEAD) + testShorthand(t, r, r.Post, method.POST) + testShorthand(t, r, r.Put, method.PUT) + testShorthand(t, r, r.Delete, method.DELETE) + testShorthand(t, r, r.Connect, method.CONNECT) + testShorthand(t, r, r.Options, method.OPTIONS) + testShorthand(t, r, r.Trace, method.TRACE) + testShorthand(t, r, r.Patch, method.PATCH) + testShorthand(t, r, r.Mkcol, method.MKCOL) + testShorthand(t, r, r.Move, method.MOVE) + testShorthand(t, r, r.Copy, method.COPY) + testShorthand(t, r, r.Lock, method.LOCK) + testShorthand(t, r, r.Unlock, method.UNLOCK) + testShorthand(t, r, r.Propfind, method.PROPFIND) + testShorthand(t, r, r.Proppatch, method.PROPPATCH) +} - t.Run("anonymous section", func(t *testing.T) { - r := New(). - Get("/:", func(request *http.Request) *http.Response { - return http.String(request, "yay") - }). - Build() +type callstack struct { + chain []int +} - request := getRequest(method.GET, "/api") - resp := r.OnRequest(request) - require.Equal(t, "yay", readbody(t, resp.Expose().Stream)) - request = getRequest(method.GET, "/api/second-level") - resp = r.OnRequest(request) - require.Equal(t, int(status.NotFound), int(resp.Expose().Code)) - }) +func (c *callstack) Push(middleware int) { + c.chain = append(c.chain, middleware) } -func testMethodShorthand( - t *testing.T, router *Router, - route func(string, Handler, ...Middleware) *Router, - method method.Method, -) { - route("/", http.Respond) - require.Contains(t, router.registrar.endpoints, "/") - require.NotNil(t, router.registrar.endpoints["/"][method]) +func (c *callstack) Chain() []int { + return c.chain } -func TestMethodShorthands(t *testing.T) { - r := New() +func (c *callstack) Clear() { + c.chain = c.chain[:0] +} - t.Run("GET", func(t *testing.T) { - testMethodShorthand(t, r, r.Get, method.GET) - }) - t.Run("HEAD", func(t *testing.T) { - testMethodShorthand(t, r, r.Head, method.HEAD) - }) - t.Run("POST", func(t *testing.T) { - testMethodShorthand(t, r, r.Post, method.POST) - }) - t.Run("PUT", func(t *testing.T) { - testMethodShorthand(t, r, r.Put, method.PUT) - }) - t.Run("DELETE", func(t *testing.T) { - testMethodShorthand(t, r, r.Delete, method.DELETE) - }) - t.Run("CONNECT", func(t *testing.T) { - testMethodShorthand(t, r, r.Connect, method.CONNECT) - }) - t.Run("OPTIONS", func(t *testing.T) { - testMethodShorthand(t, r, r.Options, method.OPTIONS) +func getMiddleware(mware int, stack *callstack) Middleware { + return func(next Handler, request *http.Request) *http.Response { + stack.Push(mware) + + return next(request) + } +} + +func TestMiddlewares(t *testing.T) { + const ( + m1 int = iota + 1 + m2 + m3 + m4 + m5 + m6 + m7 + ) + + stack := new(callstack) + raw := New(). + Use(getMiddleware(m1, stack)). + Get("/", http.Respond, getMiddleware(m2, stack)) + + api := raw.Group("/api"). + Use(getMiddleware(m3, stack)) + + api.Group("/v1"). + Get("/hello", http.Respond, getMiddleware(m6, stack)). + Use(getMiddleware(m4, stack)) + + api.Group("/v2"). + Use(getMiddleware(m5, stack)). + Get("/world", http.Respond, getMiddleware(m7, stack)) + + r := raw.Build() + + t.Run("/", func(t *testing.T) { + request := getRequest(method.GET, "/") + response := r.OnRequest(request) + require.Equal(t, status.OK, response.Expose().Code) + require.Equal(t, []int{m1, m2}, stack.Chain()) + stack.Clear() }) - t.Run("TRACE", func(t *testing.T) { - testMethodShorthand(t, r, r.Trace, method.TRACE) + + t.Run("/api/v1/hello", func(t *testing.T) { + request := getRequest(method.GET, "/api/v1/hello") + response := r.OnRequest(request) + require.Equal(t, status.OK, response.Expose().Code) + require.Equal(t, []int{m1, m3, m4, m6}, stack.Chain()) + stack.Clear() }) - t.Run("PATCH", func(t *testing.T) { - testMethodShorthand(t, r, r.Patch, method.PATCH) + + t.Run("/api/v2/world", func(t *testing.T) { + request := getRequest(method.GET, "/api/v2/world") + response := r.OnRequest(request) + require.Equal(t, status.OK, response.Expose().Code) + require.Equal(t, []int{m1, m3, m5, m7}, stack.Chain()) + stack.Clear() }) } @@ -241,62 +344,52 @@ func TestResource(t *testing.T) { r := raw.Build() - t.Run("Root", func(t *testing.T) { + t.Run("root", func(t *testing.T) { require.Equal(t, status.OK, r.OnRequest(getRequest(method.GET, "/")).Expose().Code) require.Equal(t, status.OK, r.OnRequest(getRequest(method.POST, "/")).Expose().Code) }) - t.Run("Group", func(t *testing.T) { + t.Run("on group", func(t *testing.T) { require.Equal(t, status.OK, r.OnRequest(getRequest(method.GET, "/api/stat")).Expose().Code) require.Equal(t, status.OK, r.OnRequest(getRequest(method.POST, "/api/stat")).Expose().Code) }) -} -func TestResource_Methods(t *testing.T) { - echoMethod := func(req *http.Request) *http.Response { - return req.Respond().Status(req.Method.String()) - } + t.Run("shorthands", func(t *testing.T) { + echoMethod := func(req *http.Request) *http.Response { + return req.Respond().Status(req.Method.String()) + } - raw := New() - raw.Resource("/"). - Get(echoMethod). - Head(echoMethod). - Post(echoMethod). - Put(echoMethod). - Delete(echoMethod). - Connect(echoMethod). - Options(echoMethod). - Trace(echoMethod). - Patch(echoMethod). - Mkcol(echoMethod). - Move(echoMethod). - Copy(echoMethod). - Lock(echoMethod). - Unlock(echoMethod). - Propfind(echoMethod). - Proppatch(echoMethod) + raw := New() + raw.Resource("/"). + Get(echoMethod). + Head(echoMethod). + Post(echoMethod). + Put(echoMethod). + Delete(echoMethod). + Connect(echoMethod). + Options(echoMethod). + Trace(echoMethod). + Patch(echoMethod). + Mkcol(echoMethod). + Move(echoMethod). + Copy(echoMethod). + Lock(echoMethod). + Unlock(echoMethod). + Propfind(echoMethod). + Proppatch(echoMethod) - r := raw.Build() + r := raw.Build() - for _, m := range method.List { - resp := r.OnRequest(getRequest(m, "/")).Expose() - if assert.Equal(t, int(status.OK), int(resp.Code)) { - assert.Equal(t, m.String(), resp.Status) + for _, m := range method.List { + resp := r.OnRequest(getRequest(m, "/")).Expose() + if assert.Equal(t, int(status.OK), int(resp.Code)) { + assert.Equal(t, m.String(), resp.Status) + } } - } -} - -func TestRouter_MethodNotAllowed(t *testing.T) { - r := New(). - Get("/", http.Respond). - Build() - - request := getRequest(method.POST, "/") - response := r.OnRequest(request) - require.Equal(t, status.MethodNotAllowed, response.Expose().Code) + }) } -func TestRouter_RouteError(t *testing.T) { +func TestRouterErrors(t *testing.T) { r := New(). RouteError(func(req *http.Request) *http.Response { return req.Respond(). @@ -305,14 +398,24 @@ func TestRouter_RouteError(t *testing.T) { }, status.BadRequest). Build() - t.Run("status.ErrBadRequest", func(t *testing.T) { + t.Run("ErrMethodNowAllowed", func(t *testing.T) { + r := New(). + Get("/", http.Respond). + Build() + + request := getRequest(method.POST, "/") + response := r.OnRequest(request) + require.Equal(t, status.MethodNotAllowed, response.Expose().Code) + }) + + t.Run("ErrBadRequest", func(t *testing.T) { request := getRequest(method.GET, "/") resp := r.OnError(request, status.ErrBadRequest) require.Equal(t, status.Teapot, resp.Expose().Code) require.Equal(t, status.ErrBadRequest.Error(), readbody(t, resp.Expose().Stream)) }) - t.Run("status.ErrURIDecoding (also bad request)", func(t *testing.T) { + t.Run("ErrURIDecoding (also bad request)", func(t *testing.T) { request := getRequest(method.GET, "/") resp := r.OnError(request, status.ErrURLDecoding) require.Equal(t, status.Teapot, resp.Expose().Code) diff --git a/router/inbuilt/internal/radix/tree.go b/router/inbuilt/internal/radix/tree.go index 720a091b..11960b13 100644 --- a/router/inbuilt/internal/radix/tree.go +++ b/router/inbuilt/internal/radix/tree.go @@ -1,9 +1,14 @@ package radix import ( + "cmp" "errors" + "fmt" + "slices" + "strconv" "strings" + "github.com/indigo-web/indigo/internal/strutil" "github.com/indigo-web/indigo/kv" ) @@ -41,12 +46,11 @@ func (n *Node[T]) Lookup(key string, wildcards *kv.Storage) (value T, found bool loop: for len(key) > 0 { - for _, p := range node.predecessors { - if strings.HasPrefix(key, p.value) { - key = key[len(p.value):] - node = p - continue loop - } + p, found := node.findPredecessor(key) + if found { + key = key[len(p.value):] + node = p + continue loop } if node.dyn == nil { @@ -94,6 +98,34 @@ loop: return node.payload, node.isLeaf } +func (n *Node[T]) findPredecessor(key string) (*Node[T], bool) { + // Proudly stolen from sort package. + // Inlining is faster than calling BinarySearchFunc with a lambda. + u := len(n.predecessors) + // Define x[-1] < target and x[n] >= target. + // Invariant: x[i-1] < target, x[j] >= target. + i, j := 0, u + for i < j { + h := int(uint(i+j) >> 1) // avoid overflow when computing h + // i ≤ h < j + pred := n.predecessors[h] + keyprefix := min(len(key), len(pred.value)) + if cmp.Less(pred.value, key[:keyprefix]) { + i = h + 1 // preserves x[i-1] < target + } else { + j = h // preserves x[j] >= target + } + } + + if i >= u { + return nil, false + } + + pred := n.predecessors[i] + // i == j, x[i-1] < target, and x[j] (= x[i]) >= target => answer is i. + return pred, i < u && strings.HasPrefix(key, pred.value) +} + func addWildcard(wildcard, value string, into *kv.Storage) { if len(wildcard) > 0 { into.Add(wildcard, value) @@ -101,7 +133,12 @@ func addWildcard(wildcard, value string, into *kv.Storage) { } func (n *Node[T]) Insert(key string, value T) error { - return n.insert(splitPath(key), value) + str, ok := strutil.URLDecode(key) + if !ok { + return fmt.Errorf("poorly encoded path: %s", strconv.Quote(key)) + } + + return n.insert(splitPath(str), value) } func (n *Node[T]) insert(segs []pathSegment, value T) error { @@ -164,13 +201,25 @@ func (n *Node[T]) insert(segs []pathSegment, value T) error { } newNode := &Node[T]{value: seg.Value} - n.predecessors = append(n.predecessors, newNode) + n.appendPredecessor(newNode) return newNode.insert(segs[1:], value) } -func IsDynamicTemplate(str string) bool { - return strings.IndexByte(str, ':') != -1 +func (n *Node[T]) appendPredecessor(node *Node[T]) { + for i, pred := range n.predecessors { + if node.value < pred.value { + n.predecessors = slices.Insert(n.predecessors, i, node) + return + } + } + + n.predecessors = append(n.predecessors, node) +} + +func IsDynamicTemplate(path string) bool { + segs := splitPath(path) + return len(segs) > 1 || segs[0].IsWildcard } func truncCommon(segs []pathSegment, length int) []pathSegment { @@ -197,15 +246,15 @@ type pathSegment struct { Value string } -func splitPath(str string) (result []pathSegment) { +func splitPath(str string) (path []pathSegment) { for len(str) > 0 { colon := strings.IndexByte(str, ':') if colon == -1 { - result = append(result, pathSegment{false, false, str}) + path = append(path, pathSegment{false, false, str}) break } - result = append(result, pathSegment{false, false, str[:colon]}) + path = append(path, pathSegment{false, false, str[:colon]}) str = str[colon+1:] boundary := strings.IndexByte(str, '/') @@ -216,11 +265,11 @@ func splitPath(str string) (result []pathSegment) { wildcard := str[:boundary] greedy := false if strings.HasSuffix(wildcard, "...") { - wildcard = wildcard[:len(wildcard)-3] + wildcard = wildcard[:len(wildcard)-len("...")] greedy = true } - result = append(result, pathSegment{true, greedy, wildcard}) + path = append(path, pathSegment{true, greedy, wildcard}) if boundary < len(str) { boundary++ } @@ -228,5 +277,5 @@ func splitPath(str string) (result []pathSegment) { str = str[boundary:] } - return result + return path } diff --git a/router/inbuilt/internal/radix/tree_test.go b/router/inbuilt/internal/radix/tree_test.go index 6daf4175..15f18630 100644 --- a/router/inbuilt/internal/radix/tree_test.go +++ b/router/inbuilt/internal/radix/tree_test.go @@ -2,6 +2,8 @@ package radix import ( "fmt" + "iter" + "slices" "strconv" "strings" "testing" @@ -10,25 +12,105 @@ import ( "github.com/stretchr/testify/require" ) -func BenchmarkTreeMatch(b *testing.B) { - tree := New[int]() - tree.Insert("/hello/world", 1) - tree.Insert("/hello/whopper", 2) - tree.Insert("/henry/world", 3) - tree.Insert("/hello/world/somewhere", 4) - b.SetBytes(int64(len("/hello/world/somewhere"))) - b.ResetTimer() - - for range b.N { - _, _ = tree.Lookup("/hello/world/somewhere", nil) +func nextstr(str []byte) { + for i := len(str); i > 0; i-- { + str[i-1]++ + if str[i-1] <= 'Z' { + return + } + + str[i-1] = 'A' + } +} + +func produceSegments(n, seglen int) []string { + segments := make([]string, n) + str := []byte(strings.Repeat("A", seglen)) + + for i := range n { + segments[i] = string(str) + nextstr(str) } - // - //b.Run("10 static", func(b *testing.B) { - // paths := make([]string, 0, 10) - // for i := range 10 { - // - // } - //}) + + return segments +} + +func produceStrings(width, depth, strlen int) (iter.Seq[string], string) { + allSegs := produceSegments(width, strlen) + segments, carry := allSegs[:width-1], allSegs[width-1] + + return func(yield func(string) bool) { + var base string + + for i := 0; i < depth; i++ { + for _, segment := range segments { + if !yield(base + segment) { + break + } + } + + base += carry + } + + yield(base) + }, strings.Repeat(carry, depth) +} + +func TestBench(t *testing.T) { + it, lastKey := produceStrings(3, 3, 2) + require.Equal(t, "ACACAC", lastKey) + require.Equal(t, + []string{"AA", "AB", "ACAA", "ACAB", "ACACAA", "ACACAB", "ACACAC"}, + slices.Collect(it), + ) +} + +func BenchmarkTree(b *testing.B) { + noError := func(err error) { + if err != nil { + panic(err.Error()) + } + } + + runBench := func(width, depth int) func(b *testing.B) { + return func(b *testing.B) { + const seglen = 8 + tree := New[int]() + it, key := produceStrings(width, depth, seglen) + for str := range it { + noError(tree.Insert(str, 1)) + } + + b.ResetTimer() + + for range b.N { + _, _ = tree.Lookup(key, nil) + } + } + } + + b.Run("deep", func(b *testing.B) { + b.Run("1x128", runBench(1, 128)) + b.Run("8x8", runBench(8, 8)) + b.Run("8x64", runBench(8, 64)) + b.Run("8x128", runBench(8, 128)) + b.Run("8x256", runBench(68, 128)) + }) + + b.Run("wide", func(b *testing.B) { + b.Run("128x1", runBench(128, 1)) + b.Run("32x8", runBench(32, 8)) + b.Run("64x8", runBench(64, 8)) + b.Run("128x8", runBench(128, 8)) + b.Run("256x8", runBench(256, 8)) + }) + + b.Run("quadratic", func(b *testing.B) { + b.Run("16x16", runBench(16, 16)) + b.Run("32x32", runBench(32, 32)) + b.Run("64x64", runBench(64, 64)) + b.Run("128x128", runBench(128, 64)) + }) } func TestTree(t *testing.T) { @@ -59,6 +141,14 @@ func TestTree(t *testing.T) { require.Equal(t, "wow", wildcards.Value("id")) }) + test := func(t *testing.T, tree *Node[int], path string, value int, wKey, wVal string) { + w := kv.New() + val, found := tree.Lookup(path, w) + require.True(t, found) + require.Equal(t, value, val) + require.Equal(t, wVal, w.Value(wKey)) + } + t.Run("dynamic in the middle", func(t *testing.T) { tree := New[int]() require.NoError(t, tree.Insert("/user/:id", 1)) @@ -140,14 +230,7 @@ func TestTree(t *testing.T) { }) } -func test(t *testing.T, tree *Node[int], path string, value int, wKey, wVal string) { - w := kv.New() - val, found := tree.Lookup(path, w) - require.True(t, found) - require.Equal(t, value, val) - require.Equal(t, wVal, w.Value(wKey)) -} - +// isn't used anymore. Left just in case the tree needs to be debugged. func printTree(node *Node[int], depth int) { for _, p := range node.predecessors { fmt.Print(strings.Repeat("-", depth)) diff --git a/router/inbuilt/middleware/autocompress.go b/router/inbuilt/middleware/autocompress.go new file mode 100644 index 00000000..8ac60a9a --- /dev/null +++ b/router/inbuilt/middleware/autocompress.go @@ -0,0 +1,11 @@ +package middleware + +import ( + "github.com/indigo-web/indigo/http" + "github.com/indigo-web/indigo/router/inbuilt" +) + +// Autocompress prepends automatic compression options to responses. +func Autocompress(next inbuilt.Handler, request *http.Request) *http.Response { + return next(request).Compress() +} diff --git a/router/inbuilt/middlewares.go b/router/inbuilt/middlewares.go deleted file mode 100644 index 16fda619..00000000 --- a/router/inbuilt/middlewares.go +++ /dev/null @@ -1,37 +0,0 @@ -package inbuilt - -import ( - "github.com/indigo-web/indigo/http" -) - -// Middleware works like a chain of nested calls, next may be even directly -// handler. But if we are not a closing middleware, we will call next -// middleware that is simply a partial middleware with already provided next -type Middleware func(next Handler, request *http.Request) *http.Response - -// Use adds middlewares into the global list of a group's middlewares. But they will -// be applied only after the server will be started -func (r *Router) Use(middlewares ...Middleware) *Router { - r.middlewares = append(r.middlewares, middlewares...) - - return r -} - -func (r *Router) applyMiddlewares() { - r.registrar.Apply(func(handler Handler) Handler { - return compose(handler, r.middlewares) - }) -} - -// compose produces an array of middlewares into the chain, represented by types.Handler -func compose(handler Handler, middlewares []Middleware) Handler { - for i := len(middlewares) - 1; i >= 0; i-- { - handler = func(handler Handler, middleware Middleware) Handler { - return func(request *http.Request) *http.Response { - return middleware(handler, request) - } - }(handler, middlewares[i]) - } - - return handler -} diff --git a/router/inbuilt/middlewares_test.go b/router/inbuilt/middlewares_test.go deleted file mode 100644 index 6d998175..00000000 --- a/router/inbuilt/middlewares_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package inbuilt - -import ( - "testing" - - "github.com/indigo-web/indigo/config" - "github.com/indigo-web/indigo/http/method" - "github.com/indigo-web/indigo/http/status" - "github.com/indigo-web/indigo/internal/construct" - "github.com/indigo-web/indigo/transport/dummy" - "github.com/stretchr/testify/require" - - "github.com/indigo-web/indigo/http" -) - -/* -This file is separated because it is a bit specific and contains a lot -of specific stuff for testing only middlewares. Decided it's better to -separate it from all the other tests -*/ - -type middleware uint8 - -const ( - m1 middleware = iota + 1 - m2 - m3 - m4 - m5 - m6 - m7 -) - -type callstack struct { - chain []middleware -} - -func (c *callstack) Push(ware middleware) { - c.chain = append(c.chain, ware) -} - -func (c *callstack) Chain() []middleware { - return c.chain -} - -func (c *callstack) Clear() { - c.chain = c.chain[:0] -} - -func getMiddleware(mware middleware, stack *callstack) Middleware { - return func(next Handler, request *http.Request) *http.Response { - stack.Push(mware) - - return next(request) - } -} - -func getRequest(m method.Method, path string) *http.Request { - request := construct.Request(config.Default(), dummy.NewNopClient()) - request.Method = m - request.Path = path - - return request -} - -func TestMiddlewares(t *testing.T) { - stack := new(callstack) - raw := New(). - Use(getMiddleware(m1, stack)). - Get("/", http.Respond, getMiddleware(m2, stack)) - - api := raw.Group("/api"). - Use(getMiddleware(m3, stack)) - - api.Group("/v1"). - Get("/hello", http.Respond, getMiddleware(m6, stack)). - Use(getMiddleware(m4, stack)) - - api.Group("/v2"). - Use(getMiddleware(m5, stack)). - Get("/world", http.Respond, getMiddleware(m7, stack)) - - r := raw.Build() - - t.Run("/", func(t *testing.T) { - request := getRequest(method.GET, "/") - response := r.OnRequest(request) - require.Equal(t, status.OK, response.Expose().Code) - require.Equal(t, []middleware{m1, m2}, stack.Chain()) - stack.Clear() - }) - - t.Run("/api/v1/hello", func(t *testing.T) { - request := getRequest(method.GET, "/api/v1/hello") - response := r.OnRequest(request) - require.Equal(t, status.OK, response.Expose().Code) - require.Equal(t, []middleware{m1, m3, m4, m6}, stack.Chain()) - stack.Clear() - }) - - t.Run("/api/v2/world", func(t *testing.T) { - request := getRequest(method.GET, "/api/v2/world") - response := r.OnRequest(request) - require.Equal(t, status.OK, response.Expose().Code) - require.Equal(t, []middleware{m1, m3, m5, m7}, stack.Chain()) - stack.Clear() - }) -} diff --git a/router/inbuilt/mutator.go b/router/inbuilt/mutator.go deleted file mode 100644 index 46acfb63..00000000 --- a/router/inbuilt/mutator.go +++ /dev/null @@ -1,15 +0,0 @@ -package inbuilt - -import ( - "github.com/indigo-web/indigo/router/inbuilt/internal" -) - -type Mutator = internal.Mutator - -// Mutator adds a new mutator. -// -// NOTE: registering them on groups will affect only the order of execution -func (r *Router) Mutator(mutator Mutator) *Router { - r.mutators = append(r.mutators, mutator) - return r -} diff --git a/router/inbuilt/registrar.go b/router/inbuilt/registrar.go index d483eae6..96e7eb6e 100644 --- a/router/inbuilt/registrar.go +++ b/router/inbuilt/registrar.go @@ -10,9 +10,8 @@ import ( ) type registrar struct { - endpoints map[string]map[method.Method]Handler - usedMethods [method.Count]bool - isDynamic bool + endpoints map[string]map[method.Method]Handler + isDynamic bool } func newRegistrar() *registrar { @@ -26,7 +25,6 @@ func (r *registrar) Add(path string, m method.Method, handler Handler) error { return fmt.Errorf("empty path") } - // TODO: support urlencoded characters in endpoints. path = uri.Normalize(path) methodsMap := r.endpoints[path] if methodsMap == nil { @@ -34,7 +32,7 @@ func (r *registrar) Add(path string, m method.Method, handler Handler) error { } if _, ok := methodsMap[m]; ok { - return fmt.Errorf("%s %s: already registered", m, path) + return fmt.Errorf("duplicate endpoint: %s %s", m, path) } methodsMap[m] = handler @@ -54,7 +52,7 @@ func (r *registrar) Merge(another *registrar) error { } if r.endpoints[path][method_] != nil { - return fmt.Errorf("route already registered: %s %s", method_.String(), path) + return fmt.Errorf("duplicate endpoint: %s %s", method_, path) } r.endpoints[path][method_] = handler @@ -92,10 +90,6 @@ func (r *registrar) AsRadixTree() radixTree { tree := radix.New[endpoint]() for path, e := range r.endpoints { - if len(path) == 0 { - panic("empty path") - } - var ( mlut methodLUT allow string @@ -116,3 +110,48 @@ func (r *registrar) AsRadixTree() radixTree { return tree } + +func (r *registrar) Options(includeTRACE bool) string { + var ( + totalEndpoints int + methodsStatistic [method.Count + 1]int + ) + + for _, ep := range r.endpoints { + totalEndpoints++ + + for m := range ep { + methodsStatistic[m]++ + } + } + + if totalEndpoints == 0 { + // a server with no endpoints at all. Must be rare enough, I guess. + return "" + } + + // As OPTIONS is supported, it must appear unconditionally + methodsStatistic[method.OPTIONS] = totalEndpoints + + if includeTRACE { + methodsStatistic[method.TRACE] = totalEndpoints + } + + if methodsStatistic[method.GET] == totalEndpoints { + // HEAD requests must also be unconditionally enabled, if GET + // are also supported, as HEAD are automatically redirected to + // GET handlers if weren't explicitly redefined. + methodsStatistic[method.HEAD] = totalEndpoints + } + + options := make([]string, 0, method.Count) + + for m, usage := range methodsStatistic { + if usage == totalEndpoints { + // EACH endpoint supports this method + options = append(options, method.Method(m).String()) + } + } + + return strings.Join(options, ", ") +} diff --git a/router/inbuilt/resource.go b/router/inbuilt/resource.go index 412edeaa..a4aad0c0 100644 --- a/router/inbuilt/resource.go +++ b/router/inbuilt/resource.go @@ -13,13 +13,6 @@ type Resource struct { group *Router } -// Resource returns a new Resource object for a provided resource path -func (r *Router) Resource(path string) Resource { - return Resource{ - group: r.Group(path), - } -} - // Use applies middlewares to the resource, wrapping all the already registered // and registered in future handlers func (r Resource) Use(middlewares ...Middleware) Resource { diff --git a/router/inbuilt/route.go b/router/inbuilt/route.go deleted file mode 100644 index ecf3e560..00000000 --- a/router/inbuilt/route.go +++ /dev/null @@ -1,58 +0,0 @@ -package inbuilt - -import ( - "github.com/indigo-web/indigo/http/method" - "github.com/indigo-web/indigo/http/status" -) - -// AllErrors is used to be passed into Router.RouteError, indicating by that, -// that the handler must handle ALL errors (if concrete error's handler won't -// override it) -const AllErrors = status.Code(0) - -// Route is a base method for registering handlers -func (r *Router) Route( - method method.Method, path string, handlerFunc Handler, middlewares ...Middleware, -) *Router { - err := r.registrar.Add(r.prefix+path, method, compose(handlerFunc, middlewares)) - if err != nil { - panic(err) - } - - return r -} - -// TODO: update the error handling mechanism. It should be more modifications-prone - -// RouteError adds an error handler for a corresponding HTTP error code. -// -// The following error codes may be registered: -// - AllErrors (called only if no other error handlers found) -// - status.BadRequest -// - status.NotFound -// - status.MethodNotAllowed -// - status.RequestEntityTooLarge -// - status.CloseConnection -// - status.RequestURITooLong -// - status.HeaderFieldsTooLarge -// - status.HTTPVersionNotSupported -// - status.UnsupportedMediaType -// - status.NotImplemented -// - status.RequestTimeout -// -// Note: if handler returned one of error codes above, error handler WON'T be called. -// Also, global middlewares, applied to the root router, will also be used for error handlers. -// However, global middlewares defined on groups won't be used. -// -// WARNING: calling this method from groups will affect ALL routers, including root -func (r *Router) RouteError(handler Handler, codes ...status.Code) *Router { - if len(codes) == 0 { - codes = append(codes, AllErrors) - } - - for _, code := range codes { - r.errHandlers[code] = handler - } - - return r -} diff --git a/router/inbuilt/static.go b/router/inbuilt/static.go index 4c87f956..9c7f18a3 100644 --- a/router/inbuilt/static.go +++ b/router/inbuilt/static.go @@ -30,6 +30,7 @@ func (r *Router) Static(prefix, root string, mwares ...Middleware) *Router { return http.Error(request, status.ErrBadRequest) } + // TODO: cache descriptors file, err := fs.Open(path) if err != nil { return http. @@ -45,7 +46,7 @@ func (r *Router) Static(prefix, root string, mwares ...Middleware) *Router { } return http. - SizedStream(request, file, fstat.Size()). + Stream(request, file, fstat.Size()). ContentType(mime.Guess(path)) }, mwares...) } diff --git a/router/inbuilt/trace.go b/router/inbuilt/trace.go index 2474a19a..97b0b579 100644 --- a/router/inbuilt/trace.go +++ b/router/inbuilt/trace.go @@ -1,67 +1,53 @@ package inbuilt import ( - "strings" - + "github.com/flrdv/uf" "github.com/indigo-web/indigo/http" + "github.com/indigo-web/indigo/http/mime" + "github.com/indigo-web/indigo/kv" ) -/* -This file is responsible for rendering http requests. Prime use case is rendering -http requests back as a response to a trace request -*/ - -func traceResponse(respond *http.Response, messageBody []byte) *http.Response { - return respond. - Header("Content-Type", "message/http"). - Bytes(messageBody) -} - -func renderHTTPRequest(request *http.Request, buff []byte) []byte { - buff = append(buff, request.Method.String()...) - buff = append(buff, ' ') - buff = requestURI(request, buff) - buff = append(buff, ' ') - buff = append(buff, strings.TrimSpace(request.Protocol.String())...) - buff = append(buff, "\r\n"...) - buff = requestHeaders(request.Headers, buff) - buff = append(buff, "Content-Length: 0\r\n\r\n"...) - - return buff -} - -func requestURI(request *http.Request, buff []byte) []byte { - buff = append(buff, request.Path...) - buff = requestURIParams(request.Params, buff) - - return buff -} - -func requestURIParams(params http.Params, buff []byte) []byte { - if params.Empty() { - return buff +func traceHandler(request *http.Request) *http.Response { + resp := request.Respond() + // exploit the fact that Write method never returns an error + _, _ = resp.Write(uf.S2B(request.Method.String())) + _, _ = resp.Write([]byte(" ")) + // we probably should've escaped the path back, as otherwise this might lead to + // some unwanted situations, however... Well, this doesn't seem to be effort-worthy. + _, _ = resp.Write(uf.S2B(request.Path)) + if !request.Params.Empty() { + pairs := request.Params.Expose() + _, _ = resp.Write([]byte("?")) + writeParam(resp, pairs[0]) + + for _, pair := range pairs[1:] { + // can avoid if len(pair.Key) == 0 { continue } (to filter out deleted entries), + // because the TRACE handler is supposed to be executed if no other handler ran, + // thereby having a completely virgin request, which was never touched by dirty + // user's hands ever before. And won't after as well. Awesome. + _, _ = resp.Write([]byte("&")) + writeParam(resp, pair) + } } - buff = append(buff, '?') - - for key, val := range params.Pairs() { - buff = append(buff, key...) - if len(val) > 0 { - buff = append(buff, '=') - buff = append(buff, val...) - } + _, _ = resp.Write([]byte(" ")) + _, _ = resp.Write(uf.S2B(request.Protocol.String())) + _, _ = resp.Write([]byte("\r\n")) - buff = append(buff, '&') + for key, value := range request.Headers.Pairs() { + _, _ = resp.Write(uf.S2B(key)) + _, _ = resp.Write([]byte(": ")) + _, _ = resp.Write(uf.S2B(value)) + _, _ = resp.Write([]byte("\r\n")) } - return buff[:len(buff)-1] -} + _, _ = resp.Write([]byte("\r\n")) -func requestHeaders(hdrs http.Headers, buff []byte) []byte { - for _, pair := range hdrs.Expose() { - buff = append(append(buff, pair.Key...), ": "...) - buff = append(append(buff, pair.Value...), "\r\n"...) - } + return resp.ContentType(mime.HTTP) +} - return buff +func writeParam(resp *http.Response, param kv.Pair) { + _, _ = resp.Write(uf.S2B(param.Key)) + _, _ = resp.Write([]byte("=")) + _, _ = resp.Write(uf.S2B(param.Value)) } diff --git a/router/inbuilt/trace_test.go b/router/inbuilt/trace_test.go new file mode 100644 index 00000000..2d6bd09d --- /dev/null +++ b/router/inbuilt/trace_test.go @@ -0,0 +1,87 @@ +package inbuilt + +import ( + "io" + "testing" + + "github.com/indigo-web/indigo/config" + "github.com/indigo-web/indigo/http" + "github.com/indigo-web/indigo/http/method" + "github.com/indigo-web/indigo/http/status" + "github.com/indigo-web/indigo/internal/construct" + "github.com/indigo-web/indigo/kv" + "github.com/indigo-web/indigo/transport/dummy" + "github.com/stretchr/testify/require" +) + +func TestTrace(t *testing.T) { + newRequest := func(path string, params http.Params) *http.Request { + req := construct.Request(config.Default(), dummy.NewNopClient()) + req.Method = method.TRACE + req.Path = path + req.Params = params + req.Headers = kv.New(). + Add("Accept", "*/*"). + Add("Content-Length", "13") + + return req + } + + wantMirroring := func(path string) string { + return "TRACE " + path + " HTTP/1.1\r\nAccept: */*\r\nContent-Length: 13\r\n\r\n" + } + + r := New(). + EnableTRACE(true). + Get("/", http.Respond). + Build() + + t.Run("TRACE on registered endpoint", func(t *testing.T) { + req := newRequest("/", kv.New()) + resp := r.OnRequest(req) + require.Equal(t, 200, int(resp.Expose().Code)) + b, err := io.ReadAll(resp.Expose().Stream) + require.NoError(t, err) + require.Equal(t, wantMirroring("/"), string(b)) + }) + + t.Run("TRACE on non-existing endpoint", func(t *testing.T) { + req := newRequest("/hello", kv.New()) + resp := r.OnRequest(req) + require.Equal(t, 200, int(resp.Expose().Code)) + b, err := io.ReadAll(resp.Expose().Stream) + require.NoError(t, err) + require.Equal(t, wantMirroring("/hello"), string(b)) + }) + + t.Run("params", func(t *testing.T) { + t.Run("single", func(t *testing.T) { + req := newRequest("/", kv.New().Add("hello", "world")) + resp := r.OnRequest(req) + require.Equal(t, 200, int(resp.Expose().Code)) + b, err := io.ReadAll(resp.Expose().Stream) + require.NoError(t, err) + require.Equal(t, wantMirroring("/?hello=world"), string(b)) + }) + + t.Run("multiple", func(t *testing.T) { + params := kv.New(). + Add("hello", "world"). + Add("hi", "hello") + req := newRequest("/", params) + resp := r.OnRequest(req) + require.Equal(t, 200, int(resp.Expose().Code)) + b, err := io.ReadAll(resp.Expose().Stream) + require.NoError(t, err) + require.Equal(t, wantMirroring("/?hello=world&hi=hello"), string(b)) + }) + }) + + t.Run("disabled", func(t *testing.T) { + r := New(). + Get("/", http.Respond). + Build() + resp := r.OnRequest(newRequest("/", kv.New())) + require.Equal(t, int(status.MethodNotAllowed), int(resp.Expose().Code)) + }) +} diff --git a/router/inbuilt/types.go b/router/inbuilt/types.go index 607a7cf9..57e40e65 100644 --- a/router/inbuilt/types.go +++ b/router/inbuilt/types.go @@ -1,11 +1,14 @@ package inbuilt import ( + "fmt" + "strconv" "strings" "github.com/indigo-web/indigo/http" "github.com/indigo-web/indigo/http/method" "github.com/indigo-web/indigo/http/status" + "github.com/indigo-web/indigo/internal/strutil" "github.com/indigo-web/indigo/router/inbuilt/internal/radix" ) @@ -26,20 +29,34 @@ type ( ) func (r routesMap) Add(path string, m method.Method, handler Handler) { - entry := r[path] + p, ok := strutil.URLDecode(path) + if !ok { + panic(fmt.Errorf("poorly encoded path: %s", strconv.Quote(path))) + } + + entry := r[p] entry.methods[m] = handler entry.allow = getAllowString(entry.methods) - r[path] = entry + r[p] = entry } func getAllowString(methods methodLUT) (allowed string) { + definedMethods := make([]string, 0, method.Count) + for i, handler := range methods { if handler == nil { continue } - allowed += method.Method(i).String() + "," + definedMethods = append(definedMethods, method.Method(i).String()) + if method.Method(i) == method.GET && methods[method.HEAD] == nil { + // append HEAD automatically even if they aren't explicitly supported. + // Actually, we could do this after the loop, but then we'd have the HEAD + // in the end, probably away from the GET, which looks suboptimal. + // Aesthetics is always important. + definedMethods = append(definedMethods, method.HEAD.String()) + } } - return strings.TrimRight(allowed, ",") + return strings.Join(definedMethods, ", ") } diff --git a/router/inbuilt/uri/normalize.go b/router/inbuilt/uri/normalize.go index 0aaf09e8..171d13a0 100644 --- a/router/inbuilt/uri/normalize.go +++ b/router/inbuilt/uri/normalize.go @@ -1,12 +1,9 @@ package uri -// Normalize removes trailing slashes, as all request paths are also trimmed, resulting -// in consensus between these two. +// Normalize eliminates a trailing slash if presented. func Normalize(path string) string { - for i := len(path) - 1; i > 1; i-- { - if path[i] != '/' { - return path[:i+1] - } + if len(path) > 1 && path[len(path)-1] == '/' { + return path[:len(path)-1] } return path diff --git a/router/inbuilt/uri/normalize_test.go b/router/inbuilt/uri/normalize_test.go index add476d5..501c094a 100644 --- a/router/inbuilt/uri/normalize_test.go +++ b/router/inbuilt/uri/normalize_test.go @@ -21,9 +21,4 @@ func TestNormalize(t *testing.T) { norm := Normalize("/api/") require.Equal(t, "/api", norm) }) - - t.Run("multiple trailing", func(t *testing.T) { - norm := Normalize("/api/////") - require.Equal(t, "/api", norm) - }) } diff --git a/transport.go b/transport.go index 47ed20aa..6fc6a247 100644 --- a/transport.go +++ b/transport.go @@ -33,8 +33,10 @@ func TCP() Transport { return Transport{ inner: transport.NewTCP(), spawnCallback: func(cfg *config.Config, r router.Router, c []codec.Codec) func(net.Conn) { + acceptString := codecutil.AcceptEncoding(c) + return func(conn net.Conn) { - serve.HTTP1(cfg, conn, 0, r, codecutil.NewCache(c)) + serve.HTTP1(cfg, conn, 0, r, codecutil.NewCache(c, acceptString)) } }, } @@ -105,9 +107,11 @@ func newTLSTransport(cfg *tls.Config) Transport { return Transport{ inner: transport.NewTLS(cfg), spawnCallback: func(cfg *config.Config, r router.Router, c []codec.Codec) func(net.Conn) { + acceptString := codecutil.AcceptEncoding(c) + return func(conn net.Conn) { ver := conn.(*tls.Conn).ConnectionState().Version - serve.HTTP1(cfg, conn, ver, r, codecutil.NewCache(c)) + serve.HTTP1(cfg, conn, ver, r, codecutil.NewCache(c, acceptString)) } }, }