Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions middles/oauth/cookies.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,39 @@ import (
//
// Each cookie minted is of the same name; i.e. the name associated with the
// cookie in the requester's cookie jar (web browser / http client).
type CookieFactory struct {
type CookieFactory[U Unique] struct {
Name string
Secure bool
Clock func() time.Time
}

// CookieContent is the data stored per session.
type CookieContent struct {
UserToken string `json:"token"`
UserID Identity `json:"user_id"`
type CookieContent[U Unique] struct {
UserToken string `json:"token"`
UserID U `json:"user_id"`
}

// Token returns the secret token associated with the cookie.
func (cc *CookieContent) Token() *conceal.Text {
func (cc *CookieContent[U]) Token() *conceal.Text {
return conceal.New(cc.UserToken)
}

// Identity returns the identity associated with the cookie.
func (cc *CookieContent) Identity() Identity {
func (cc *CookieContent[U]) Identity() U {
return cc.UserID
}

// Create the cookie.
func (cf *CookieFactory) Create(
id Identity,
func (cf *CookieFactory[U]) Create(
id U,
token *conceal.Text,
ttl time.Duration,
) *http.Cookie {
// compute the future time cookie expires
expiration := cf.Clock().Add(ttl)

// encode the cookie payload as base64 json
b, _ := json.Marshal(&CookieContent{
b, _ := json.Marshal(&CookieContent[U]{
UserToken: token.Unveil(),
UserID: id,
})
Expand Down
12 changes: 7 additions & 5 deletions middles/oauth/cookies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ func testNow() time.Time {

const testUser = 12345

type rowid uint // example of Unique type

func TestCookieFactory_Create_NameAndPath(t *testing.T) {
t.Parallel()

cf := &CookieFactory{
cf := &CookieFactory[rowid]{
Name: "session-id",
Clock: testNow,
}
Expand All @@ -34,7 +36,7 @@ func TestCookieFactory_Create_NameAndPath(t *testing.T) {
func TestCookieFactory_Create_Expiration(t *testing.T) {
t.Parallel()

cf := &CookieFactory{
cf := &CookieFactory[rowid]{
Clock: testNow,
}

Expand All @@ -48,7 +50,7 @@ func TestCookieFactory_Create_Expiration(t *testing.T) {
func TestCookieFactory_Create_SecureFlag(t *testing.T) {
t.Parallel()

cf := &CookieFactory{
cf := &CookieFactory[rowid]{
Clock: testNow,
Secure: true,
}
Expand All @@ -61,7 +63,7 @@ func TestCookieFactory_Create_SecureFlag(t *testing.T) {
func TestCookieFactory_Create_ValueEncoding(t *testing.T) {
t.Parallel()

cf := &CookieFactory{
cf := &CookieFactory[rowid]{
Clock: testNow,
}

Expand All @@ -71,7 +73,7 @@ func TestCookieFactory_Create_ValueEncoding(t *testing.T) {
decoded, err := base64.StdEncoding.DecodeString(cookie.Value)
must.NoError(t, err)

var content CookieContent
var content CookieContent[rowid]
jerr := json.Unmarshal(decoded, &content)
must.NoError(t, jerr)

Expand Down
20 changes: 12 additions & 8 deletions middles/oauth/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,34 @@ type Cache[K, T any] interface {
Put(K, T, time.Duration)
}

type Identity int64
// Unique is a unique number assigned to each user that can be associated
// with any number of sessions. Typically a ROWID number from a database.
type Unique interface {
~int | ~int64 | ~uint | ~uint64
}

type Sessions struct {
Cache Cache[*conceal.Text, Identity]
CookieFactory *CookieFactory
type Sessions[U Unique] struct {
Cache Cache[*conceal.Text, U]
CookieFactory *CookieFactory[U]
}

// NewSessions creates a new Sessions for managing sessions and cookies
// associated with those sessions.
func NewSessions(cookies *CookieFactory, cache Cache[*conceal.Text, Identity]) *Sessions {
return &Sessions{
func NewSessions[U Unique](cookies *CookieFactory[U], cache Cache[*conceal.Text, U]) *Sessions[U] {
return &Sessions[U]{
Cache: cache,
CookieFactory: cookies,
}
}

func (s *Sessions) Create(id Identity, ttl time.Duration) *http.Cookie {
func (s *Sessions[U]) Create(id U, ttl time.Duration) *http.Cookie {
token := conceal.UUIDv4()
cookie := s.CookieFactory.Create(id, token, ttl)
s.Cache.Put(token, id, ttl)
return cookie
}

func (s *Sessions) Match(id Identity, token *conceal.Text) error {
func (s *Sessions[U]) Match(id U, token *conceal.Text) error {
actual, exists := s.Cache.Get(token)

switch {
Expand Down
22 changes: 11 additions & 11 deletions middles/oauth/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@ import (

// mockCache provides an in-memory implementation of the Cache interface
type mockCache struct {
storage map[string]Identity
storage map[string]rowid
}

func (m *mockCache) Get(k *conceal.Text) (Identity, bool) {
func (m *mockCache) Get(k *conceal.Text) (rowid, bool) {
id, ok := m.storage[k.Unveil()]
return id, ok
}

func (m *mockCache) Put(k *conceal.Text, v Identity, _ time.Duration) {
func (m *mockCache) Put(k *conceal.Text, v rowid, _ time.Duration) {
m.storage[k.Unveil()] = v
}

func TestSessions_Create(t *testing.T) {
t.Parallel()

cache := &mockCache{storage: make(map[string]Identity)}
cache := &mockCache{storage: make(map[string]rowid)}

// initialize the sessions manager with the cache and a cookie factory
sessions := NewSessions(&CookieFactory{
sessions := NewSessions(&CookieFactory[rowid]{
Name: "session-token",
Secure: true,
Clock: testNow,
}, cache)

id := Identity(12345)
id := rowid(12345)
ttl := 1 * time.Hour

cookie := sessions.Create(id, ttl)
Expand All @@ -50,7 +50,7 @@ func TestSessions_Create(t *testing.T) {
must.NoError(t, berr)

// unamrshal the json value content
cc := new(CookieContent)
cc := new(CookieContent[rowid])
jerr := json.Unmarshal(b, cc)
must.NoError(t, jerr)

Expand All @@ -63,11 +63,11 @@ func TestSessions_Create(t *testing.T) {
func TestSessions_Match(t *testing.T) {
t.Parallel()

cookies := (*CookieFactory)(nil)
cache := &mockCache{storage: make(map[string]Identity)}
cookies := (*CookieFactory[rowid])(nil)
cache := &mockCache{storage: make(map[string]rowid)}
sessions := NewSessions(cookies, cache)

id := Identity(12345)
id := rowid(12345)
token := conceal.UUIDv4()

// seed the cache with a known session
Expand All @@ -85,7 +85,7 @@ func TestSessions_Match(t *testing.T) {
})

t.Run("match not a match", func(t *testing.T) {
wrongID := Identity(99999)
wrongID := rowid(99999)
err := sessions.Match(wrongID, token)
must.ErrorIs(t, err, ErrNotMatch)
})
Expand Down
7 changes: 4 additions & 3 deletions middles/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ type Sessions[I identity.UserIdentity] interface {
//
// If on session is found, an implementation of identity.UserSession where
// .Active() always returns false is returned, indicating there is no session.
func GetSession(r *http.Request) identity.UserSession[identity.UserIdentity] {
value, ok := r.Context().Value(sessionContextKey).(identity.UserSession[identity.UserIdentity])
func GetSession[I identity.UserIdentity](r *http.Request) identity.UserSession[I] {
value, ok := r.Context().Value(sessionContextKey).(identity.UserSession[I])
if !ok {
return &session[identity.UserIdentity]{
return &session[I]{
active: false,
}
}
Expand Down Expand Up @@ -79,6 +79,7 @@ func (ss *SetSession[D, I]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
live := &session[I]{id: data.Identity(), active: true}
ctx2 := context.WithValue(r.Context(), sessionContextKey, live)
r2 := r.WithContext(ctx2)

ss.Next.ServeHTTP(w, r2)
}

Expand Down