Skip to content

Commit a9fcad3

Browse files
Copilothsluoyz
andcommitted
Add AI policy ("a" type) support to model and enforcer
Co-authored-by: hsluoyz <3787410+hsluoyz@users.noreply.github.com>
1 parent 0c5cebe commit a9fcad3

File tree

5 files changed

+211
-7
lines changed

5 files changed

+211
-7
lines changed

ai_api.go

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,23 @@ func (e *Enforcer) buildExplainContext(rvals []interface{}, result bool, matched
145145

146146
// callAIAPI calls the configured AI API to get an explanation.
147147
func (e *Enforcer) callAIAPI(explainContext string) (string, error) {
148+
return e.callAIAPIWithSystemPrompt(explainContext, "You are an expert in access control and authorization systems. "+
149+
"Explain why an authorization request was allowed or denied based on the "+
150+
"provided access control model, policies, and enforcement result. "+
151+
"Be clear, concise, and educational.")
152+
}
153+
154+
// callAIAPIWithSystemPrompt calls the configured AI API with a custom system prompt.
155+
func (e *Enforcer) callAIAPIWithSystemPrompt(userContent, systemPrompt string) (string, error) {
148156
// Prepare the request
149157
messages := []aiMessage{
150158
{
151-
Role: "system",
152-
Content: "You are an expert in access control and authorization systems. " +
153-
"Explain why an authorization request was allowed or denied based on the " +
154-
"provided access control model, policies, and enforcement result. " +
155-
"Be clear, concise, and educational.",
159+
Role: "system",
160+
Content: systemPrompt,
156161
},
157162
{
158163
Role: "user",
159-
Content: fmt.Sprintf("Please explain the following authorization decision:\n\n%s", explainContext),
164+
Content: userContent,
160165
},
161166
}
162167

@@ -219,3 +224,46 @@ func (e *Enforcer) callAIAPI(explainContext string) (string, error) {
219224

220225
return chatResp.Choices[0].Message.Content, nil
221226
}
227+
228+
// evaluateAIPolicy evaluates an AI policy by calling the configured LLM API.
229+
// It returns true if the AI policy allows the request, false otherwise.
230+
func (e *Enforcer) evaluateAIPolicy(policyDescription string, rvals []interface{}) (bool, error) {
231+
if e.aiConfig.Endpoint == "" {
232+
return false, errors.New("AI config not set, use SetAIConfig first")
233+
}
234+
235+
// Build context for AI
236+
var sb strings.Builder
237+
sb.WriteString("Authorization Request:\n")
238+
if len(rvals) > 0 {
239+
sb.WriteString(fmt.Sprintf("Subject: %v\n", rvals[0]))
240+
}
241+
if len(rvals) > 1 {
242+
sb.WriteString(fmt.Sprintf("Object: %v\n", rvals[1]))
243+
}
244+
if len(rvals) > 2 {
245+
sb.WriteString(fmt.Sprintf("Action: %v\n", rvals[2]))
246+
}
247+
248+
sb.WriteString(fmt.Sprintf("\nAI Policy Rule: %s\n", policyDescription))
249+
sb.WriteString("\nQuestion: Does this request satisfy the AI policy rule? Answer with 'ALLOW' if yes, 'DENY' if no.")
250+
251+
// Call AI API
252+
systemPrompt := "You are an AI security policy evaluator. " +
253+
"Your task is to determine if an authorization request satisfies the given AI policy rule. " +
254+
"Respond with ONLY the word 'ALLOW' or 'DENY' based on your evaluation."
255+
256+
response, err := e.callAIAPIWithSystemPrompt(sb.String(), systemPrompt)
257+
if err != nil {
258+
return false, fmt.Errorf("failed to evaluate AI policy: %w", err)
259+
}
260+
261+
// Parse response
262+
response = strings.TrimSpace(strings.ToUpper(response))
263+
if strings.Contains(response, "ALLOW") {
264+
return true, nil
265+
}
266+
267+
return false, nil
268+
}
269+

ai_policy_api.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright 2026 The casbin Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package casbin
16+
17+
// GetAIPolicy gets all the AI policy rules in the policy.
18+
func (e *Enforcer) GetAIPolicy() ([][]string, error) {
19+
return e.GetNamedAIPolicy("a")
20+
}
21+
22+
// GetFilteredAIPolicy gets all the AI policy rules in the policy, field filters can be specified.
23+
func (e *Enforcer) GetFilteredAIPolicy(fieldIndex int, fieldValues ...string) ([][]string, error) {
24+
return e.GetFilteredNamedAIPolicy("a", fieldIndex, fieldValues...)
25+
}
26+
27+
// GetNamedAIPolicy gets all the AI policy rules in the named policy.
28+
func (e *Enforcer) GetNamedAIPolicy(ptype string) ([][]string, error) {
29+
return e.model.GetPolicy("a", ptype)
30+
}
31+
32+
// GetFilteredNamedAIPolicy gets all the AI policy rules in the named policy, field filters can be specified.
33+
func (e *Enforcer) GetFilteredNamedAIPolicy(ptype string, fieldIndex int, fieldValues ...string) ([][]string, error) {
34+
return e.model.GetFilteredPolicy("a", ptype, fieldIndex, fieldValues...)
35+
}
36+
37+
// HasAIPolicy determines whether an AI policy rule exists.
38+
func (e *Enforcer) HasAIPolicy(params ...string) (bool, error) {
39+
return e.HasNamedAIPolicy("a", params...)
40+
}
41+
42+
// HasNamedAIPolicy determines whether a named AI policy rule exists.
43+
func (e *Enforcer) HasNamedAIPolicy(ptype string, params ...string) (bool, error) {
44+
return e.model.HasPolicy("a", ptype, params)
45+
}
46+
47+
// AddAIPolicy adds an AI policy rule to the current policy.
48+
// If the rule already exists, the function returns false and the rule will not be added.
49+
// Otherwise the function returns true by adding the new rule.
50+
func (e *Enforcer) AddAIPolicy(params ...string) (bool, error) {
51+
return e.AddNamedAIPolicy("a", params...)
52+
}
53+
54+
// AddAIPolicies adds AI policy rules to the current policy.
55+
// If the rule already exists, the function returns false for the corresponding rule and the rule will not be added.
56+
// Otherwise the function returns true for the corresponding rule by adding the new rule.
57+
func (e *Enforcer) AddAIPolicies(rules [][]string) (bool, error) {
58+
return e.AddNamedAIPolicies("a", rules)
59+
}
60+
61+
// AddNamedAIPolicy adds an AI policy rule to the current named policy.
62+
// If the rule already exists, the function returns false and the rule will not be added.
63+
// Otherwise the function returns true by adding the new rule.
64+
func (e *Enforcer) AddNamedAIPolicy(ptype string, params ...string) (bool, error) {
65+
return e.addPolicyInternal("a", ptype, params)
66+
}
67+
68+
// AddNamedAIPolicies adds AI policy rules to the current named policy.
69+
// If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added.
70+
// Otherwise the function returns true for the corresponding policy rule by adding the new rule.
71+
func (e *Enforcer) AddNamedAIPolicies(ptype string, rules [][]string) (bool, error) {
72+
return e.addPoliciesInternal("a", ptype, rules)
73+
}
74+
75+
// RemoveAIPolicy removes an AI policy rule from the current policy.
76+
func (e *Enforcer) RemoveAIPolicy(params ...string) (bool, error) {
77+
return e.RemoveNamedAIPolicy("a", params...)
78+
}
79+
80+
// RemoveAIPolicies removes AI policy rules from the current policy.
81+
func (e *Enforcer) RemoveAIPolicies(rules [][]string) (bool, error) {
82+
return e.RemoveNamedAIPolicies("a", rules)
83+
}
84+
85+
// RemoveFilteredAIPolicy removes an AI policy rule from the current policy, field filters can be specified.
86+
func (e *Enforcer) RemoveFilteredAIPolicy(fieldIndex int, fieldValues ...string) (bool, error) {
87+
return e.RemoveFilteredNamedAIPolicy("a", fieldIndex, fieldValues...)
88+
}
89+
90+
// RemoveNamedAIPolicy removes an AI policy rule from the current named policy.
91+
func (e *Enforcer) RemoveNamedAIPolicy(ptype string, params ...string) (bool, error) {
92+
return e.removePolicyInternal("a", ptype, params)
93+
}
94+
95+
// RemoveNamedAIPolicies removes AI policy rules from the current named policy.
96+
func (e *Enforcer) RemoveNamedAIPolicies(ptype string, rules [][]string) (bool, error) {
97+
return e.removePoliciesInternal("a", ptype, rules)
98+
}
99+
100+
// RemoveFilteredNamedAIPolicy removes an AI policy rule from the current named policy, field filters can be specified.
101+
func (e *Enforcer) RemoveFilteredNamedAIPolicy(ptype string, fieldIndex int, fieldValues ...string) (bool, error) {
102+
return e.removeFilteredPolicyInternal("a", ptype, fieldIndex, fieldValues...)
103+
}
104+
105+
// UpdateAIPolicy updates an AI policy rule from the current policy.
106+
func (e *Enforcer) UpdateAIPolicy(oldPolicy []string, newPolicy []string) (bool, error) {
107+
return e.UpdateNamedAIPolicy("a", oldPolicy, newPolicy)
108+
}
109+
110+
// UpdateAIPolicies updates AI policy rules from the current policy.
111+
func (e *Enforcer) UpdateAIPolicies(oldPolicies [][]string, newPolicies [][]string) (bool, error) {
112+
return e.UpdateNamedAIPolicies("a", oldPolicies, newPolicies)
113+
}
114+
115+
// UpdateNamedAIPolicy updates an AI policy rule from the current named policy.
116+
func (e *Enforcer) UpdateNamedAIPolicy(ptype string, oldPolicy []string, newPolicy []string) (bool, error) {
117+
return e.updatePolicyInternal("a", ptype, oldPolicy, newPolicy)
118+
}
119+
120+
// UpdateNamedAIPolicies updates AI policy rules from the current named policy.
121+
func (e *Enforcer) UpdateNamedAIPolicies(ptype string, oldPolicies [][]string, newPolicies [][]string) (bool, error) {
122+
return e.updatePoliciesInternal("a", ptype, oldPolicies, newPolicies)
123+
}

enforcer.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,33 @@ func (e *Enforcer) enforce(matcher string, explains *[]string, rvals ...interfac
799799
var effect effector.Effect
800800
var explainIndex int
801801

802+
// Check AI policies first if they exist
803+
aType := "a"
804+
if _, ok := e.model["a"]; ok {
805+
if aPolicies, ok := e.model["a"][aType]; ok && len(aPolicies.Policy) > 0 {
806+
// Evaluate AI policies
807+
for _, aPolicy := range aPolicies.Policy {
808+
if len(aPolicy) > 0 {
809+
// The AI policy description is the first (and typically only) field
810+
policyDescription := aPolicy[0]
811+
allowed, err := e.evaluateAIPolicy(policyDescription, rvals)
812+
if err != nil {
813+
// If AI evaluation fails, log but continue with regular policies
814+
// This allows the system to fall back to traditional policies
815+
continue
816+
}
817+
if allowed {
818+
// AI policy allows the request
819+
return true, nil
820+
}
821+
}
822+
}
823+
// If we have AI policies but none allowed the request, deny
824+
// This implements a deny-by-default behavior for AI policies
825+
return false, nil
826+
}
827+
}
828+
802829
if policyLen := len(e.model["p"][pType].Policy); policyLen != 0 && strings.Contains(expString, pType+"_") { //nolint:nestif // TODO: reduce function complexity
803830
policyEffects = make([]effector.Effect, policyLen)
804831
matcherResults = make([]float64, policyLen)

model/model.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ var sectionNameMap = map[string]string{
4444
"e": "policy_effect",
4545
"m": "matchers",
4646
"c": "constraint_definition",
47+
"a": "ai_policy_definition",
4748
}
4849

4950
// Minimal required sections for a model to be valid.
@@ -78,7 +79,7 @@ func (model Model) AddDef(sec string, key string, value string) bool {
7879
ast.PolicyMap = make(map[string]int)
7980
ast.FieldIndexMap = make(map[string]int)
8081

81-
if sec == "r" || sec == "p" {
82+
if sec == "r" || sec == "p" || sec == "a" {
8283
ast.Tokens = strings.Split(ast.Value, ",")
8384
for i := range ast.Tokens {
8485
ast.Tokens[i] = key + "_" + strings.TrimSpace(ast.Tokens[i])

model/policy.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ func (model Model) ClearPolicy() {
105105
ast.Policy = nil
106106
ast.PolicyMap = map[string]int{}
107107
}
108+
109+
for _, ast := range model["a"] {
110+
ast.Policy = nil
111+
ast.PolicyMap = map[string]int{}
112+
}
108113
}
109114

110115
// GetPolicy gets all rules in a policy.

0 commit comments

Comments
 (0)