diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index e998bbbbe3..e198d7a121 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -700,7 +700,60 @@ func Test_CustomCacheHeader(t *testing.T) { func Test_CacheInvalidation(t *testing.T) { t.Parallel() - t.Run("Invalidation by requests", func(t *testing.T) { + app := fiber.New() + app.Use(New(Config{ + CacheControl: true, + CacheInvalidator: func(c fiber.Ctx) bool { + return fiber.Query[bool](c, "invalidate") + }, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(time.Now().String()) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + bodyCached, err := io.ReadAll(respCached.Body) + require.NoError(t, err) + require.True(t, bytes.Equal(body, bodyCached)) + require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl)) + + respInvalidate, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil)) + require.NoError(t, err) + bodyInvalidate, err := io.ReadAll(respInvalidate.Body) + require.NoError(t, err) + require.NotEqual(t, body, bodyInvalidate) +} + +func Test_CacheInvalidation_noCacheEntry(t *testing.T) { + t.Parallel() + t.Run("Cache Invalidator should not be called if no cache entry exist ", func(t *testing.T) { + t.Parallel() + app := fiber.New() + cacheInvalidatorExecuted := false + app.Use(New(Config{ + CacheControl: true, + CacheInvalidator: func(c fiber.Ctx) bool { + cacheInvalidatorExecuted = true + return fiber.Query[bool](c, "invalidate") + }, + MaxBytes: 10 * 1024 * 1024, + })) + _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil)) + require.NoError(t, err) + require.False(t, cacheInvalidatorExecuted) + }) +} + +func Test_CacheInvalidation_removeFromHeap(t *testing.T) { + t.Parallel() + t.Run("Invalidate and remove from the heap", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ @@ -708,6 +761,7 @@ func Test_CacheInvalidation(t *testing.T) { CacheInvalidator: func(c fiber.Ctx) bool { return fiber.Query[bool](c, "invalidate") }, + MaxBytes: 10 * 1024 * 1024, })) app.Get("/", func(c fiber.Ctx) error { @@ -732,23 +786,34 @@ func Test_CacheInvalidation(t *testing.T) { require.NoError(t, err) require.NotEqual(t, body, bodyInvalidate) }) +} - t.Run("Cache Invalidator should not be called if no cache entry exist ", func(t *testing.T) { - t.Parallel() - app := fiber.New() - cacheInvalidatorExecuted := false - app.Use(New(Config{ - CacheControl: true, - CacheInvalidator: func(c fiber.Ctx) bool { - cacheInvalidatorExecuted = true - return fiber.Query[bool](c, "invalidate") - }, - MaxBytes: 10 * 1024 * 1024, - })) - _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil)) - require.NoError(t, err) - require.False(t, cacheInvalidatorExecuted) +func Test_CacheStorage_CustomHeaders(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Use(New(Config{ + CacheControl: true, + Storage: memory.New(), + MaxBytes: 10 * 1024 * 1024, + })) + + app.Get("/", func(c fiber.Ctx) error { + c.Response().Header.Set("Content-Type", "text/xml") + c.Response().Header.Set("Content-Encoding", "utf8") + return c.Send([]byte("Test")) }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + bodyCached, err := io.ReadAll(respCached.Body) + require.NoError(t, err) + require.True(t, bytes.Equal(body, bodyCached)) + require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl)) } // Because time points are updated once every X milliseconds, entries in tests can often have