Skip to content

Commit 74de6a3

Browse files
committed
fix: proper concurrent access to field map cache
1 parent 78ef2c1 commit 74de6a3

File tree

5 files changed

+76
-17
lines changed

5 files changed

+76
-17
lines changed

internal_api.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,5 +493,5 @@ func (e *Enforcer) GetFieldIndex(ptype string, field string) (int, error) {
493493

494494
func (e *Enforcer) SetFieldIndex(ptype string, field string, index int) {
495495
assertion := e.model["p"][ptype]
496-
assertion.FieldIndexMap[field] = index
496+
assertion.FieldIndexMap.Store(field, index)
497497
}

model/assertion.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package model
1717
import (
1818
"errors"
1919
"strings"
20+
"sync"
2021

2122
"github.com/casbin/casbin/v2/log"
2223
"github.com/casbin/casbin/v2/rbac"
@@ -33,11 +34,25 @@ type Assertion struct {
3334
PolicyMap map[string]int
3435
RM rbac.RoleManager
3536
CondRM rbac.ConditionalRoleManager
36-
FieldIndexMap map[string]int
37+
FieldIndexMap sync.Map
3738

3839
logger log.Logger
3940
}
4041

42+
func (ast *Assertion) GetFieldIndex(field string) (idx int, ok bool) {
43+
value, found := ast.FieldIndexMap.Load(field)
44+
if found {
45+
ok = true
46+
idx = value.(int)
47+
}
48+
return idx, ok
49+
}
50+
51+
func (ast *Assertion) GetFieldIndexOrZero(field string) (idx int) {
52+
idx, _ = ast.GetFieldIndex(field)
53+
return idx
54+
}
55+
4156
func (ast *Assertion) buildIncrementalRoleLinks(rm rbac.RoleManager, op PolicyOp, rules [][]string) error {
4257
ast.RM = rm
4358
count := strings.Count(ast.Value, "_")
@@ -182,13 +197,17 @@ func (ast *Assertion) copy() *Assertion {
182197
}
183198

184199
newAst := &Assertion{
185-
Key: ast.Key,
186-
Value: ast.Value,
187-
PolicyMap: policyMap,
188-
Tokens: tokens,
189-
Policy: policy,
190-
FieldIndexMap: ast.FieldIndexMap,
200+
Key: ast.Key,
201+
Value: ast.Value,
202+
PolicyMap: policyMap,
203+
Tokens: tokens,
204+
Policy: policy,
191205
}
192206

207+
ast.FieldIndexMap.Range(func(key, value interface{}) bool {
208+
newAst.FieldIndexMap.Store(key, value)
209+
return true
210+
})
211+
193212
return newAst
194213
}

model/model.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ func (model Model) AddDef(sec string, key string, value string) bool {
7676
ast.Key = key
7777
ast.Value = value
7878
ast.PolicyMap = make(map[string]int)
79-
ast.FieldIndexMap = make(map[string]int)
8079
ast.setLogger(model.GetLogger())
8180

8281
if sec == "r" || sec == "p" {
@@ -419,7 +418,7 @@ func (model Model) Copy() Model {
419418

420419
func (model Model) GetFieldIndex(ptype string, field string) (int, error) {
421420
assertion := model["p"][ptype]
422-
if index, ok := assertion.FieldIndexMap[field]; ok {
421+
if index, ok := assertion.GetFieldIndex(field); ok {
423422
return index, nil
424423
}
425424
pattern := fmt.Sprintf("%s_"+field, ptype)
@@ -433,6 +432,6 @@ func (model Model) GetFieldIndex(ptype string, field string) (int, error) {
433432
if index == -1 {
434433
return index, fmt.Errorf(field + " index is not set, please use enforcer.SetFieldIndex() to set index")
435434
}
436-
assertion.FieldIndexMap[field] = index
435+
assertion.FieldIndexMap.Store(field, index)
437436
return index, nil
438437
}

model/model_test.go

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
package model
1616

1717
import (
18-
"io/ioutil"
18+
"os"
1919
"path/filepath"
2020
"strings"
2121
"testing"
@@ -63,7 +63,7 @@ func TestNewModelFromFile(t *testing.T) {
6363
}
6464

6565
func TestNewModelFromString(t *testing.T) {
66-
modelBytes, _ := ioutil.ReadFile(basicExample)
66+
modelBytes, _ := os.ReadFile(basicExample)
6767
modelString := string(modelBytes)
6868
m, err := NewModelFromString(modelString)
6969
if err != nil {
@@ -125,7 +125,47 @@ func TestModel_AddDef(t *testing.T) {
125125
}
126126
}
127127

128-
func TestModelToTest(t *testing.T) {
128+
func TestModel_Copy(t *testing.T) {
129+
m, err := NewModelFromFile(basicExample)
130+
if err != nil {
131+
t.Errorf("model failed to load from file: %s", err)
132+
}
133+
134+
newModel := m.Copy()
135+
if newModel.ToText() != m.ToText() {
136+
t.Errorf("new model is not equal to original")
137+
}
138+
}
139+
140+
func TestModel_Copy_includesFieldMapInCopy(t *testing.T) {
141+
m, err := NewModelFromFile(basicExample)
142+
if err != nil {
143+
t.Errorf("model failed to load from file: %s", err)
144+
}
145+
146+
idx, _ := m.GetFieldIndex("p", "act")
147+
if idx != 2 {
148+
t.Errorf("unexpected field index: %d", idx)
149+
}
150+
151+
newModel := m.Copy()
152+
if newModel.ToText() != m.ToText() {
153+
t.Error("new model is not equal to original")
154+
}
155+
156+
assertion, err := newModel.GetAssertion("p", "p")
157+
if err != nil {
158+
t.Errorf("model failed to get assertion: %s", err)
159+
}
160+
if _, ok := assertion.GetFieldIndex("act"); !ok {
161+
t.Errorf("model does not have the field index in cache")
162+
}
163+
if idx, err := newModel.GetFieldIndex("p", "act"); err != nil || idx != 2 {
164+
t.Errorf("unexpected field index: %s - %d", err, idx)
165+
}
166+
}
167+
168+
func TestModel_ToText(t *testing.T) {
129169
testModelToText(t, "r.sub == p.sub && r.obj == p.obj && r_func(r.act, p.act) && testr_func(r.act, p.act)", "r_sub == p_sub && r_obj == p_obj && r_func(r_act, p_act) && testr_func(r_act, p_act)")
130170
testModelToText(t, "r.sub == p.sub && r.obj == p.obj && p_func(r.act, p.act) && testp_func(r.act, p.act)", "r_sub == p_sub && r_obj == p_obj && p_func(r_act, p_act) && testp_func(r_act, p_act)")
131171
}

model/policy.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,15 @@ func (model Model) AddPolicy(sec string, ptype string, rule []string) error {
229229
assertion.PolicyMap[strings.Join(rule, DefaultSep)] = len(model[sec][ptype].Policy) - 1
230230

231231
hasPriority := false
232-
if _, ok := assertion.FieldIndexMap[constant.PriorityIndex]; ok {
232+
if _, ok := assertion.FieldIndexMap.Load(constant.PriorityIndex); ok {
233233
hasPriority = true
234234
}
235235
if sec == "p" && hasPriority {
236-
if idxInsert, err := strconv.Atoi(rule[assertion.FieldIndexMap[constant.PriorityIndex]]); err == nil {
236+
priorityIndex := assertion.GetFieldIndexOrZero(constant.PriorityIndex)
237+
if idxInsert, err := strconv.Atoi(rule[priorityIndex]); err == nil {
237238
i := len(assertion.Policy) - 1
238239
for ; i > 0; i-- {
239-
idx, err := strconv.Atoi(assertion.Policy[i-1][assertion.FieldIndexMap[constant.PriorityIndex]])
240+
idx, err := strconv.Atoi(assertion.Policy[i-1][priorityIndex])
240241
if err != nil || idx <= idxInsert {
241242
break
242243
}

0 commit comments

Comments
 (0)