Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions internal/container/jump_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,58 @@ func DeleteJumpServerAccount(username string, verbose bool) error {
return nil
}

// EnsureJumpServerAccount creates a host-level user with containarium-shell
// as the login shell, enabling SSH access through sshpiper into the user's
// Incus container. This is called automatically when a container is created.
// It is idempotent — if the account already exists, it just ensures the shell
// and permissions are correct.
func EnsureJumpServerAccount(username string) error {
if !isValidUsername(username) {
return fmt.Errorf("invalid username: %s", username)
}

shellPath := "/usr/local/bin/containarium-shell"

if userExists(username) {
// Ensure shell is containarium-shell
// #nosec G204 -- username validated by isValidUsername above (alphanumeric, dash, underscore only)
_ = exec.Command("usermod", "-s", shellPath, username).Run()
return nil
}

// Create user with containarium-shell
// #nosec G204 -- username validated by isValidUsername above
if err := exec.Command("useradd", "-m", "-s", shellPath,
"-c", fmt.Sprintf("Containarium user - %s", username),
username).Run(); err != nil {
return fmt.Errorf("useradd failed: %w", err)
}

// Unlock account (useradd creates locked accounts, sshd rejects them)
// #nosec G204 -- username validated by isValidUsername above
_ = exec.Command("passwd", "-d", username).Run()

// Set home dir permissions (sshd requires 755 or stricter)
_ = os.Chmod(fmt.Sprintf("/home/%s", username), 0755) // #nosec G302 -- sshd requires home dir to be world-readable

// Create .ssh dir
sshDir := fmt.Sprintf("/home/%s/.ssh", username)
if err := os.MkdirAll(sshDir, 0700); err != nil {
return fmt.Errorf("failed to create .ssh dir: %w", err)
}
// #nosec G204 -- username validated by isValidUsername above
_ = exec.Command("chown", "-R", username+":"+username, sshDir).Run()

// Sudoers entry for incus access (containarium-shell needs it)
sudoersEntry := fmt.Sprintf("%s ALL=(root) NOPASSWD: /usr/bin/incus\n", username)
sudoersPath := fmt.Sprintf("/etc/sudoers.d/containarium-%s", username)
if err := os.WriteFile(sudoersPath, []byte(sudoersEntry), 0440); err != nil { // #nosec G306 -- sudoers requires 0440
return fmt.Errorf("failed to write sudoers: %w", err)
}

return nil
}

// isValidUsername checks if username contains only allowed characters
func isValidUsername(username string) bool {
if len(username) == 0 || len(username) > 32 {
Expand Down
18 changes: 18 additions & 0 deletions internal/sentinel/binaryserver.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package sentinel

import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
Expand Down Expand Up @@ -43,6 +46,21 @@ func StartBinaryServer(port int, manager *Manager) (stop func(), err error) {
w.Header().Set("Content-Disposition", "attachment; filename=containarium")
http.ServeFile(w, r, binaryPath)
})
mux.HandleFunc("/containarium/checksum", func(w http.ResponseWriter, r *http.Request) {
f, err := os.Open(binaryPath)
if err != nil {
http.Error(w, "binary not found", http.StatusInternalServerError)
return
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
http.Error(w, "checksum error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/plain")
fmt.Fprint(w, hex.EncodeToString(h.Sum(nil)))
})
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
Expand Down
4 changes: 2 additions & 2 deletions internal/sentinel/keysync.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ func (ks *KeyStore) Apply() error {
}

// Ensure directories exist
if err := os.MkdirAll(sshpiperUsersDir, 0755); err != nil {
if err := os.MkdirAll(sshpiperUsersDir, 0755); err != nil { // #nosec G301 -- sshpiper needs world-readable dirs for authorized_keys lookup
return fmt.Errorf("failed to create sshpiper users dir: %w", err)
}

// Write per-user authorized_keys
for _, r := range routes {
userDir := filepath.Join(sshpiperUsersDir, r.username)
if err := os.MkdirAll(userDir, 0755); err != nil {
if err := os.MkdirAll(userDir, 0755); err != nil { // #nosec G301 -- sshpiper requires world-readable user dirs
log.Printf("[keysync] failed to create dir for %s: %v", r.username, err)
continue
}
Expand Down
201 changes: 201 additions & 0 deletions internal/server/autoupdate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package server

import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"time"
)

// AutoUpdater periodically checks the sentinel for a newer binary and
// self-updates if a new version is available.
type AutoUpdater struct {
sentinelURL string // e.g. "http://10.130.0.13:8888"
binaryPath string // e.g. "/usr/local/bin/containarium"
interval time.Duration
}

// NewAutoUpdater creates a new auto-updater.
func NewAutoUpdater(sentinelURL, binaryPath string, interval time.Duration) *AutoUpdater {
return &AutoUpdater{
sentinelURL: sentinelURL,
binaryPath: binaryPath,
interval: interval,
}
}

// Run starts the auto-update loop. Blocks until ctx is cancelled.
func (u *AutoUpdater) Run(ctx context.Context) {
log.Printf("[auto-update] started (check interval: %s, sentinel: %s)", u.interval, u.sentinelURL)

// Wait before first check to let the daemon fully start
select {
case <-time.After(2 * time.Minute):
case <-ctx.Done():
return
}

ticker := time.NewTicker(u.interval)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
log.Printf("[auto-update] stopped")
return
case <-ticker.C:
if err := u.checkAndUpdate(ctx); err != nil {
log.Printf("[auto-update] check failed: %v", err)
}
}
}
}

func (u *AutoUpdater) checkAndUpdate(ctx context.Context) error {
// 1. Get remote checksum
remoteChecksum, err := u.getRemoteChecksum(ctx)
if err != nil {
return fmt.Errorf("get remote checksum: %w", err)
}

// 2. Get local checksum
localChecksum, err := u.getLocalChecksum()
if err != nil {
return fmt.Errorf("get local checksum: %w", err)
}

// 3. Compare
if remoteChecksum == localChecksum {
return nil // up to date
}

log.Printf("[auto-update] new version detected (local=%s..., remote=%s...)", localChecksum[:12], remoteChecksum[:12])

// 4. Download new binary
tmpPath := u.binaryPath + ".new"
if err := u.downloadBinary(ctx, tmpPath); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("download: %w", err)
}

// 5. Verify downloaded binary checksum
dlChecksum, err := checksumFile(tmpPath)
if err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("verify download: %w", err)
}
if dlChecksum != remoteChecksum {
_ = os.Remove(tmpPath)
return fmt.Errorf("checksum mismatch after download (got %s, want %s)", dlChecksum[:12], remoteChecksum[:12])
}

// 6. Make executable
if err := os.Chmod(tmpPath, 0755); err != nil { // #nosec G302 -- executable binary needs 0755
_ = os.Remove(tmpPath)
return fmt.Errorf("chmod: %w", err)
}

// 7. Replace: rename running binary to .old, move new one in place
oldPath := u.binaryPath + ".old"
_ = os.Remove(oldPath)
if err := os.Rename(u.binaryPath, oldPath); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("rename old binary: %w", err)
}
if err := os.Rename(tmpPath, u.binaryPath); err != nil {
// Try to restore old binary
_ = os.Rename(oldPath, u.binaryPath)
return fmt.Errorf("rename new binary: %w", err)
}

log.Printf("[auto-update] binary replaced successfully, restarting...")

// 8. Restart services via systemd (async — we'll be killed)
// Restart tunnel first (it also uses the same binary), then the daemon.
go func() {
time.Sleep(1 * time.Second)
// Restart tunnel if it exists (peers only)
if exec.Command("systemctl", "is-active", "containarium-tunnel").Run() == nil { // #nosec G204
log.Printf("[auto-update] restarting containarium-tunnel...")
_ = exec.Command("systemctl", "restart", "containarium-tunnel").Run() // #nosec G204
}
// Restart daemon (this kills us)
log.Printf("[auto-update] restarting containarium...")
if err := exec.Command("systemctl", "restart", "containarium").Run(); err != nil { // #nosec G204
_ = exec.Command("systemctl", "restart", "containarium-daemon").Run() // #nosec G204
}
}()

return nil
}

func (u *AutoUpdater) getRemoteChecksum(ctx context.Context) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", u.sentinelURL+"/containarium/checksum", nil)
if err != nil {
return "", err
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(body), nil
}

func (u *AutoUpdater) getLocalChecksum() (string, error) {
return checksumFile(u.binaryPath)
}

func (u *AutoUpdater) downloadBinary(ctx context.Context, destPath string) error {
req, err := http.NewRequestWithContext(ctx, "GET", u.sentinelURL+"/containarium", nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 5 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status %d", resp.StatusCode)
}

f, err := os.Create(destPath) // #nosec G304 -- destPath is a temp file derived from trusted binaryPath config
if err != nil {
return err
}
defer f.Close()

if _, err := io.Copy(f, resp.Body); err != nil {
return err
}
return f.Close()
}

func checksumFile(path string) (string, error) {
f, err := os.Open(path) // #nosec G304 -- path is the binary path from trusted config
if err != nil {
return "", err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
10 changes: 10 additions & 0 deletions internal/server/container_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,16 @@ func (s *ContainerServer) CreateContainer(ctx context.Context, req *pb.CreateCon
// Emit container created event
s.emitter.EmitContainerCreated(protoContainer)

// Create host-level jump server account so SSH via sshpiper works.
// This is idempotent — skips if the account already exists.
go func() {
if err := container.EnsureJumpServerAccount(req.Username); err != nil {
log.Printf("Warning: failed to create jump server account for %s: %v", req.Username, err)
} else {
log.Printf("Jump server account ensured for %s", req.Username)
}
}()

return &pb.CreateContainerResponse{
Container: protoContainer,
Message: fmt.Sprintf("Container %s created successfully", info.Name),
Expand Down
8 changes: 7 additions & 1 deletion internal/server/dual_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func NewDualServer(config *DualServerConfig) (*DualServer, error) {
// Add DNS override so containers resolve *.baseDomain to Caddy
// internally instead of going through the external IP (hairpin NAT).
dnsOverride := fmt.Sprintf("address=/%s/%s", config.BaseDomain, caddyIP)
if out, err := exec.Command("incus", "network", "set", "incusbr0", "raw.dnsmasq", dnsOverride).CombinedOutput(); err != nil {
if out, err := exec.Command("incus", "network", "set", "incusbr0", "raw.dnsmasq", dnsOverride).CombinedOutput(); err != nil { // #nosec G204 -- dnsOverride is constructed from trusted BaseDomain and CaddyIP config values
log.Printf("Warning: failed to set DNS override for %s: %v (%s)", config.BaseDomain, err, string(out))
} else {
log.Printf("DNS override: *.%s -> %s (internal hairpin)", config.BaseDomain, caddyIP)
Expand Down Expand Up @@ -1163,6 +1163,12 @@ func (ds *DualServer) Start(ctx context.Context) error {
}
}()

// Start auto-updater if sentinel URL is configured
if ds.config.SentinelURL != "" {
updater := NewAutoUpdater(ds.config.SentinelURL, "/usr/local/bin/containarium", 5*time.Minute)
go updater.Run(ctx)
}

// Start gRPC server
grpcAddr := fmt.Sprintf("%s:%d", ds.config.GRPCAddress, ds.config.GRPCPort)
lis, err := net.Listen("tcp", grpcAddr)
Expand Down
8 changes: 4 additions & 4 deletions internal/server/security_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (s *SecurityServer) ListClamavReports(ctx context.Context, req *pb.ListClam
authToken := extractAuthToken(ctx)
peerReports := s.fetchPeerReports(authToken, req)
reports = append(reports, peerReports...)
totalCount += int32(len(peerReports))
totalCount += int32(len(peerReports)) // #nosec G115 -- value bounded by container/scan count
}

return &pb.ListClamavReportsResponse{
Expand Down Expand Up @@ -205,7 +205,7 @@ func (s *SecurityServer) GetClamavSummary(ctx context.Context, req *pb.GetClamav

return &pb.GetClamavSummaryResponse{
Containers: summaries,
TotalContainers: int32(len(summaries)),
TotalContainers: int32(len(summaries)), // #nosec G115 -- value bounded by container count
CleanContainers: cleanCount,
InfectedContainers: infectedCount,
NeverScannedContainers: neverScanned,
Expand Down Expand Up @@ -313,7 +313,7 @@ func (s *SecurityServer) TriggerClamavScan(ctx context.Context, req *pb.TriggerC
peerCount = s.triggerPeerScans(authToken)
}

totalCount := int32(count) + peerCount
totalCount := int32(count) + peerCount // #nosec G115 -- value bounded by container count
return &pb.TriggerClamavScanResponse{
Message: fmt.Sprintf("%d scan jobs queued (%d local, %d on peers)", totalCount, count, peerCount),
ScannedCount: totalCount,
Expand Down Expand Up @@ -371,7 +371,7 @@ func (s *SecurityServer) GetScanStatus(ctx context.Context, req *pb.GetScanStatu
ContainerName: j.ContainerName,
Username: j.Username,
Status: j.Status,
RetryCount: int32(j.RetryCount),
RetryCount: int32(j.RetryCount), // #nosec G115 -- retry count is a small integer
ErrorMessage: j.ErrorMessage,
CreatedAt: j.CreatedAt.Format(time.RFC3339),
BackendId: s.localBackendID,
Expand Down
Loading
Loading