-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoauth.go
More file actions
117 lines (95 loc) · 2.59 KB
/
oauth.go
File metadata and controls
117 lines (95 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
//go:build !wasm
package user
import (
"net/http"
"time"
"github.com/tinywasm/orm"
"github.com/tinywasm/unixid"
)
func (m *Module) BeginOAuth(providerName string) (string, error) {
p := m.getProvider(providerName)
if p == nil {
return "", ErrProviderNotFound
}
u, err := unixid.NewUnixID()
if err != nil {
return "", err
}
state := u.GetNewID()
now := time.Now().Unix()
expiresAt := now + 600 // 10 minutes
stateObj := &OAuthState{
State: state,
Provider: providerName,
ExpiresAt: expiresAt,
CreatedAt: now,
}
if err := m.db.Create(stateObj); err != nil {
return "", err
}
return p.AuthCodeURL(state), nil
}
func (m *Module) CompleteOAuth(providerName string, r *http.Request, ip, ua string) (User, bool, error) {
state := r.URL.Query().Get("state")
if err := consumeState(m.db, state, providerName); err != nil {
return User{}, false, ErrInvalidOAuthState
}
p := m.getProvider(providerName)
if p == nil {
return User{}, false, ErrProviderNotFound
}
token, err := p.ExchangeCode(r.Context(), r.URL.Query().Get("code"))
if err != nil {
return User{}, false, err
}
info, err := p.GetUserInfo(r.Context(), token)
if err != nil {
return User{}, false, err
}
identity, err := getIdentityByProvider(m.db, providerName, info.ID)
if err == nil {
u, err := getUser(m.db, m.ucache, identity.UserID)
return u, false, err
}
u, err := getUserByEmail(m.db, m.ucache, info.Email)
if err == nil {
_ = createIdentity(m.db, u.ID, providerName, info.ID, info.Email)
return u, false, nil
}
u, err = createUser(m.db, info.Email, info.Name, "")
if err != nil {
return User{}, false, err
}
_ = createIdentity(m.db, u.ID, providerName, info.ID, info.Email)
return u, true, nil
}
func consumeState(db *orm.DB, state, provider string) error {
qb := db.Query(&OAuthState{}).Where(OAuthState_.State).Eq(state)
results, err := ReadAllOAuthState(qb)
if err != nil {
return err
}
if len(results) == 0 {
return ErrInvalidOAuthState
}
stateObj := results[0]
if stateObj.Provider != provider {
return ErrInvalidOAuthState
}
// Delete state (single use) - done regardless of expiration to prevent reuse
if err := db.Delete(stateObj, orm.Eq(OAuthState_.State, stateObj.State)); err != nil {
return err
}
if stateObj.ExpiresAt < time.Now().Unix() {
return ErrInvalidOAuthState
}
return nil
}
func (m *Module) PurgeExpiredOAuthStates() error {
qb := m.db.Query(&OAuthState{}).Where(OAuthState_.ExpiresAt).Lt(time.Now().Unix())
states, _ := ReadAllOAuthState(qb)
for _, s := range states {
m.db.Delete(s, orm.Eq(OAuthState_.State, s.State))
}
return nil
}