diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index c107b212ed..69c3fd5c1d 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -117,46 +117,49 @@ func New(config ...Config) fiber.Handler { // Get timestamp ts := atomic.LoadUint64(×tamp) - // Invalidate cache if requested - if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) && e != nil { - e.exp = ts - 1 - } - - // Check if entry is expired - if e.exp != 0 && ts >= e.exp { - deleteKey(key) - if cfg.MaxBytes > 0 { - _, size := heap.remove(e.heapidx) - storedBytes -= size - } - } else if e.exp != 0 && !hasRequestDirective(c, noCache) { - // Separate body value to avoid msgp serialization - // We can store raw bytes with Storage 👍 - if cfg.Storage != nil { - e.body = manager.getRaw(key + "_body") - } - // Set response headers from cache - c.Response().SetBodyRaw(e.body) - c.Response().SetStatusCode(e.status) - c.Response().Header.SetContentTypeBytes(e.ctype) - if len(e.cencoding) > 0 { - c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding) - } - for k, v := range e.headers { - c.Response().Header.SetBytesV(k, v) + // Cache Entry not found + if e != nil { + // Invalidate cache if requested + if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) { + e.exp = ts - 1 } - // Set Cache-Control header if enabled - if cfg.CacheControl { - maxAge := strconv.FormatUint(e.exp-ts, 10) - c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge) - } - - c.Set(cfg.CacheHeader, cacheHit) - - mux.Unlock() - // Return response - return nil + // Check if entry is expired + if e.exp != 0 && ts >= e.exp { + deleteKey(key) + if cfg.MaxBytes > 0 { + _, size := heap.remove(e.heapidx) + storedBytes -= size + } + } else if e.exp != 0 && !hasRequestDirective(c, noCache) { + // Separate body value to avoid msgp serialization + // We can store raw bytes with Storage 👍 + if cfg.Storage != nil { + e.body = manager.getRaw(key + "_body") + } + // Set response headers from cache + c.Response().SetBodyRaw(e.body) + c.Response().SetStatusCode(e.status) + c.Response().Header.SetContentTypeBytes(e.ctype) + if len(e.cencoding) > 0 { + c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding) + } + for k, v := range e.headers { + c.Response().Header.SetBytesV(k, v) + } + // Set Cache-Control header if enabled + if cfg.CacheControl { + maxAge := strconv.FormatUint(e.exp-ts, 10) + c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge) + } + + c.Set(cfg.CacheHeader, cacheHit) + + mux.Unlock() + + // Return response + return nil + } } // make sure we're not blocking concurrent requests - do unlock @@ -193,6 +196,7 @@ func New(config ...Config) fiber.Handler { } } + e = manager.acquire() // Cache response e.body = utils.CopyBytes(c.Response().Body()) e.status = c.Response().StatusCode() diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index d529ccd9f5..1cc3b7375b 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -47,9 +47,10 @@ func Test_Cache_Expired(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 2 * time.Second})) - + count := 0 app.Get("/", func(c fiber.Ctx) error { - return c.SendString(strconv.FormatInt(time.Now().UnixNano(), 10)) + count++ + return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) @@ -86,9 +87,10 @@ func Test_Cache(t *testing.T) { app := fiber.New() app.Use(New()) + count := 0 app.Get("/", func(c fiber.Ctx) error { - now := strconv.FormatInt(time.Now().UnixNano(), 10) - return c.SendString(now) + count++ + return c.SendString(strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", nil) @@ -305,9 +307,10 @@ func Test_Cache_Invalid_Expiration(t *testing.T) { cache := New(Config{Expiration: 0 * time.Second}) app.Use(cache) + count := 0 app.Get("/", func(c fiber.Ctx) error { - now := strconv.FormatInt(time.Now().UnixNano(), 10) - return c.SendString(now) + count++ + return c.SendString(strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", nil) @@ -414,8 +417,10 @@ func Test_Cache_NothingToCache(t *testing.T) { app.Use(New(Config{Expiration: -(time.Second * 1)})) + count := 0 app.Get("/", func(c fiber.Ctx) error { - return c.SendString(time.Now().String()) + count++ + return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) @@ -447,12 +452,16 @@ func Test_Cache_CustomNext(t *testing.T) { CacheControl: true, })) + count := 0 app.Get("/", func(c fiber.Ctx) error { - return c.SendString(time.Now().String()) + count++ + return c.SendString(strconv.Itoa(count)) }) + errorCount := 0 app.Get("/error", func(c fiber.Ctx) error { - return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String()) + errorCount++ + return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(errorCount)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) @@ -508,9 +517,11 @@ func Test_CustomExpiration(t *testing.T) { return time.Second * time.Duration(newCacheTime) }})) + count := 0 app.Get("/", func(c fiber.Ctx) error { + count++ c.Response().Header.Add("Cache-Time", "1") - return c.SendString(strconv.FormatInt(time.Now().UnixNano(), 10)) + return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) @@ -588,8 +599,11 @@ func Test_CacheHeader(t *testing.T) { return c.SendString(fiber.Query[string](c, "cache")) }) + count := 0 app.Get("/error", func(c fiber.Ctx) error { - return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String()) + count++ + c.Response().Header.Add("Cache-Time", "1") + return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) @@ -615,10 +629,13 @@ func Test_Cache_WithHead(t *testing.T) { app := fiber.New() app.Use(New()) + count := 0 handler := func(c fiber.Ctx) error { - now := strconv.FormatInt(time.Now().UnixNano(), 10) - return c.SendString(now) + count++ + c.Response().Header.Add("Cache-Time", "1") + return c.SendString(strconv.Itoa(count)) } + app.Route("/").Get(handler).Head(handler) req := httptest.NewRequest(fiber.MethodHead, "/", nil) @@ -708,8 +725,10 @@ func Test_CacheInvalidation(t *testing.T) { }, })) + count := 0 app.Get("/", func(c fiber.Ctx) error { - return c.SendString(time.Now().String()) + count++ + return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) @@ -731,6 +750,93 @@ func Test_CacheInvalidation(t *testing.T) { require.NotEqual(t, body, bodyInvalidate) } +func Test_CacheInvalidation_noCacheEntry(t *testing.T) { + t.Parallel() + t.Run("Cache Invalidator should not be called if no cache entry exist ", func(t *testing.T) { + t.Parallel() + app := fiber.New() + cacheInvalidatorExecuted := false + app.Use(New(Config{ + CacheControl: true, + CacheInvalidator: func(c fiber.Ctx) bool { + cacheInvalidatorExecuted = true + return fiber.Query[bool](c, "invalidate") + }, + MaxBytes: 10 * 1024 * 1024, + })) + _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil)) + require.NoError(t, err) + require.False(t, cacheInvalidatorExecuted) + }) +} + +func Test_CacheInvalidation_removeFromHeap(t *testing.T) { + t.Parallel() + t.Run("Invalidate and remove from the heap", func(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Use(New(Config{ + CacheControl: true, + CacheInvalidator: func(c fiber.Ctx) bool { + return fiber.Query[bool](c, "invalidate") + }, + MaxBytes: 10 * 1024 * 1024, + })) + + count := 0 + app.Get("/", func(c fiber.Ctx) error { + count++ + return c.SendString(strconv.Itoa(count)) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + bodyCached, err := io.ReadAll(respCached.Body) + require.NoError(t, err) + require.True(t, bytes.Equal(body, bodyCached)) + require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl)) + + respInvalidate, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil)) + require.NoError(t, err) + bodyInvalidate, err := io.ReadAll(respInvalidate.Body) + require.NoError(t, err) + require.NotEqual(t, body, bodyInvalidate) + }) +} + +func Test_CacheStorage_CustomHeaders(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Use(New(Config{ + CacheControl: true, + Storage: memory.New(), + MaxBytes: 10 * 1024 * 1024, + })) + + app.Get("/", func(c fiber.Ctx) error { + c.Response().Header.Set("Content-Type", "text/xml") + c.Response().Header.Set("Content-Encoding", "utf8") + return c.Send([]byte("Test")) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + bodyCached, err := io.ReadAll(respCached.Body) + require.NoError(t, err) + require.True(t, bytes.Equal(body, bodyCached)) + require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl)) +} + // Because time points are updated once every X milliseconds, entries in tests can often have // equal expiration times and thus be in an random order. This closure hands out increasing // time intervals to maintain strong ascending order of expiration diff --git a/middleware/cache/heap.go b/middleware/cache/heap.go index fa97871595..c5715392ef 100644 --- a/middleware/cache/heap.go +++ b/middleware/cache/heap.go @@ -15,7 +15,7 @@ type heapEntry struct { // elements in constant time. It does so by handing out special indices // and tracking entry movement. // -// indexdedHeap is used for quickly finding entries with the lowest +// indexedHeap is used for quickly finding entries with the lowest // expiration timestamp and deleting arbitrary entries. type indexedHeap struct { // Slice the heap is built on diff --git a/middleware/cache/manager.go b/middleware/cache/manager.go index c6ae542805..0c99d014eb 100644 --- a/middleware/cache/manager.go +++ b/middleware/cache/manager.go @@ -83,8 +83,7 @@ func (m *manager) get(key string) *item { return it } if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool - it = m.acquire() - return it + return nil } return it } diff --git a/middleware/cache/manager_test.go b/middleware/cache/manager_test.go new file mode 100644 index 0000000000..d792cc3444 --- /dev/null +++ b/middleware/cache/manager_test.go @@ -0,0 +1,23 @@ +package cache + +import ( + "testing" + "time" + + "github.com/gofiber/utils/v2" + "github.com/stretchr/testify/assert" +) + +func Test_manager_get(t *testing.T) { + cacheManager := newManager(nil) + t.Run("Item not found in cache", func(t *testing.T) { + assert.Nil(t, cacheManager.get(utils.UUID())) + }) + t.Run("Item found in cache", func(t *testing.T) { + id := utils.UUID() + cacheItem := cacheManager.acquire() + cacheItem.body = []byte("test-body") + cacheManager.set(id, cacheItem, 10*time.Second) + assert.NotNil(t, cacheManager.get(id)) + }) +}