Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
"github.com/dxta-dev/app/internal/otel"
)

type TenantDB struct {
type DB struct {
DB *sql.DB
}

func NewTenantDB(dbUrl string, ctx context.Context) (TenantDB, error) {
func NewDB(dbUrl string, ctx context.Context) (DB, error) {
driverName := otel.GetDriverName()
devToken := os.Getenv("DXTA_DEV_GROUP_TOKEN")

Expand All @@ -23,14 +23,14 @@ func NewTenantDB(dbUrl string, ctx context.Context) (TenantDB, error) {
)

if err != nil {
return TenantDB{}, errors.New("failed to open tenant db connection " + err.Error())
return DB{}, errors.New("failed to open tenant db connection " + err.Error())
}

if err := tenantDB.PingContext(ctx); err != nil {
return TenantDB{}, errors.New("failed to verify tenant db connection " + err.Error())
return DB{}, errors.New("failed to verify tenant db connection " + err.Error())
}

return TenantDB{
return DB{
DB: tenantDB,
}, nil
}
4 changes: 2 additions & 2 deletions internal/internal-api/data/members.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

type CreateMemberResponse struct{ Id int64 }

func (d TenantDB) CreateMember(
func (d DB) CreateMember(
name string,
email *string,
ctx context.Context,
Expand Down Expand Up @@ -36,7 +36,7 @@ func (d TenantDB) CreateMember(

}

func (d TenantDB) AddMemberToTeam(teamId int64, memberId int64, ctx context.Context) error {
func (d DB) AddMemberToTeam(teamId int64, memberId int64, ctx context.Context) error {
query := `
INSERT INTO teams__members
(team_id, member_id)
Expand Down
2 changes: 1 addition & 1 deletion internal/internal-api/data/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"fmt"
)

func (d TenantDB) GetOrganizationIdByAuthId(authId string, ctx context.Context) (int64, error) {
func (d DB) GetOrganizationIdByAuthId(authId string, ctx context.Context) (int64, error) {
query := `
SELECT id
FROM organizations
Expand Down
2 changes: 1 addition & 1 deletion internal/internal-api/data/teams.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

type CreateTeamResponse struct{ Id int64 }

func (d TenantDB) CreateTeam(
func (d DB) CreateTeam(
teamName string,
organizationId int64,
ctx context.Context,
Expand Down
9 changes: 8 additions & 1 deletion internal/internal-api/handler/add_member_to_team.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ type AddMemberToTeamResponse struct {
func AddMemberToTeam(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

apiState := ctx.Value(util.ApiStateCtxKey).(api.State)
authId := ctx.Value(util.AuthIdCtxKey).(string)

apiState, err := api.InternalApiState(authId, ctx)

if err != nil {
util.JSONError(w, util.ErrorParam{Error: "Internal Server Error"}, http.StatusInternalServerError)
return
}

teamId, err := strconv.ParseInt(chi.URLParam(r, "team_id"), 10, 64)
if err != nil {
Expand Down
9 changes: 8 additions & 1 deletion internal/internal-api/handler/create_member.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ func CreateMember(w http.ResponseWriter, r *http.Request) {
util.JSONError(w, util.ErrorParam{Error: "Bad Request"}, http.StatusBadRequest)
}

apiState := ctx.Value(util.ApiStateCtxKey).(api.State)
authId := ctx.Value(util.AuthIdCtxKey).(string)

apiState, err := api.InternalApiState(authId, ctx)

if err != nil {
util.JSONError(w, util.ErrorParam{Error: "Internal Server Error"}, http.StatusInternalServerError)
return
}

newMemberRes, err := apiState.DB.CreateMember(body.Name, body.Email, ctx)

Expand Down
26 changes: 17 additions & 9 deletions internal/internal-api/handler/create_team.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,26 @@ func CreateTeam(w http.ResponseWriter, r *http.Request) {
return
}

organizationId := ctx.Value(util.OrganizationIdCtxKey).(int64)

if organizationId == 0 || body.TeamName == "" {
fmt.Printf(
"No organization id or team name provided. Organization id: %d Team name: %s",
organizationId,
body.TeamName,
)
if body.TeamName == "" {
fmt.Printf("No team name provided. Team name: %s", body.TeamName)
util.JSONError(w, util.ErrorParam{Error: "Bad Request"}, http.StatusBadRequest)
}

apiState := ctx.Value(util.ApiStateCtxKey).(api.State)
authId := ctx.Value(util.AuthIdCtxKey).(string)

apiState, err := api.InternalApiState(authId, ctx)

if err != nil {
util.JSONError(w, util.ErrorParam{Error: "Internal Server Error"}, http.StatusInternalServerError)
return
}

organizationId, err := apiState.DB.GetOrganizationIdByAuthId(authId, ctx)

if err != nil {
util.JSONError(w, util.ErrorParam{Error: "Bad request"}, http.StatusBadRequest)
return
}

newTeamRes, err := apiState.DB.CreateTeam(body.TeamName, organizationId, ctx)

Expand Down
28 changes: 21 additions & 7 deletions internal/internal-api/internal-api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,29 @@ import (
"context"
"database/sql"
"fmt"
"net/http"
"os"
"sync"

"github.com/dxta-dev/app/internal/internal-api/data"
"github.com/dxta-dev/app/internal/otel"
_ "github.com/libsql/libsql-client-go/libsql"
)

type State struct {
DB data.TenantDB
DB data.DB
}

type TenantDBData struct {
DBUrl string
}

func GetTenantDBUrlByAuthId(ctx context.Context, authId string) (TenantDBData, error) {
var tenantDBURLcache sync.Map

func GetTenantDBUrlByAuthId(ctx context.Context, authID string) (TenantDBData, error) {
if cached, ok := tenantDBURLcache.Load(authID); ok {
return TenantDBData{DBUrl: cached.(string)}, nil
}

driverName := otel.GetDriverName()
tenantOrganizationMapDBUrl := os.Getenv("TENANT_ORG_MAPPING_URL")
devToken := os.Getenv("DXTA_DEV_GROUP_TOKEN")
Expand All @@ -47,20 +53,28 @@ func GetTenantDBUrlByAuthId(ctx context.Context, authId string) (TenantDBData, e

var tenantData TenantDBData

if err = tenantOrganizationMapDB.QueryRowContext(ctx, query, authId).Scan(&tenantData.DBUrl); err != nil {
if err = tenantOrganizationMapDB.QueryRowContext(ctx, query, authID).Scan(&tenantData.DBUrl); err != nil {
fmt.Printf(
"Could not retrieve tenant db url for organization with id: %s. Error: %s",
authId,
authID,
err.Error(),
)
return TenantDBData{}, err
}

tenantDBURLcache.Store(authID, tenantData.DBUrl)

return tenantData, nil
}

func InternalApiState(ctx context.Context, dbUrl string, r *http.Request) (State, error) {
tenantDB, err := data.NewTenantDB(dbUrl, ctx)
func InternalApiState(authId string, ctx context.Context) (State, error) {
tenantData, err := GetTenantDBUrlByAuthId(ctx, authId)

if err != nil {
return State{}, err
}

tenantDB, err := data.NewDB(tenantData.DBUrl, ctx)

if err != nil {
return State{}, err
Expand Down
2 changes: 1 addition & 1 deletion internal/onboarding/tenant.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func GetCachedTenantDB(store *sync.Map, dbUrl string, ctx context.Context) (*sql
db, ok := store.Load(dbUrl)

if !ok {
tenantDB, err := internal_api_data.NewTenantDB(dbUrl, ctx)
tenantDB, err := internal_api_data.NewDB(dbUrl, ctx)

if err != nil {
return nil, errors.New("failed to create tenant db connection: " + err.Error())
Expand Down
38 changes: 8 additions & 30 deletions internal/util/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"os"

api "github.com/dxta-dev/app/internal/internal-api"
"github.com/go-chi/jwtauth/v5"
)

Expand Down Expand Up @@ -92,8 +91,7 @@ type contextKey struct {
}

var (
OrganizationIdCtxKey = contextKey{"organizationId"}
ApiStateCtxKey = contextKey{"apiState"}
AuthIdCtxKey = contextKey{"authId"}
)

func Authenticator() func(http.Handler) http.Handler {
Expand All @@ -102,7 +100,7 @@ func Authenticator() func(http.Handler) http.Handler {
token, claims, err := jwtauth.FromContext(r.Context())

if err != nil {
fmt.Println("Error extracting token and claims from context")
fmt.Printf("Error extracting token and claims from context. Error: %s", err.Error())
JSONError(w, ErrorParam{Error: "Internal Server Error"}, http.StatusInternalServerError)
return
}
Expand All @@ -113,39 +111,19 @@ func Authenticator() func(http.Handler) http.Handler {
return
}

authId := claims["organizationId"].(string)
authId := claims["organizationId"]

if authId == "" {
fmt.Println("No organization id found in JWT payload")
if authId == nil {
fmt.Println("No auth id found in JWT payload")
JSONError(w, ErrorParam{Error: "Bad request"}, http.StatusBadRequest)
return
}

ctx := r.Context()

tenantData, err := api.GetTenantDBUrlByAuthId(ctx, authId)

if err != nil {
JSONError(w, ErrorParam{Error: "Internal Server Error"}, http.StatusInternalServerError)
return
}

apiState, err := api.InternalApiState(ctx, tenantData.DBUrl, r)
authId = authId.(string)

if err != nil {
JSONError(w, ErrorParam{Error: "Internal Server Error"}, http.StatusInternalServerError)
return
}

organizationId, err := apiState.DB.GetOrganizationIdByAuthId(authId, ctx)

if err != nil {
JSONError(w, ErrorParam{Error: "Bad request"}, http.StatusBadRequest)
return
}
ctx := r.Context()

ctx = context.WithValue(ctx, OrganizationIdCtxKey, organizationId)
ctx = context.WithValue(ctx, ApiStateCtxKey, apiState)
ctx = context.WithValue(ctx, AuthIdCtxKey, authId)

next.ServeHTTP(w, r.WithContext(ctx))
}
Expand Down
Loading