Skip to content

Commit ee2a90b

Browse files
feat: Add RWMutex to Assertion to ensure safe concurrent map access(#1495)
1 parent 78ef2c1 commit ee2a90b

File tree

3 files changed

+62
-35
lines changed

3 files changed

+62
-35
lines changed

model/assertion.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ package model
1616

1717
import (
1818
"errors"
19+
"github.com/casbin/casbin/v2/constant"
20+
"strconv"
1921
"strings"
22+
"sync"
2023

2124
"github.com/casbin/casbin/v2/log"
2225
"github.com/casbin/casbin/v2/rbac"
@@ -29,6 +32,7 @@ type Assertion struct {
2932
Value string
3033
Tokens []string
3134
ParamsTokens []string
35+
policyMu sync.RWMutex
3236
Policy [][]string
3337
PolicyMap map[string]int
3438
RM rbac.RoleManager
@@ -192,3 +196,28 @@ func (ast *Assertion) copy() *Assertion {
192196

193197
return newAst
194198
}
199+
200+
func (ast *Assertion) addPolicy(sec string, rule []string) {
201+
ast.Policy = append(ast.Policy, rule)
202+
ast.PolicyMap[strings.Join(rule, DefaultSep)] = len(ast.Policy) - 1
203+
204+
hasPriority := false
205+
if _, ok := ast.FieldIndexMap[constant.PriorityIndex]; ok {
206+
hasPriority = true
207+
}
208+
if sec == "p" && hasPriority {
209+
if idxInsert, err := strconv.Atoi(rule[ast.FieldIndexMap[constant.PriorityIndex]]); err == nil {
210+
i := len(ast.Policy) - 1
211+
for ; i > 0; i-- {
212+
idx, err := strconv.Atoi(ast.Policy[i-1][ast.FieldIndexMap[constant.PriorityIndex]])
213+
if err != nil || idx <= idxInsert {
214+
break
215+
}
216+
ast.Policy[i] = ast.Policy[i-1]
217+
ast.PolicyMap[strings.Join(ast.Policy[i-1], DefaultSep)]++
218+
}
219+
ast.Policy[i] = rule
220+
ast.PolicyMap[strings.Join(rule, DefaultSep)] = i
221+
}
222+
}
223+
}

model/model.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,8 @@ func (model Model) Copy() Model {
419419

420420
func (model Model) GetFieldIndex(ptype string, field string) (int, error) {
421421
assertion := model["p"][ptype]
422+
assertion.policyMu.RLock()
423+
defer assertion.policyMu.RUnlock()
422424
if index, ok := assertion.FieldIndexMap[field]; ok {
423425
return index, nil
424426
}

model/policy.go

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@ package model
1616

1717
import (
1818
"fmt"
19-
"strconv"
2019
"strings"
2120

22-
"github.com/casbin/casbin/v2/constant"
2321
"github.com/casbin/casbin/v2/rbac"
2422
"github.com/casbin/casbin/v2/util"
2523
)
@@ -123,31 +121,39 @@ func (model Model) PrintPolicy() {
123121
// ClearPolicy clears all current policy.
124122
func (model Model) ClearPolicy() {
125123
for _, ast := range model["p"] {
124+
ast.policyMu.Lock()
126125
ast.Policy = nil
127126
ast.PolicyMap = map[string]int{}
127+
ast.policyMu.Unlock()
128128
}
129129

130130
for _, ast := range model["g"] {
131+
ast.policyMu.Lock()
131132
ast.Policy = nil
132133
ast.PolicyMap = map[string]int{}
134+
ast.policyMu.Unlock()
133135
}
134136
}
135137

136138
// GetPolicy gets all rules in a policy.
137139
func (model Model) GetPolicy(sec string, ptype string) ([][]string, error) {
138-
_, err := model.GetAssertion(sec, ptype)
140+
ast, err := model.GetAssertion(sec, ptype)
139141
if err != nil {
140142
return nil, err
141143
}
144+
ast.policyMu.RLock()
145+
defer ast.policyMu.RUnlock()
142146
return model[sec][ptype].Policy, nil
143147
}
144148

145149
// GetFilteredPolicy gets rules based on field filters from a policy.
146150
func (model Model) GetFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) ([][]string, error) {
147-
_, err := model.GetAssertion(sec, ptype)
151+
ast, err := model.GetAssertion(sec, ptype)
148152
if err != nil {
149153
return nil, err
150154
}
155+
ast.policyMu.RLock()
156+
defer ast.policyMu.RUnlock()
151157
res := [][]string{}
152158

153159
for _, rule := range model[sec][ptype].Policy {
@@ -196,10 +202,12 @@ func (model Model) HasPolicyEx(sec string, ptype string, rule []string) (bool, e
196202

197203
// HasPolicy determines whether a model has the specified policy rule.
198204
func (model Model) HasPolicy(sec string, ptype string, rule []string) (bool, error) {
199-
_, err := model.GetAssertion(sec, ptype)
205+
ast, err := model.GetAssertion(sec, ptype)
200206
if err != nil {
201207
return false, err
202208
}
209+
ast.policyMu.RLock()
210+
defer ast.policyMu.RUnlock()
203211
_, ok := model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)]
204212
return ok, nil
205213
}
@@ -225,28 +233,9 @@ func (model Model) AddPolicy(sec string, ptype string, rule []string) error {
225233
if err != nil {
226234
return err
227235
}
228-
assertion.Policy = append(assertion.Policy, rule)
229-
assertion.PolicyMap[strings.Join(rule, DefaultSep)] = len(model[sec][ptype].Policy) - 1
230-
231-
hasPriority := false
232-
if _, ok := assertion.FieldIndexMap[constant.PriorityIndex]; ok {
233-
hasPriority = true
234-
}
235-
if sec == "p" && hasPriority {
236-
if idxInsert, err := strconv.Atoi(rule[assertion.FieldIndexMap[constant.PriorityIndex]]); err == nil {
237-
i := len(assertion.Policy) - 1
238-
for ; i > 0; i-- {
239-
idx, err := strconv.Atoi(assertion.Policy[i-1][assertion.FieldIndexMap[constant.PriorityIndex]])
240-
if err != nil || idx <= idxInsert {
241-
break
242-
}
243-
assertion.Policy[i] = assertion.Policy[i-1]
244-
assertion.PolicyMap[strings.Join(assertion.Policy[i-1], DefaultSep)]++
245-
}
246-
assertion.Policy[i] = rule
247-
assertion.PolicyMap[strings.Join(rule, DefaultSep)] = i
248-
}
249-
}
236+
assertion.policyMu.Lock()
237+
defer assertion.policyMu.Unlock()
238+
assertion.addPolicy(sec, rule)
250239
return nil
251240
}
252241

@@ -258,10 +247,12 @@ func (model Model) AddPolicies(sec string, ptype string, rules [][]string) error
258247

259248
// AddPoliciesWithAffected adds policy rules to the model, and returns affected rules.
260249
func (model Model) AddPoliciesWithAffected(sec string, ptype string, rules [][]string) ([][]string, error) {
261-
_, err := model.GetAssertion(sec, ptype)
250+
assertion, err := model.GetAssertion(sec, ptype)
262251
if err != nil {
263252
return nil, err
264253
}
254+
assertion.policyMu.Lock()
255+
defer assertion.policyMu.Unlock()
265256
var affected [][]string
266257
for _, rule := range rules {
267258
hashKey := strings.Join(rule, DefaultSep)
@@ -270,10 +261,7 @@ func (model Model) AddPoliciesWithAffected(sec string, ptype string, rules [][]s
270261
continue
271262
}
272263
affected = append(affected, rule)
273-
err = model.AddPolicy(sec, ptype, rule)
274-
if err != nil {
275-
return affected, err
276-
}
264+
assertion.addPolicy(sec, rule)
277265
}
278266
return affected, err
279267
}
@@ -285,6 +273,8 @@ func (model Model) RemovePolicy(sec string, ptype string, rule []string) (bool,
285273
if err != nil {
286274
return false, err
287275
}
276+
ast.policyMu.Lock()
277+
defer ast.policyMu.Unlock()
288278
key := strings.Join(rule, DefaultSep)
289279
index, ok := ast.PolicyMap[key]
290280
if !ok {
@@ -304,10 +294,12 @@ func (model Model) RemovePolicy(sec string, ptype string, rule []string) (bool,
304294

305295
// UpdatePolicy updates a policy rule from the model.
306296
func (model Model) UpdatePolicy(sec string, ptype string, oldRule []string, newRule []string) (bool, error) {
307-
_, err := model.GetAssertion(sec, ptype)
297+
ast, err := model.GetAssertion(sec, ptype)
308298
if err != nil {
309299
return false, err
310300
}
301+
ast.policyMu.Lock()
302+
defer ast.policyMu.Unlock()
311303
oldPolicy := strings.Join(oldRule, DefaultSep)
312304
index, ok := model[sec][ptype].PolicyMap[oldPolicy]
313305
if !ok {
@@ -323,10 +315,12 @@ func (model Model) UpdatePolicy(sec string, ptype string, oldRule []string, newR
323315

324316
// UpdatePolicies updates a policy rule from the model.
325317
func (model Model) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) (bool, error) {
326-
_, err := model.GetAssertion(sec, ptype)
318+
ast, err := model.GetAssertion(sec, ptype)
327319
if err != nil {
328320
return false, err
329321
}
322+
ast.policyMu.Lock()
323+
defer ast.policyMu.Unlock()
330324
rollbackFlag := false
331325
// index -> []{oldIndex, newIndex}
332326
modifiedRuleIndex := make(map[int][]int)
@@ -370,10 +364,12 @@ func (model Model) RemovePolicies(sec string, ptype string, rules [][]string) (b
370364

371365
// RemovePoliciesWithAffected removes policy rules from the model, and returns affected rules.
372366
func (model Model) RemovePoliciesWithAffected(sec string, ptype string, rules [][]string) ([][]string, error) {
373-
_, err := model.GetAssertion(sec, ptype)
367+
ast, err := model.GetAssertion(sec, ptype)
374368
if err != nil {
375369
return nil, err
376370
}
371+
ast.policyMu.Lock()
372+
defer ast.policyMu.Unlock()
377373
var affected [][]string
378374
for _, rule := range rules {
379375
index, ok := model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)]

0 commit comments

Comments
 (0)