diff --git a/middles/oauth/cookies.go b/middles/oauth/cookies.go index e8b77b0..f1afe21 100644 --- a/middles/oauth/cookies.go +++ b/middles/oauth/cookies.go @@ -14,31 +14,31 @@ import ( // // Each cookie minted is of the same name; i.e. the name associated with the // cookie in the requester's cookie jar (web browser / http client). -type CookieFactory struct { +type CookieFactory[U Unique] struct { Name string Secure bool Clock func() time.Time } // CookieContent is the data stored per session. -type CookieContent struct { - UserToken string `json:"token"` - UserID Identity `json:"user_id"` +type CookieContent[U Unique] struct { + UserToken string `json:"token"` + UserID U `json:"user_id"` } // Token returns the secret token associated with the cookie. -func (cc *CookieContent) Token() *conceal.Text { +func (cc *CookieContent[U]) Token() *conceal.Text { return conceal.New(cc.UserToken) } // Identity returns the identity associated with the cookie. -func (cc *CookieContent) Identity() Identity { +func (cc *CookieContent[U]) Identity() U { return cc.UserID } // Create the cookie. -func (cf *CookieFactory) Create( - id Identity, +func (cf *CookieFactory[U]) Create( + id U, token *conceal.Text, ttl time.Duration, ) *http.Cookie { @@ -46,7 +46,7 @@ func (cf *CookieFactory) Create( expiration := cf.Clock().Add(ttl) // encode the cookie payload as base64 json - b, _ := json.Marshal(&CookieContent{ + b, _ := json.Marshal(&CookieContent[U]{ UserToken: token.Unveil(), UserID: id, }) diff --git a/middles/oauth/cookies_test.go b/middles/oauth/cookies_test.go index 98fdfc4..f552508 100644 --- a/middles/oauth/cookies_test.go +++ b/middles/oauth/cookies_test.go @@ -16,10 +16,12 @@ func testNow() time.Time { const testUser = 12345 +type rowid uint // example of Unique type + func TestCookieFactory_Create_NameAndPath(t *testing.T) { t.Parallel() - cf := &CookieFactory{ + cf := &CookieFactory[rowid]{ Name: "session-id", Clock: testNow, } @@ -34,7 +36,7 @@ func TestCookieFactory_Create_NameAndPath(t *testing.T) { func TestCookieFactory_Create_Expiration(t *testing.T) { t.Parallel() - cf := &CookieFactory{ + cf := &CookieFactory[rowid]{ Clock: testNow, } @@ -48,7 +50,7 @@ func TestCookieFactory_Create_Expiration(t *testing.T) { func TestCookieFactory_Create_SecureFlag(t *testing.T) { t.Parallel() - cf := &CookieFactory{ + cf := &CookieFactory[rowid]{ Clock: testNow, Secure: true, } @@ -61,7 +63,7 @@ func TestCookieFactory_Create_SecureFlag(t *testing.T) { func TestCookieFactory_Create_ValueEncoding(t *testing.T) { t.Parallel() - cf := &CookieFactory{ + cf := &CookieFactory[rowid]{ Clock: testNow, } @@ -71,7 +73,7 @@ func TestCookieFactory_Create_ValueEncoding(t *testing.T) { decoded, err := base64.StdEncoding.DecodeString(cookie.Value) must.NoError(t, err) - var content CookieContent + var content CookieContent[rowid] jerr := json.Unmarshal(decoded, &content) must.NoError(t, jerr) diff --git a/middles/oauth/sessions.go b/middles/oauth/sessions.go index 8e33dfa..6ec73e9 100644 --- a/middles/oauth/sessions.go +++ b/middles/oauth/sessions.go @@ -25,30 +25,34 @@ type Cache[K, T any] interface { Put(K, T, time.Duration) } -type Identity int64 +// Unique is a unique number assigned to each user that can be associated +// with any number of sessions. Typically a ROWID number from a database. +type Unique interface { + ~int | ~int64 | ~uint | ~uint64 +} -type Sessions struct { - Cache Cache[*conceal.Text, Identity] - CookieFactory *CookieFactory +type Sessions[U Unique] struct { + Cache Cache[*conceal.Text, U] + CookieFactory *CookieFactory[U] } // NewSessions creates a new Sessions for managing sessions and cookies // associated with those sessions. -func NewSessions(cookies *CookieFactory, cache Cache[*conceal.Text, Identity]) *Sessions { - return &Sessions{ +func NewSessions[U Unique](cookies *CookieFactory[U], cache Cache[*conceal.Text, U]) *Sessions[U] { + return &Sessions[U]{ Cache: cache, CookieFactory: cookies, } } -func (s *Sessions) Create(id Identity, ttl time.Duration) *http.Cookie { +func (s *Sessions[U]) Create(id U, ttl time.Duration) *http.Cookie { token := conceal.UUIDv4() cookie := s.CookieFactory.Create(id, token, ttl) s.Cache.Put(token, id, ttl) return cookie } -func (s *Sessions) Match(id Identity, token *conceal.Text) error { +func (s *Sessions[U]) Match(id U, token *conceal.Text) error { actual, exists := s.Cache.Get(token) switch { diff --git a/middles/oauth/sessions_test.go b/middles/oauth/sessions_test.go index 22eefe3..ec27a72 100644 --- a/middles/oauth/sessions_test.go +++ b/middles/oauth/sessions_test.go @@ -12,31 +12,31 @@ import ( // mockCache provides an in-memory implementation of the Cache interface type mockCache struct { - storage map[string]Identity + storage map[string]rowid } -func (m *mockCache) Get(k *conceal.Text) (Identity, bool) { +func (m *mockCache) Get(k *conceal.Text) (rowid, bool) { id, ok := m.storage[k.Unveil()] return id, ok } -func (m *mockCache) Put(k *conceal.Text, v Identity, _ time.Duration) { +func (m *mockCache) Put(k *conceal.Text, v rowid, _ time.Duration) { m.storage[k.Unveil()] = v } func TestSessions_Create(t *testing.T) { t.Parallel() - cache := &mockCache{storage: make(map[string]Identity)} + cache := &mockCache{storage: make(map[string]rowid)} // initialize the sessions manager with the cache and a cookie factory - sessions := NewSessions(&CookieFactory{ + sessions := NewSessions(&CookieFactory[rowid]{ Name: "session-token", Secure: true, Clock: testNow, }, cache) - id := Identity(12345) + id := rowid(12345) ttl := 1 * time.Hour cookie := sessions.Create(id, ttl) @@ -50,7 +50,7 @@ func TestSessions_Create(t *testing.T) { must.NoError(t, berr) // unamrshal the json value content - cc := new(CookieContent) + cc := new(CookieContent[rowid]) jerr := json.Unmarshal(b, cc) must.NoError(t, jerr) @@ -63,11 +63,11 @@ func TestSessions_Create(t *testing.T) { func TestSessions_Match(t *testing.T) { t.Parallel() - cookies := (*CookieFactory)(nil) - cache := &mockCache{storage: make(map[string]Identity)} + cookies := (*CookieFactory[rowid])(nil) + cache := &mockCache{storage: make(map[string]rowid)} sessions := NewSessions(cookies, cache) - id := Identity(12345) + id := rowid(12345) token := conceal.UUIDv4() // seed the cache with a known session @@ -85,7 +85,7 @@ func TestSessions_Match(t *testing.T) { }) t.Run("match not a match", func(t *testing.T) { - wrongID := Identity(99999) + wrongID := rowid(99999) err := sessions.Match(wrongID, token) must.ErrorIs(t, err, ErrNotMatch) }) diff --git a/middles/sessions.go b/middles/sessions.go index b2a964c..5e49b18 100644 --- a/middles/sessions.go +++ b/middles/sessions.go @@ -21,10 +21,10 @@ type Sessions[I identity.UserIdentity] interface { // // If on session is found, an implementation of identity.UserSession where // .Active() always returns false is returned, indicating there is no session. -func GetSession(r *http.Request) identity.UserSession[identity.UserIdentity] { - value, ok := r.Context().Value(sessionContextKey).(identity.UserSession[identity.UserIdentity]) +func GetSession[I identity.UserIdentity](r *http.Request) identity.UserSession[I] { + value, ok := r.Context().Value(sessionContextKey).(identity.UserSession[I]) if !ok { - return &session[identity.UserIdentity]{ + return &session[I]{ active: false, } } @@ -79,6 +79,7 @@ func (ss *SetSession[D, I]) ServeHTTP(w http.ResponseWriter, r *http.Request) { live := &session[I]{id: data.Identity(), active: true} ctx2 := context.WithValue(r.Context(), sessionContextKey, live) r2 := r.WithContext(ctx2) + ss.Next.ServeHTTP(w, r2) }