Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion internal/adapters/handlers/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"go.openfort.xyz/shield/internal/adapters/handlers/rest/requestmdw"
"go.openfort.xyz/shield/internal/adapters/handlers/rest/responsemdw"
"go.openfort.xyz/shield/internal/adapters/handlers/rest/sharehdl"
"go.openfort.xyz/shield/internal/adapters/handlers/rest/usrhdl"
"go.openfort.xyz/shield/internal/applications/projectapp"
"go.openfort.xyz/shield/internal/applications/shareapp"
"go.openfort.xyz/shield/internal/core/ports/factories"
Expand Down Expand Up @@ -61,7 +62,7 @@ func New(cfg *Config,
authenticationFactory: authenticationFactory,
identityFactory: identityFactory,
userService: userService,
projectService: projectService,
projectService: projectService,
}
}

Expand All @@ -70,6 +71,7 @@ func (s *Server) Start(ctx context.Context) error {
healthzHdl := healthzhdl.New(s.healthzApp)
projectHdl := projecthdl.New(s.projectApp)
shareHdl := sharehdl.New(s.shareApp)
userHdl := usrhdl.New(s.userService)
authMdw := authmdw.New(s.authenticationFactory, s.identityFactory, s.userService, s.projectService)
rateLimiterMdw := ratelimitermdw.New(s.config.RPS)

Expand Down Expand Up @@ -99,6 +101,10 @@ func (s *Server) Start(ctx context.Context) error {
p.HandleFunc("/encryption-key", projectHdl.RegisterEncryptionKey).Methods(http.MethodPost)
p.HandleFunc("/enable-2fa", projectHdl.Enable2FA).Methods(http.MethodPost)

usr := r.PathPrefix("/user").Subrouter()
usr.Use(authMdw.AuthenticateAPISecret)
usr.HandleFunc("", userHdl.CreateUser).Methods(http.MethodPost)

u := r.PathPrefix("/shares").Subrouter()
u.Use(authMdw.AuthenticateUser)
u.HandleFunc("", shareHdl.GetShare).Methods(http.MethodGet)
Expand All @@ -118,6 +124,11 @@ func (s *Server) Start(ctx context.Context) error {
e.HandleFunc("/reference/bulk", shareHdl.GetSharesEncryptionForReferences).Methods(http.MethodPost)
e.HandleFunc("/user/bulk", shareHdl.GetSharesEncryptionForUsers).Methods(http.MethodPost)

m := r.PathPrefix("/shares/migration").Subrouter()
m.Use(authMdw.AuthenticateAPISecret)
m.HandleFunc("/export/{reference}", shareHdl.ExportShare).Methods(http.MethodGet)
m.HandleFunc("/import", shareHdl.ImportShare).Methods(http.MethodPost)

a := r.PathPrefix("/admin").Subrouter()
a.Use(authMdw.AuthenticateAPISecret)
a.Use(authMdw.PreRegisterUser)
Expand Down
83 changes: 83 additions & 0 deletions internal/adapters/handlers/rest/sharehdl/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,89 @@ func (h *Handler) GetSharesEncryptionForUsers(w http.ResponseWriter, r *http.Req
_, _ = w.Write(resp)
}

// ExportShare exports a share as-is from the database
// @Summary Export share
// @Description Export a share by reference without decryption
// @Tags Share Migration
// @Produce json
// @Param X-API-Key header string true "API Key"
// @Param X-API-Secret header string true "API Secret"
// @Param reference path string true "Share Reference"
// @Success 200 {object} ExportShareResponse "Successful response"
// @Failure 404 "Description: Not Found"
// @Failure 500 "Description: Internal Server Error"
// @Router /shares/migration/export/{reference} [get]
func (h *Handler) ExportShare(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.InfoContext(ctx, "exporting share")

reference := mux.Vars(r)["reference"]
if reference == "" {
api.RespondWithError(w, api.ErrBadRequestWithMessage("missing reference"))
return
}

shr, err := h.app.ExportShare(ctx, reference)
if err != nil {
api.RespondWithError(w, fromApplicationError(err))
return
}

resp, err := json.Marshal(h.parser.fromDomainExport(shr))
if err != nil {
api.RespondWithError(w, api.ErrInternal)
return
}

w.WriteHeader(http.StatusOK)
_, _ = w.Write(resp)
}

// ImportShare imports a share directly into the database
// @Summary Import share
// @Description Import a previously exported share
// @Tags Share Migration
// @Accept json
// @Param X-API-Key header string true "API Key"
// @Param X-API-Secret header string true "API Secret"
// @Param importShareRequest body ImportShareRequest true "Import Share Request"
// @Success 201 "Description: Share imported successfully"
// @Failure 400 {object} api.Error "Bad Request"
// @Failure 409 {object} api.Error "Conflict"
// @Failure 500 {object} api.Error "Internal Server Error"
// @Router /shares/migration/import [post]
func (h *Handler) ImportShare(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.InfoContext(ctx, "importing share")

body, err := io.ReadAll(r.Body)
if err != nil {
api.RespondWithError(w, api.ErrBadRequestWithMessage("failed to read request body"))
return
}

var req ImportShareRequest
err = json.Unmarshal(body, &req)
if err != nil {
api.RespondWithError(w, api.ErrBadRequestWithMessage("failed to parse request body"))
return
}

if req.Secret == "" {
api.RespondWithError(w, api.ErrBadRequestWithMessage("secret is required"))
return
}

shr := h.parser.toImportDomain(&req)
err = h.app.ImportShare(ctx, shr)
if err != nil {
api.RespondWithError(w, fromApplicationError(err))
return
}

w.WriteHeader(http.StatusCreated)
}

// GetShareStorageMethods list the available share storage methods
// @Summary Get share storage methods
// @Description Get the available share storage methods
Expand Down
72 changes: 72 additions & 0 deletions internal/adapters/handlers/rest/sharehdl/parser.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sharehdl

import (
"github.com/google/uuid"
"go.openfort.xyz/shield/internal/core/domain/share"
)

Expand Down Expand Up @@ -169,6 +170,77 @@ func (p *parser) fromDomain(s *share.Share) *Share {
return shr
}

func (p *parser) fromDomainExport(s *share.Share) *ExportShareResponse {
resp := &ExportShareResponse{
Secret: s.Secret,
Entropy: p.mapDomainEntropy[s.Entropy],
ShareStorageMethodID: p.mapDomainStorageMethod[s.ShareStorageMethodID],
}

if s.Reference != nil {
resp.Reference = *s.Reference
}

if s.EncryptionParameters != nil {
resp.Salt = s.EncryptionParameters.Salt
resp.Iterations = s.EncryptionParameters.Iterations
resp.Length = s.EncryptionParameters.Length
resp.Digest = s.EncryptionParameters.Digest
}

if s.PasskeyReference != nil {
resp.PasskeyReference = &PasskeyReference{
PasskeyId: &s.PasskeyReference.PasskeyID,
}
if s.PasskeyReference.PasskeyEnv != nil {
resp.PasskeyReference.PasskeyEnv = &PasskeyEnv{
Name: s.PasskeyReference.PasskeyEnv.Name,
OS: s.PasskeyReference.PasskeyEnv.OS,
OSVersion: s.PasskeyReference.PasskeyEnv.OSVersion,
Device: s.PasskeyReference.PasskeyEnv.Device,
}
}
}

return resp
}

func (p *parser) toImportDomain(s *ImportShareRequest) *share.Share {
shr := &share.Share{
UserID: s.UserId,
Secret: s.Secret,
Entropy: p.mapEntropyDomain[s.Entropy],
ShareStorageMethodID: p.mapStorageMethodDomain[s.ShareStorageMethodID],
}

if s.Reference != "" {
shr.Reference = &s.Reference
}

if s.Salt != "" || s.Iterations != 0 || s.Length != 0 || s.Digest != "" {
shr.EncryptionParameters = &share.EncryptionParameters{
Salt: s.Salt,
Iterations: s.Iterations,
Length: s.Length,
Digest: s.Digest,
}
}

if s.Entropy == EntropyPasskey && s.PasskeyReference != nil {
shr.PasskeyReference = &share.PasskeyReference{
PasskeyID: uuid.NewString(),
}
shr.PasskeyReference.PasskeyEnv = &share.PasskeyEnv{
Name: s.PasskeyReference.Name,
OS: s.PasskeyReference.OS,
OSVersion: s.PasskeyReference.OSVersion,
Device: s.PasskeyReference.Device,
}
}

return shr
}

func (p *parser) fromDomainShareStorageMethod(s *share.StorageMethod) *ShareStorageMethod {
return &ShareStorageMethod{
ID: s.ID,
Expand Down
25 changes: 25 additions & 0 deletions internal/adapters/handlers/rest/sharehdl/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,31 @@ type PasskeyReference struct {
PasskeyEnv *PasskeyEnv
}

type ExportShareResponse struct {
Secret string `json:"secret"`
Entropy Entropy `json:"entropy"`
Salt string `json:"salt,omitempty"`
Iterations int `json:"iterations,omitempty"`
Length int `json:"length,omitempty"`
Digest string `json:"digest,omitempty"`
Reference string `json:"reference,omitempty"`
ShareStorageMethodID ShareStorageMethodID `json:"storage_method_id"`
PasskeyReference *PasskeyReference `json:"passkey_reference,omitempty"`
}

type ImportShareRequest struct {
UserId string `json:"user_id"`
Secret string `json:"secret"`
Entropy Entropy `json:"entropy"`
Salt string `json:"salt,omitempty"`
Iterations int `json:"iterations,omitempty"`
Length int `json:"length,omitempty"`
Digest string `json:"digest,omitempty"`
Reference string `json:"reference,omitempty"`
ShareStorageMethodID ShareStorageMethodID `json:"storage_method_id"`
PasskeyReference *PasskeyEnv `json:"passkey_reference,omitempty"`
}

type GetShareEncryptionResponse struct {
Entropy Entropy `json:"entropy"`
Salt *string `json:"salt,omitempty"`
Expand Down
79 changes: 79 additions & 0 deletions internal/adapters/handlers/rest/usrhdl/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package usrhdl

import (
"encoding/json"
"io"
"log/slog"
"net/http"

"go.openfort.xyz/shield/internal/adapters/handlers/rest/api"
"go.openfort.xyz/shield/internal/core/ports/services"
"go.openfort.xyz/shield/pkg/contexter"
"go.openfort.xyz/shield/pkg/logger"
)

type Handler struct {
userService services.UserService
logger *slog.Logger
}

func New(userService services.UserService) *Handler {
return &Handler{
userService: userService,
logger: logger.New("user_handler"),
}
}

type CreateUserRequest struct {
ExternalUserID string `json:"external_user_id"`
ProviderID string `json:"provider_id"`
}

type CreateUserResponse struct {
UserID string `json:"user_id"`
}

func (h *Handler) CreateUser(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h.logger.InfoContext(ctx, "creating user")

body, err := io.ReadAll(r.Body)
if err != nil {
api.RespondWithError(w, api.ErrBadRequestWithMessage("failed to read request body"))
return
}

var req CreateUserRequest
if err := json.Unmarshal(body, &req); err != nil {
api.RespondWithError(w, api.ErrBadRequestWithMessage("failed to parse request body"))
return
}

if req.ExternalUserID == "" {
api.RespondWithError(w, api.ErrBadRequestWithMessage("external_user_id is required"))
return
}

if req.ProviderID == "" {
api.RespondWithError(w, api.ErrBadRequestWithMessage("provider_id is required"))
return
}

projectID := contexter.GetProjectID(ctx)

usr, err := h.userService.GetOrCreate(ctx, projectID, req.ExternalUserID, req.ProviderID)
if err != nil {
h.logger.ErrorContext(ctx, "failed to create user", slog.String("error", err.Error()))
api.RespondWithError(w, api.ErrInternal)
return
}

resp, err := json.Marshal(CreateUserResponse{UserID: usr.ID})
if err != nil {
api.RespondWithError(w, api.ErrInternal)
return
}

w.WriteHeader(http.StatusCreated)
_, _ = w.Write(resp)
}
8 changes: 8 additions & 0 deletions internal/adapters/repositories/mocks/sharemockrepo/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ func (m *MockShareRepository) GetByReferenceAndKeychain(ctx context.Context, ref
return args.Get(0).(*share.Share), args.Error(1)
}

func (m *MockShareRepository) GetByReferenceAndProjectID(ctx context.Context, reference, projectID string) (*share.Share, error) {
args := m.Mock.Called(ctx, reference, projectID)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*share.Share), args.Error(1)
}

func (m *MockShareRepository) ListByKeychainID(ctx context.Context, keychainID string) ([]*share.Share, error) {
args := m.Mock.Called(ctx, keychainID)
if args.Get(0) == nil {
Expand Down
20 changes: 20 additions & 0 deletions internal/adapters/repositories/sql/sharerepo/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,26 @@ func (r *repository) GetByReferenceAndKeychain(ctx context.Context, reference, k
return r.parser.toDomain(dbShr), nil
}

func (r *repository) GetByReferenceAndProjectID(ctx context.Context, reference, projectID string) (*share.Share, error) {
r.logger.InfoContext(ctx, "getting share by reference and project", slog.String("reference", reference), slog.String("project_id", projectID))

dbShr := &Share{}
err := r.db.Preload("PasskeyReference").
Joins("JOIN shld_users ON shld_shares.user_id = shld_users.id").
Where("shld_shares.reference = ? AND shld_users.project_id = ?", reference, projectID).
Where("shld_shares.deleted_at IS NULL").
First(dbShr).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, domainErrors.ErrShareNotFound
}
r.logger.ErrorContext(ctx, "error getting share by reference and project", logger.Error(err))
return nil, err
}

return r.parser.toDomain(dbShr), nil
}

func (r *repository) GetByUserID(ctx context.Context, userID string) (*share.Share, error) {
r.logger.InfoContext(ctx, "getting share", slog.String("user_id", userID))

Expand Down
Loading
Loading