Skip to content

Commit 2145a69

Browse files
authored
Merge pull request #7 from nodece/master
refactor: remove createDatabase, add return errors
2 parents f1b4b38 + 525124c commit 2145a69

File tree

4 files changed

+108
-193
lines changed

4 files changed

+108
-193
lines changed

.travis.yml

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
language: go
2-
3-
sudo: false
4-
5-
go:
6-
- tip
7-
8-
before_install:
9-
- go get github.com/mattn/goveralls
10-
11-
script:
12-
- $HOME/gopath/bin/goveralls -service=travis-ci
13-
14-
services:
15-
- mysql
16-
- postgresql
1+
language: go
2+
3+
sudo: false
4+
5+
go:
6+
- tip
7+
8+
before_install:
9+
- mysql -e 'CREATE DATABASE casbin_test;'
10+
- psql -c 'create database casbin_test;' -U postgres
11+
- go get github.com/mattn/goveralls
12+
13+
script:
14+
- $HOME/gopath/bin/goveralls -service=travis-ci
15+
16+
services:
17+
- mysql
18+
- postgresql

README.md

Lines changed: 16 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,33 @@ Based on [Beego ORM Support](https://beego.me/docs/mvc/model/overview.md), The c
1010
- Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
1111

1212
## Installation
13+
```bash
14+
go get github.com/casbin/beego-orm-adapter/v2
15+
```
1316

14-
go get github.com/casbin/beego-orm-adapter
15-
16-
## Simple MySQL Example
17+
## Simple Example
1718

1819
```go
1920
package main
2021

2122
import (
22-
"github.com/casbin/beego-orm-adapter"
23-
"github.com/casbin/casbin"
23+
"log"
24+
beegoormadapter "github.com/casbin/beego-orm-adapter"
25+
"github.com/casbin/casbin/v2"
2426
_ "github.com/go-sql-driver/mysql"
2527
)
2628

2729
func main() {
2830
// Initialize a Beego ORM adapter and use it in a Casbin enforcer:
29-
// The adapter will use the MySQL database named "casbin".
30-
// If it doesn't exist, the adapter will create it automatically.
31-
a := beegoormadapter.NewAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/") // Your driver and data source.
32-
33-
// Or you can use an existing DB "abc" like this:
34-
// The adapter will use the table named "casbin_rule".
35-
// If it doesn't exist, the adapter will create it automatically.
36-
// a := beegoormadapter.NewAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/abc", true)
37-
38-
e := casbin.NewEnforcer("examples/rbac_model.conf", a)
39-
40-
// Load the policy from DB.
41-
e.LoadPolicy()
42-
31+
a, err := beegoormadapter.NewAdapter("default", "mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/dbname") // Your driver and data source.
32+
if err != nil {
33+
log.Fatalln(err)
34+
}
35+
36+
e, err := casbin.NewEnforcer("examples/rbac_model.conf", a)
37+
if err != nil {
38+
log.Fatalln(err)
39+
}
4340
// Check the permission.
4441
e.Enforce("alice", "data1", "read")
4542

@@ -52,45 +49,6 @@ func main() {
5249
}
5350
```
5451

55-
## Simple Postgres Example
56-
57-
```go
58-
package main
59-
60-
import (
61-
"github.com/casbin/beego-orm-adapter"
62-
"github.com/casbin/casbin"
63-
_ "github.com/lib/pq"
64-
)
65-
66-
func main() {
67-
// Initialize a Beego ORM adapter and use it in a Casbin enforcer:
68-
// The adapter will use the Postgres database named "casbin".
69-
// If it doesn't exist, the adapter will create it automatically.
70-
a := beegoormadapter.NewAdapter("postgres", "user=postgres_username password=postgres_password host=127.0.0.1 port=5432 sslmode=disable") // Your driver and data source.
71-
72-
// Or you can use an existing DB "abc" like this:
73-
// The adapter will use the table named "casbin_rule".
74-
// If it doesn't exist, the adapter will create it automatically.
75-
// a := beegoormadapter.NewAdapter("postgres", "dbname=abc user=postgres_username password=postgres_password host=127.0.0.1 port=5432 sslmode=disable", true)
76-
77-
e := casbin.NewEnforcer("../examples/rbac_model.conf", a)
78-
79-
// Load the policy from DB.
80-
e.LoadPolicy()
81-
82-
// Check the permission.
83-
e.Enforce("alice", "data1", "read")
84-
85-
// Modify the policy.
86-
// e.AddPolicy(...)
87-
// e.RemovePolicy(...)
88-
89-
// Save the policy back to DB.
90-
e.SavePolicy()
91-
}
92-
```
93-
9452
## Getting Help
9553

9654
- [Casbin](https://github.com/casbin/casbin)

adapter.go

Lines changed: 52 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,10 @@
1515
package beegoormadapter
1616

1717
import (
18-
"errors"
19-
"runtime"
20-
"strings"
21-
2218
"github.com/astaxie/beego/orm"
2319
"github.com/casbin/casbin/v2/model"
2420
"github.com/casbin/casbin/v2/persist"
25-
"github.com/lib/pq"
21+
"runtime"
2622
)
2723

2824
type CasbinRule struct {
@@ -38,133 +34,80 @@ type CasbinRule struct {
3834

3935
func init() {
4036
orm.RegisterModel(new(CasbinRule))
41-
42-
orm.RegisterDriver("mysql", orm.DRMySQL)
4337
}
4438

4539
// Adapter represents the Xorm adapter for policy storage.
4640
type Adapter struct {
47-
driverName string
48-
dataSourceName string
49-
dbSpecified bool
50-
o orm.Ormer
41+
driverName string
42+
dataSourceName string
43+
dataSourceAlias string
44+
dbSpecified bool
45+
o orm.Ormer
5146
}
5247

5348
// finalizer is the destructor for Adapter.
5449
func finalizer(a *Adapter) {
5550
}
5651

5752
// NewAdapter is the constructor for Adapter.
58-
// dbSpecified is an optional bool parameter. The default value is false.
59-
// It's up to whether you have specified an existing DB in dataSourceName.
60-
// If dbSpecified == true, you need to make sure the DB in dataSourceName exists.
61-
// If dbSpecified == false, the adapter will automatically create a DB named "casbin".
62-
func NewAdapter(driverName string, dataSourceName string, dbSpecified ...bool) *Adapter {
53+
// dataSourceAlias: Database alias. ORM will use it to switch database.
54+
// driverName: database driverName.
55+
// dataSourceName: connection string
56+
func NewAdapter(dataSourceAlias, driverName, dataSourceName string) (*Adapter, error) {
6357
a := &Adapter{}
6458
a.driverName = driverName
6559
a.dataSourceName = dataSourceName
60+
a.dataSourceAlias = dataSourceAlias
6661

67-
if len(dbSpecified) == 0 {
68-
a.dbSpecified = false
69-
} else if len(dbSpecified) == 1 {
70-
a.dbSpecified = dbSpecified[0]
71-
} else {
72-
panic(errors.New("invalid parameter: dbSpecified"))
73-
}
62+
err := a.open()
7463

75-
// Open the DB, create it if not existed.
76-
a.open()
64+
if err != nil {
65+
return nil, err
66+
}
7767

7868
// Call the destructor when the object is released.
7969
runtime.SetFinalizer(a, finalizer)
8070

81-
return a
71+
return a, nil
8272
}
8373

8474
func (a *Adapter) registerDataBase(aliasName, driverName, dataSource string, params ...int) error {
8575
err := orm.RegisterDataBase(aliasName, driverName, dataSource, params...)
86-
if err != nil && strings.HasSuffix(err.Error(), "already registered, cannot reuse") {
87-
return nil
88-
}
8976
return err
9077
}
9178

92-
func (a *Adapter) createDatabase() error {
79+
func (a *Adapter) open() error {
9380
var err error
94-
var o orm.Ormer
95-
if a.driverName == "postgres" {
96-
err = a.registerDataBase("create_casbin", a.driverName, a.dataSourceName + " dbname=postgres")
97-
} else {
98-
err = a.registerDataBase("create_casbin", a.driverName, a.dataSourceName)
99-
}
81+
82+
err = a.registerDataBase(a.dataSourceAlias, a.driverName, a.dataSourceName)
10083
if err != nil {
10184
return err
10285
}
103-
o = orm.NewOrm()
104-
105-
if a.driverName == "postgres" {
106-
if _, err = o.Raw("CREATE DATABASE casbin").Exec(); err != nil {
107-
// 42P04 is duplicate_database
108-
if err.(*pq.Error).Code == "42P04" {
109-
return nil
110-
}
111-
}
112-
} else {
113-
_, err = o.Raw("CREATE DATABASE IF NOT EXISTS casbin").Exec()
114-
}
115-
return err
116-
}
117-
118-
func (a *Adapter) open() {
119-
var err error
12086

121-
err = a.registerDataBase("default", a.driverName, a.dataSourceName)
87+
a.o = orm.NewOrm()
88+
err = a.o.Using(a.dataSourceAlias)
12289
if err != nil {
123-
panic(err)
90+
return err
12491
}
12592

126-
if a.dbSpecified {
127-
err = a.registerDataBase("casbin", a.driverName, a.dataSourceName)
128-
if err != nil {
129-
panic(err)
130-
}
131-
} else {
132-
if err = a.createDatabase(); err != nil {
133-
panic(err)
134-
}
135-
136-
if a.driverName == "postgres" {
137-
err = a.registerDataBase("casbin", a.driverName, a.dataSourceName + " dbname=casbin")
138-
} else {
139-
err = a.registerDataBase("casbin", a.driverName, a.dataSourceName + "casbin")
140-
}
141-
if err != nil {
142-
panic(err)
143-
}
93+
err = a.createTable()
94+
if err != nil {
95+
return err
14496
}
14597

146-
a.o = orm.NewOrm()
147-
a.o.Using("casbin")
148-
149-
a.createTable()
98+
return nil
15099
}
151100

152101
func (a *Adapter) close() {
153102
a.o = nil
154103
}
155104

156-
func (a *Adapter) createTable() {
157-
err := orm.RunSyncdb("casbin", false, true)
158-
if err != nil {
159-
panic(err)
160-
}
105+
func (a *Adapter) createTable() error {
106+
return orm.RunSyncdb(a.dataSourceAlias, false, true)
161107
}
162108

163-
func (a *Adapter) dropTable() {
164-
err := orm.RunSyncdb("casbin", true, true)
165-
if err != nil {
166-
panic(err)
167-
}
109+
func (a *Adapter) dropTable() error {
110+
return orm.RunSyncdb(a.dataSourceAlias, true, true)
168111
}
169112

170113
func loadPolicyLine(line CasbinRule, model model.Model) {
@@ -234,8 +177,15 @@ func savePolicyLine(ptype string, rule []string) CasbinRule {
234177

235178
// SavePolicy saves policy to database.
236179
func (a *Adapter) SavePolicy(model model.Model) error {
237-
a.dropTable()
238-
a.createTable()
180+
err := a.dropTable()
181+
if err != nil {
182+
return err
183+
}
184+
185+
err = a.createTable()
186+
if err != nil {
187+
return err
188+
}
239189

240190
var lines []CasbinRule
241191

@@ -253,7 +203,7 @@ func (a *Adapter) SavePolicy(model model.Model) error {
253203
}
254204
}
255205

256-
_, err := a.o.InsertMulti(len(lines), lines)
206+
_, err = a.o.InsertMulti(len(lines), lines)
257207
return err
258208
}
259209

@@ -278,28 +228,28 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
278228
line.PType = ptype
279229
filter := []string{}
280230
filter = append(filter, "p_type")
281-
if fieldIndex <= 0 && 0 < fieldIndex + len(fieldValues) {
282-
line.V0 = fieldValues[0 - fieldIndex]
231+
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
232+
line.V0 = fieldValues[0-fieldIndex]
283233
filter = append(filter, "v0")
284234
}
285-
if fieldIndex <= 1 && 1 < fieldIndex + len(fieldValues) {
286-
line.V1 = fieldValues[1 - fieldIndex]
235+
if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
236+
line.V1 = fieldValues[1-fieldIndex]
287237
filter = append(filter, "v1")
288238
}
289-
if fieldIndex <= 2 && 2 < fieldIndex + len(fieldValues) {
290-
line.V2 = fieldValues[2 - fieldIndex]
239+
if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
240+
line.V2 = fieldValues[2-fieldIndex]
291241
filter = append(filter, "v2")
292242
}
293-
if fieldIndex <= 3 && 3 < fieldIndex + len(fieldValues) {
294-
line.V3 = fieldValues[3 - fieldIndex]
243+
if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
244+
line.V3 = fieldValues[3-fieldIndex]
295245
filter = append(filter, "v3")
296246
}
297-
if fieldIndex <= 4 && 4 < fieldIndex + len(fieldValues) {
298-
line.V4 = fieldValues[4 - fieldIndex]
247+
if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
248+
line.V4 = fieldValues[4-fieldIndex]
299249
filter = append(filter, "v4")
300250
}
301-
if fieldIndex <= 5 && 5 < fieldIndex + len(fieldValues) {
302-
line.V5 = fieldValues[5 - fieldIndex]
251+
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
252+
line.V5 = fieldValues[5-fieldIndex]
303253
filter = append(filter, "v5")
304254
}
305255

0 commit comments

Comments
 (0)