Skip to content

Commit

Permalink
refactor(middleware/session): Update session middleware initializatio…
Browse files Browse the repository at this point in the history
…n and saving
  • Loading branch information
sixcolors committed Sep 20, 2024
1 parent 13a1eb4 commit 9d3b032
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 75 deletions.
111 changes: 52 additions & 59 deletions middleware/session/middleware.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Package session provides session management middleware for Fiber.
// This middleware allows you to manage user sessions, including storing session data in the store.
// This middleware handles user sessions, including storing session data in the store.
package session

import (
Expand All @@ -10,7 +10,7 @@ import (
"github.com/gofiber/fiber/v3/log"
)

// Middleware defines the session middleware configuration
// Middleware holds session data and configuration.
type Middleware struct {
Session *Session
ctx *fiber.Ctx
Expand All @@ -19,21 +19,22 @@ type Middleware struct {
destroyed bool
}

// key for looking up session middleware in request context
// Context key for session middleware lookup.
const key = 0

var (
// ErrTypeAssertionFailed is returned when the type assertion failed
// ErrTypeAssertionFailed occurs when a type assertion fails.
ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware")

// Pool for reusing middleware instances.
middlewarePool = &sync.Pool{
New: func() any {
return &Middleware{}
},
}
)

// New creates a new session middleware with the given configuration.
// New initializes session middleware with optional configuration.
//
// Parameters:
// - config: Variadic parameter to override default config.
Expand All @@ -44,18 +45,20 @@ var (
// Usage:
//
// app.Use(session.New())
//
// Usage:
//
// app.Use(session.New())
func New(config ...Config) fiber.Handler {
var handler fiber.Handler
if len(config) > 0 {
handler, _ = NewWithStore(config[0])
} else {
handler, _ = NewWithStore()
handler, _ := NewWithStore(config[0])
return handler
}

handler, _ := NewWithStore()
return handler
}

// NewWithStore returns a new session middleware with the given store.
// NewWithStore creates session middleware with an optional custom store.
//
// Parameters:
// - config: Variadic parameter to override default config.
Expand All @@ -75,80 +78,71 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) {
}

handler := func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}

// Get the session
session, err := cfg.Store.getSession(c)
if err != nil {
return err
}

// get a middleware from the pool
// Acquire session middleware
m := acquireMiddleware()
m.mu.Lock()
m.config = cfg
m.Session = session
m.ctx = &c

// Store the middleware in the context
c.Locals(key, m)
m.mu.Unlock()
m.initialize(c, cfg)

// Continue stack
stackErr := c.Next()

m.mu.RLock()
destroyed := m.destroyed
m.mu.RUnlock()

if !destroyed {
// Save the session
// This is done after the response is sent to the client
// It allows us to modify the session data during the request
// without having to worry about calling Save() on the session.
//
// It will also extend the session idle timeout automatically.
if err := session.saveSession(); err != nil {
if cfg.ErrorHandler != nil {
cfg.ErrorHandler(&c, err)
} else {
DefaultErrorHandler(&c, err)
}
}

// Release the session back to the pool
releaseSession(session)
m.saveSession()
}

// release the middleware back to the pool
releaseMiddleware(m)

return stackErr
}

return handler, cfg.Store
}

// acquireMiddleware returns a new Middleware from the pool.
//
// Returns:
// - *Middleware: The middleware object.
//
// Usage:
//
// m := acquireMiddleware()
// initialize sets up middleware for the request.
func (m *Middleware) initialize(c fiber.Ctx, cfg Config) {
m.mu.Lock()
defer m.mu.Unlock()

session, err := cfg.Store.getSession(c)
if err != nil {
panic(err) // handle or log this error appropriately in production
}

m.config = cfg
m.Session = session
m.ctx = &c

c.Locals(key, m)
}

// saveSession handles session saving and error management after the response.
func (m *Middleware) saveSession() {
if err := m.Session.saveSession(); err != nil {
if m.config.ErrorHandler != nil {
m.config.ErrorHandler(m.ctx, err)
} else {
DefaultErrorHandler(m.ctx, err)
}
}

releaseSession(m.Session)
}

// acquireMiddleware retrieves a middleware instance from the pool.
func acquireMiddleware() *Middleware {
middleware, ok := middlewarePool.Get().(*Middleware)
m, ok := middlewarePool.Get().(*Middleware)
if !ok {
panic(ErrTypeAssertionFailed.Error())
}
return middleware
return m
}

// releaseMiddleware returns a Middleware to the pool.
// releaseMiddleware resets and returns middleware to the pool.
//
// Parameters:
// - m: The middleware object to release.
Expand Down Expand Up @@ -289,8 +283,7 @@ func (m *Middleware) Reset() error {
m.mu.Lock()
defer m.mu.Unlock()

err := m.Session.Reset()
return err
return m.Session.Reset()
}

// Store returns the session store.
Expand Down
32 changes: 16 additions & 16 deletions middleware/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Session struct {
ctx fiber.Ctx // fiber context
config *Store // store configuration
data *data // key value data
byteBuffer *bytes.Buffer // byte buffer for the en- and decode
byteBuffer *bytes.Buffer // byte buffer for encoding/decoding
id string // session id
idleTimeout time.Duration // idleTimeout of this session
mu sync.RWMutex // Mutex to protect non-data fields
Expand All @@ -26,7 +26,9 @@ type Session struct {

var sessionPool = sync.Pool{
New: func() any {
return new(Session)
return &Session{
byteBuffer: new(bytes.Buffer),
}
},
}

Expand All @@ -43,9 +45,6 @@ func acquireSession() *Session {
if s.data == nil {
s.data = acquireData()
}
if s.byteBuffer == nil {
s.byteBuffer = new(bytes.Buffer)
}
s.fresh = true
return s
}
Expand Down Expand Up @@ -90,7 +89,7 @@ func releaseSession(s *Session) {
sessionPool.Put(s)
}

// Fresh returns true if the current session is new.
// Fresh returns whether the session is new
//
// Returns:
// - bool: True if the session is fresh, otherwise false.
Expand All @@ -104,7 +103,7 @@ func (s *Session) Fresh() bool {
return s.fresh
}

// ID returns the session id.
// ID returns the session ID
//
// Returns:
// - string: The session ID.
Expand Down Expand Up @@ -263,12 +262,10 @@ func (s *Session) refresh() {
s.fresh = true
}

// Save updates the storage and client cookie.
//
// sess.Save() will save the session data to the storage and update the
// client cookie.
// Save saves the session data and updates the cookie
//
// Checks if the session is being used in the handler, if so, it will not save the session.
// Note: If the session is being used in the handler, calling Save will have
// no effect and the session will automatically be saved when the handler returns.
//
// Returns:
// - error: An error if the save operation fails.
Expand All @@ -288,6 +285,7 @@ func (s *Session) Save() error {
return s.saveSession()
}

// saveSession encodes session data to saves it to storage.
func (s *Session) saveSession() error {
if s.data == nil {
return nil
Expand All @@ -296,17 +294,19 @@ func (s *Session) saveSession() error {
s.mu.Lock()
defer s.mu.Unlock()

// Check is the session has an idle timeout
// Set idleTimeout if not already set
if s.idleTimeout <= 0 {
s.idleTimeout = s.config.IdleTimeout
}

// Update client cookie
s.setSession()

// Convert data to bytes
// Encode session data
encCache := gob.NewEncoder(s.byteBuffer)
s.data.RLock()
err := encCache.Encode(&s.data.Data)
s.data.RUnlock()
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}
Expand Down Expand Up @@ -334,7 +334,7 @@ func (s *Session) Keys() []string {
return s.data.Keys()
}

// SetIdleTimeout sets a specific idle timeout for the session.
// SetIdleTimeout used when saving the session on the next call to `Save()`.
//
// Parameters:
// - idleTimeout: The duration for the idle timeout.
Expand Down Expand Up @@ -411,7 +411,7 @@ func (s *Session) delSession() {
}
}

// decodeSessionData decodes the session data from raw bytes.
// decodeSessionData decodes session data from raw bytes
//
// Parameters:
// - rawData: The raw byte data to decode.
Expand Down

0 comments on commit 9d3b032

Please sign in to comment.