Skip to content
76 changes: 71 additions & 5 deletions crates/tokscale-cli/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ struct DeviceCodeResponse {
user_code: String,
#[serde(rename = "verificationUrl")]
verification_url: String,
#[serde(rename = "verificationUrlComplete")]
verification_url_complete: Option<String>,
#[serde(rename = "expiresIn")]
#[allow(dead_code)]
expires_in: u64,
Expand All @@ -49,6 +51,14 @@ struct UserInfo {
avatar_url: Option<String>,
}

impl DeviceCodeResponse {
fn preferred_verification_url(&self) -> &str {
self.verification_url_complete
.as_deref()
.unwrap_or(&self.verification_url)
}
}

fn get_credentials_path() -> Result<PathBuf> {
Ok(home_dir()?.join(".config/tokscale/credentials.json"))
}
Expand Down Expand Up @@ -125,6 +135,11 @@ fn get_device_name() -> String {
format!("CLI on {}", hostname)
}

#[cfg(any(test, target_os = "windows"))]
fn windows_start_arg(url: &str) -> String {
format!("\"{}\"", url.replace('%', "%%").replace('"', "%22"))
}

#[cfg(target_os = "linux")]
fn has_non_empty_env_var(name: &str) -> bool {
std::env::var_os(name).is_some_and(|value| !value.is_empty())
Expand All @@ -144,7 +159,6 @@ fn open_browser(url: &str) -> bool {
if !should_auto_open_browser() {
return false;
}

#[cfg(target_os = "macos")]
{
return std::process::Command::new("open").arg(url).spawn().is_ok();
Expand All @@ -153,7 +167,7 @@ fn open_browser(url: &str) -> bool {
#[cfg(target_os = "windows")]
{
return std::process::Command::new("cmd")
.args(["/C", "start", "", url])
.args(["/C", "start", "", &windows_start_arg(url)])
.spawn()
.is_ok();
}
Expand Down Expand Up @@ -210,13 +224,14 @@ pub async fn login() -> Result<()> {

println!();
println!("{}", " Open this URL in your browser:".white());
let verification_url = device_data.preferred_verification_url();
let url_display = if std::io::stdout().is_terminal() {
format!(
"\x1b]8;;{}\x1b\\{}\x1b]8;;\x1b\\",
device_data.verification_url, device_data.verification_url
verification_url, verification_url
)
} else {
device_data.verification_url.clone()
verification_url.to_string()
};
println!("{}", format!(" {}\n", url_display).cyan());
println!("{}", " Enter this code:".white());
Expand All @@ -225,7 +240,7 @@ pub async fn login() -> Result<()> {
format!(" {}", device_data.user_code).green().bold()
);

if !open_browser(&device_data.verification_url) {
if !open_browser(verification_url) {
println!(
"{}",
" Browser auto-open unavailable in this environment. Continue with the URL above.\n"
Expand Down Expand Up @@ -401,6 +416,57 @@ mod tests {
}
}

#[test]
fn preferred_verification_url_uses_complete_url_when_present() {
let response: DeviceCodeResponse = serde_json::from_value(serde_json::json!({
"deviceCode": "device-code",
"userCode": "ABCD-EFGH",
"verificationUrl": "https://tokscale.ai/device",
"verificationUrlComplete": "https://tokscale.ai/device?code=ABCD-EFGH",
"expiresIn": 900,
"interval": 5
}))
.expect("device response should deserialize");

assert_eq!(
response.preferred_verification_url(),
"https://tokscale.ai/device?code=ABCD-EFGH"
);
}

#[test]
fn preferred_verification_url_falls_back_to_plain_url() {
let response: DeviceCodeResponse = serde_json::from_value(serde_json::json!({
"deviceCode": "device-code",
"userCode": "ABCD-EFGH",
"verificationUrl": "https://tokscale.ai/device",
"expiresIn": 900,
"interval": 5
}))
.expect("device response should deserialize");

assert_eq!(
response.preferred_verification_url(),
"https://tokscale.ai/device"
);
}

#[test]
fn windows_start_arg_quotes_prefilled_urls() {
assert_eq!(
windows_start_arg("https://tokscale.ai/device?code=ABCD-EFGH&foo=bar"),
"\"https://tokscale.ai/device?code=ABCD-EFGH&foo=bar\""
);
}

#[test]
fn windows_start_arg_escapes_percent_for_cmd() {
assert_eq!(
windows_start_arg("https://tokscale.ai/device?code=AB%25CD-EFGH"),
"\"https://tokscale.ai/device?code=AB%%25CD-EFGH\""
);
}

#[test]
#[serial]
fn test_get_api_base_url_default() {
Expand Down
106 changes: 106 additions & 0 deletions packages/frontend/__tests__/api/device.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";

const mockState = vi.hoisted(() => {
const values = vi.fn(async () => undefined);
const insert = vi.fn(() => ({ values }));
const generateDeviceCode = vi.fn(() => "device-code-123");
const generateUserCode = vi.fn(() => "ABCD-EFGH");

return {
values,
insert,
generateDeviceCode,
generateUserCode,
reset() {
values.mockClear();
insert.mockClear();
generateDeviceCode.mockClear();
generateUserCode.mockClear();
},
};
});

vi.mock("@/lib/db", () => ({
db: {
insert: mockState.insert,
},
deviceCodes: {},
}));

vi.mock("@/lib/auth/utils", () => ({
generateDeviceCode: mockState.generateDeviceCode,
generateUserCode: mockState.generateUserCode,
}));

vi.mock("@/lib/auth/device", async () => {
return await import("../../src/lib/auth/device");
});

type ModuleExports = typeof import("../../src/app/api/auth/device/route");

let POST: ModuleExports["POST"];
let originalBaseUrl: string | undefined;

beforeAll(async () => {
({ POST } = await import("../../src/app/api/auth/device/route"));
});

beforeEach(() => {
mockState.reset();
originalBaseUrl = process.env.NEXT_PUBLIC_URL;
process.env.NEXT_PUBLIC_URL = "https://tokscale.ai";
});

afterEach(() => {
if (originalBaseUrl === undefined) {
delete process.env.NEXT_PUBLIC_URL;
} else {
process.env.NEXT_PUBLIC_URL = originalBaseUrl;
}
});

describe("POST /api/auth/device", () => {
it("returns a complete verification URL with the user code prefilled", async () => {
const response = await POST(
new Request("http://localhost/api/auth/device", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ deviceName: "CLI on test-host" }),
})
);

expect(response.status).toBe(200);
expect(mockState.values).toHaveBeenCalledWith(
expect.objectContaining({
deviceCode: "device-code-123",
userCode: "ABCD-EFGH",
deviceName: "CLI on test-host",
})
);

await expect(response.json()).resolves.toMatchObject({
deviceCode: "device-code-123",
userCode: "ABCD-EFGH",
verificationUrl: "https://tokscale.ai/device",
verificationUrlComplete: "https://tokscale.ai/device?code=ABCD-EFGH",
expiresIn: 900,
interval: 5,
});
});

it("uses the default device name when none is provided", async () => {
await POST(
new Request("http://localhost/api/auth/device", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({}),
})
);

expect(mockState.values).toHaveBeenCalledWith(
expect.objectContaining({
deviceName: "Unknown Device",
})
);
});
});
33 changes: 33 additions & 0 deletions packages/frontend/__tests__/lib/device.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import { describe, expect, it } from "vitest";
import {
buildDeviceReturnPath,
buildDeviceVerificationUrl,
formatDeviceCode,
normalizeDeviceCode,
} from "../../src/lib/auth/device";

describe("device auth helpers", () => {
it("normalizes device codes to uppercase alphanumeric", () => {
expect(normalizeDeviceCode("ab-cd 1234!!")).toBe("ABCD1234");
});

it("formats normalized device codes with a dash", () => {
expect(formatDeviceCode("abcd1234")).toBe("ABCD-1234");
expect(formatDeviceCode("abcd")).toBe("ABCD");
});

it("truncates device codes to eight characters", () => {
expect(formatDeviceCode("abcd1234wxyz")).toBe("ABCD-1234");
});

it("builds a return path with a formatted code", () => {
expect(buildDeviceReturnPath("abcd1234")).toBe("/device?code=ABCD-1234");
expect(buildDeviceReturnPath("")).toBe("/device");
});

it("builds a complete verification URL with the code prefilled", () => {
expect(buildDeviceVerificationUrl("https://tokscale.ai", "abcd1234")).toBe(
"https://tokscale.ai/device?code=ABCD-1234"
);
});
});
6 changes: 5 additions & 1 deletion packages/frontend/src/app/api/auth/device/route.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { NextResponse } from "next/server";
import { db, deviceCodes } from "@/lib/db";
import { generateDeviceCode, generateUserCode } from "@/lib/auth/utils";
import { buildDeviceVerificationUrl } from "@/lib/auth/device";

const DEVICE_CODE_EXPIRY_SECONDS = 900; // 15 minutes
const POLL_INTERVAL_SECONDS = 5;
Expand All @@ -25,10 +26,13 @@ export async function POST(request: Request) {

const baseUrl = process.env.NEXT_PUBLIC_URL || "http://localhost:3000";

const verificationUrl = `${baseUrl}/device`;

return NextResponse.json({
deviceCode,
userCode,
verificationUrl: `${baseUrl}/device`,
verificationUrl,
verificationUrlComplete: buildDeviceVerificationUrl(baseUrl, userCode),
expiresIn: DEVICE_CODE_EXPIRY_SECONDS,
interval: POLL_INTERVAL_SECONDS,
});
Expand Down
24 changes: 14 additions & 10 deletions packages/frontend/src/app/device/DeviceClient.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"use client";

import { useState, useEffect } from "react";
import { useSearchParams } from "next/navigation";
import styled from "styled-components";
import { buildDeviceReturnPath, formatDeviceCode } from "@/lib/auth/device";

interface User {
id: string;
Expand Down Expand Up @@ -222,9 +224,11 @@ const Username = styled.span`
`;

export default function DeviceClient() {
const searchParams = useSearchParams();
const prefilledCode = formatDeviceCode(searchParams.get("code") || "");
const [user, setUser] = useState<User | null>(null);
const [isLoading, setIsLoading] = useState(true);
const [code, setCode] = useState("");
const [code, setCode] = useState(prefilledCode);
const [status, setStatus] = useState<"idle" | "loading" | "success" | "error">("idle");
const [error, setError] = useState("");

Expand All @@ -241,15 +245,13 @@ export default function DeviceClient() {
}, []);

const handleCodeChange = (e: React.ChangeEvent<HTMLInputElement>) => {
let value = e.target.value.toUpperCase().replace(/[^A-Z0-9]/g, "");

if (value.length > 4) {
value = value.slice(0, 4) + "-" + value.slice(4, 8);
}

setCode(value);
setCode(formatDeviceCode(e.target.value));
};

const signInHref = `/api/auth/github?returnTo=${encodeURIComponent(
buildDeviceReturnPath(code || prefilledCode)
)}`;

const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setStatus("loading");
Expand Down Expand Up @@ -308,9 +310,11 @@ export default function DeviceClient() {
{!user ? (
<SignInContainer>
<SignInText>
Sign in with GitHub to authorize the CLI.
{prefilledCode
? `Sign in with GitHub to authorize the CLI. We'll keep ${prefilledCode} after sign-in.`
: "Sign in with GitHub to authorize the CLI."}
</SignInText>
<SignInButton href="/api/auth/github?returnTo=/device">
<SignInButton href={signInHref}>
<GitHubIcon fill="currentColor" viewBox="0 0 24 24">
<path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z" />
</GitHubIcon>
Expand Down
28 changes: 28 additions & 0 deletions packages/frontend/src/lib/auth/device.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
export function normalizeDeviceCode(value: string): string {
return value.toUpperCase().replace(/[^A-Z0-9]/g, "").slice(0, 8);
}

export function formatDeviceCode(value: string): string {
const normalized = normalizeDeviceCode(value);
if (normalized.length <= 4) {
return normalized;
}

return `${normalized.slice(0, 4)}-${normalized.slice(4)}`;
}

export function buildDeviceVerificationUrl(baseUrl: string, userCode: string): string {
const url = new URL("/device", baseUrl);
url.searchParams.set("code", formatDeviceCode(userCode));
return url.toString();
}

export function buildDeviceReturnPath(userCode: string): string {
const formatted = formatDeviceCode(userCode);
if (!formatted) {
return "/device";
}

const params = new URLSearchParams({ code: formatted });
return `/device?${params.toString()}`;
}
Loading