From 280d5399dc1b22938557bf7d27b73df13630fb7c Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 13 Sep 2024 21:00:23 -0300 Subject: [PATCH] refactor(middleware/session): Improve data pool handling and locking --- middleware/session/data.go | 20 ++++--- middleware/session/session_test.go | 85 +++++++++++++++++++----------- 2 files changed, 65 insertions(+), 40 deletions(-) diff --git a/middleware/session/data.go b/middleware/session/data.go index 7d278d8e21..93f7c06f57 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -29,7 +29,12 @@ var dataPool = sync.Pool{ // // d := acquireData() func acquireData() *data { - return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool + obj := dataPool.Get() + if d, ok := obj.(*data); ok { + return d + } + // Handle unexpected type in the pool + panic("unexpected type in data pool") } // Reset clears the data map and resets the data object. @@ -39,8 +44,8 @@ func acquireData() *data { // d.Reset() func (d *data) Reset() { d.Lock() + defer d.Unlock() d.Data = make(map[string]any) - d.Unlock() } // Get retrieves a value from the data map by key. @@ -56,9 +61,8 @@ func (d *data) Reset() { // value := d.Get("key") func (d *data) Get(key string) any { d.RLock() - v := d.Data[key] - d.RUnlock() - return v + defer d.RUnlock() + return d.Data[key] } // Set updates or creates a new key-value pair in the data map. @@ -72,8 +76,8 @@ func (d *data) Get(key string) any { // d.Set("key", "value") func (d *data) Set(key string, value any) { d.Lock() + defer d.Unlock() d.Data[key] = value - d.Unlock() } // Delete removes a key-value pair from the data map. @@ -86,8 +90,8 @@ func (d *data) Set(key string, value any) { // d.Delete("key") func (d *data) Delete(key string) { d.Lock() + defer d.Unlock() delete(d.Data, key) - d.Unlock() } // Keys retrieves all keys in the data map. @@ -100,11 +104,11 @@ func (d *data) Delete(key string) { // keys := d.Keys() func (d *data) Keys() []string { d.RLock() + defer d.RUnlock() keys := make([]string, 0, len(d.Data)) for k := range d.Data { keys = append(keys, k) } - d.RUnlock() return keys } diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index c27a535fb7..ef8d469ade 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -8,7 +8,6 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) @@ -25,7 +24,6 @@ func Test_Session(t *testing.T) { // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) // Get a new session sess, err := store.Get(ctx) @@ -34,6 +32,7 @@ func Test_Session(t *testing.T) { token := sess.ID() require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -82,6 +81,9 @@ func Test_Session(t *testing.T) { err = sess.Save() require.NoError(t, err) + // release the session + sess.Release() + // release the context app.ReleaseCtx(ctx) // requesting entirely new context to prevent falsy tests @@ -94,6 +96,8 @@ func Test_Session(t *testing.T) { // this id should be randomly generated as session key was deleted require.Len(t, sess.ID(), 36) + sess.Release() + // when we use the original session for the second time // the session be should be same if the session is not expired app.ReleaseCtx(ctx) @@ -103,6 +107,7 @@ func Test_Session(t *testing.T) { // request the server with the old session ctx.Request().Header.SetCookie(store.sessionName, id) sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.False(t, sess.Fresh()) require.Equal(t, sess.id, id) @@ -187,6 +192,7 @@ func Test_Session_Types(t *testing.T) { err = sess.Save() require.NoError(t, err) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -278,6 +284,8 @@ func Test_Session_Types(t *testing.T) { require.True(t, ok) require.Equal(t, vcomplex128, vcomplex128Result) + sess.Release() + app.ReleaseCtx(ctx) } @@ -305,6 +313,7 @@ func Test_Session_Store_Reset(t *testing.T) { require.NoError(t, store.Reset()) id := sess.ID() + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) @@ -312,6 +321,7 @@ func Test_Session_Store_Reset(t *testing.T) { // make sure the session is recreated sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.True(t, sess.Fresh()) require.Nil(t, sess.Get("hello")) @@ -339,6 +349,7 @@ func Test_Session_Save(t *testing.T) { // save session err = sess.Save() require.NoError(t, err) + sess.Release() }) t.Run("save to header", func(t *testing.T) { @@ -364,6 +375,7 @@ func Test_Session_Save(t *testing.T) { require.NoError(t, err) require.Equal(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName))) require.Equal(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName))) + sess.Release() }) } @@ -398,6 +410,7 @@ func Test_Session_Save_Expiration(t *testing.T) { err = sess.Save() require.NoError(t, err) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -410,6 +423,8 @@ func Test_Session_Save_Expiration(t *testing.T) { // just to make sure the session has been expired time.Sleep(sessionDuration + (10 * time.Millisecond)) + sess.Release() + app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) @@ -417,6 +432,7 @@ func Test_Session_Save_Expiration(t *testing.T) { // here you should get a new session ctx.Request().Header.SetCookie(store.sessionName, token) sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.Nil(t, sess.Get("name")) require.NotEqual(t, sess.ID(), token) @@ -439,6 +455,7 @@ func Test_Session_Destroy(t *testing.T) { // get session sess, err := store.Get(ctx) + defer sess.Release() require.NoError(t, err) sess.Set("name", "fenny") @@ -468,6 +485,7 @@ func Test_Session_Destroy(t *testing.T) { id := sess.ID() require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) @@ -476,6 +494,7 @@ func Test_Session_Destroy(t *testing.T) { ctx.Request().Header.Set(store.sessionName, id) sess, err = store.Get(ctx) require.NoError(t, err) + defer sess.Release() err = sess.Destroy() require.NoError(t, err) @@ -512,6 +531,8 @@ func Test_Session_Cookie(t *testing.T) { require.NoError(t, err) require.NoError(t, sess.Save()) + sess.Release() + // cookie should be set on Save ( even if empty data ) require.Len(t, ctx.Response().Header.PeekCookie(store.sessionName), 84) } @@ -535,8 +556,11 @@ func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) { id := sess.ID() require.NoError(t, sess.Save()) + sess.Release() + sess, err = store.Get(ctx) require.NoError(t, err) + defer sess.Release() sess.Set("name", "john") require.True(t, sess.Fresh()) require.Equal(t, id, sess.ID()) // session id should be the same @@ -560,6 +584,7 @@ func Test_Session_Deletes_Single_Key(t *testing.T) { sess.Set("id", "1") require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetCookie(store.sessionName, id) @@ -569,11 +594,13 @@ func Test_Session_Deletes_Single_Key(t *testing.T) { sess.Delete("id") require.NoError(t, sess.Save()) + sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetCookie(store.sessionName, id) sess, err = store.Get(ctx) + defer sess.Release() require.NoError(t, err) require.False(t, sess.Fresh()) require.Nil(t, sess.Get("id")) @@ -610,6 +637,7 @@ func Test_Session_Reset(t *testing.T) { err = freshSession.Save() require.NoError(t, err) + freshSession.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -642,6 +670,8 @@ func Test_Session_Reset(t *testing.T) { err = acquiredSession.Save() require.NoError(t, err) + acquiredSession.Release() + // Check that the session id is not in the header or cookie anymore require.Equal(t, "", string(ctx.Response().Header.Peek(store.sessionName))) require.Equal(t, "", string(ctx.Request().Header.Peek(store.sessionName))) @@ -675,6 +705,8 @@ func Test_Session_Regenerate(t *testing.T) { err = freshSession.Save() require.NoError(t, err) + freshSession.Release() + // release the context app.ReleaseCtx(ctx) @@ -687,6 +719,7 @@ func Test_Session_Regenerate(t *testing.T) { // as the session is in the storage, session.fresh should be false acquiredSession, err := store.Get(ctx) require.NoError(t, err) + defer acquiredSession.Release() require.False(t, acquiredSession.Fresh()) err = acquiredSession.Regenerate() @@ -716,6 +749,8 @@ func Benchmark_Session(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() } }) @@ -734,6 +769,8 @@ func Benchmark_Session(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() } }) } @@ -752,6 +789,9 @@ func Benchmark_Session_Parallel(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() + app.ReleaseCtx(c) } }) @@ -772,6 +812,9 @@ func Benchmark_Session_Parallel(b *testing.B) { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark + + sess.Release() + app.ReleaseCtx(c) } }) @@ -794,6 +837,7 @@ func Benchmark_Session_Asserted(b *testing.B) { sess.Set("john", "doe") err = sess.Save() require.NoError(b, err) + sess.Release() } }) @@ -814,6 +858,7 @@ func Benchmark_Session_Asserted(b *testing.B) { sess.Set("john", "doe") err = sess.Save() require.NoError(b, err) + sess.Release() } }) } @@ -833,6 +878,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { require.NoError(b, err) sess.Set("john", "doe") require.NoError(b, sess.Save()) + sess.Release() app.ReleaseCtx(c) } }) @@ -854,6 +900,7 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) { require.NoError(b, err) sess.Set("john", "doe") require.NoError(b, sess.Save()) + sess.Release() app.ReleaseCtx(c) } }) @@ -902,6 +949,9 @@ func Test_Session_Concurrency(t *testing.T) { return } + // release the session + sess.Release() + // Release the context app.ReleaseCtx(localCtx) @@ -918,6 +968,7 @@ func Test_Session_Concurrency(t *testing.T) { errChan <- err return } + defer sess.Release() // Get the value name := sess.Get("name") @@ -964,33 +1015,3 @@ func Test_Session_Concurrency(t *testing.T) { require.NoError(t, err) } } - -// go test -v race -run Test_Session_Release -count 4 -func Test_Session_Release(t *testing.T) { - t.Parallel() - - // session store - store := newStore() - // fiber instance - app := fiber.New() - // fiber context - ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) - - // acquire a new session - sess := acquireSession() - sess.ctx = ctx - sess.config = store - rid, _ := uuid.NewRandom() - sess.id = rid.String() - - // release the session - sess.Release() - - // assertions - require.Empty(t, sess.id) - require.Nil(t, sess.ctx) - require.Nil(t, sess.config) - require.Empty(t, sess.Keys()) - require.Zero(t, sess.byteBuffer.Len()) -}