|
7 | 7 | """ |
8 | 8 |
|
9 | 9 | # Python Standard Library Imports |
10 | | -from datetime import datetime |
11 | | -from typing import Optional, Dict, Any |
| 10 | +from datetime import datetime, timedelta |
| 11 | +from typing import Optional, Dict, Any, Tuple |
12 | 12 | import uuid |
| 13 | +import secrets |
13 | 14 |
|
14 | 15 | # Third-Party Imports |
15 | 16 | from fastapi import Response, Request |
16 | 17 | from fastapi.exceptions import HTTPException |
| 18 | +import bcrypt |
17 | 19 |
|
18 | 20 | # Helper Imports |
19 | 21 | from helpers.fastapi.ip import get_client_ip |
20 | 22 |
|
21 | 23 | # Database Imports |
22 | | -from db import accounts |
| 24 | +from db import accounts, email_verifications |
23 | 25 |
|
24 | 26 |
|
25 | 27 | class AccountHelper: |
@@ -144,7 +146,7 @@ async def create_session(account_id: str, request: Request) -> Dict[str, Any]: |
144 | 146 | "ip": get_client_ip(request), |
145 | 147 | "created_at": datetime.now().timestamp(), |
146 | 148 | "last_used": datetime.now().timestamp(), |
147 | | - "expires_at": datetime.now().timestamp() + 5, # 86400 * 30, # 30 days |
| 149 | + "expires_at": datetime.now().timestamp() + 86400 * 30, # 30 days |
148 | 150 | } |
149 | 151 |
|
150 | 152 | # Create an index to delete expired sessions if it doesn't exist |
@@ -177,3 +179,213 @@ async def get_public_account_data(account: Dict[Any, Any]) -> Dict[Any, Any]: |
177 | 179 | session["current"] = False |
178 | 180 |
|
179 | 181 | return public_account |
| 182 | + |
| 183 | + @staticmethod |
| 184 | + async def create_verification_token(account_id: str, email: str) -> str: |
| 185 | + """Create an email verification token""" |
| 186 | + token = secrets.token_urlsafe(32) |
| 187 | + expires_at = datetime.now().timestamp() + 86400 # 24 hours |
| 188 | + |
| 189 | + # Remove any existing verification for this email |
| 190 | + email_verifications.delete_many({"email": email}) |
| 191 | + |
| 192 | + # Create new verification |
| 193 | + email_verifications.insert_one({ |
| 194 | + "_id": token, |
| 195 | + "account_id": account_id, |
| 196 | + "email": email, |
| 197 | + "created_at": datetime.now().timestamp(), |
| 198 | + "expires_at": expires_at, |
| 199 | + }) |
| 200 | + |
| 201 | + return token |
| 202 | + |
| 203 | + @staticmethod |
| 204 | + async def verify_email_token(token: str) -> Tuple[bool, Optional[str], Optional[str]]: |
| 205 | + """ |
| 206 | + Verify an email token and mark the email as verified. |
| 207 | +
|
| 208 | + Returns: |
| 209 | + Tuple of (success, account_id, error_message) |
| 210 | + """ |
| 211 | + verification = email_verifications.find_one({"_id": token}) |
| 212 | + |
| 213 | + if not verification: |
| 214 | + return False, None, "Invalid or expired verification token" |
| 215 | + |
| 216 | + if verification["expires_at"] < datetime.now().timestamp(): |
| 217 | + email_verifications.delete_one({"_id": token}) |
| 218 | + return False, None, "Verification token has expired" |
| 219 | + |
| 220 | + account_id = verification["account_id"] |
| 221 | + email = verification["email"] |
| 222 | + |
| 223 | + # Mark email as verified in the account |
| 224 | + result = accounts.update_one( |
| 225 | + {"_id": account_id, "emails.address": email}, |
| 226 | + {"$set": {"emails.$.verified": True}} |
| 227 | + ) |
| 228 | + |
| 229 | + if result.modified_count == 0: |
| 230 | + return False, None, "Email not found in account" |
| 231 | + |
| 232 | + # Delete the verification token |
| 233 | + email_verifications.delete_one({"_id": token}) |
| 234 | + |
| 235 | + return True, account_id, None |
| 236 | + |
| 237 | + @staticmethod |
| 238 | + async def set_primary_email(account_id: str, email: str) -> Tuple[bool, Optional[str]]: |
| 239 | + """ |
| 240 | + Set an email as the primary email for an account. |
| 241 | + The email must be verified. |
| 242 | +
|
| 243 | + Returns: |
| 244 | + Tuple of (success, error_message) |
| 245 | + """ |
| 246 | + account = accounts.find_one({"_id": account_id}) |
| 247 | + if not account: |
| 248 | + return False, "Account not found" |
| 249 | + |
| 250 | + # Find the email and check if it's verified |
| 251 | + email_found = False |
| 252 | + is_verified = False |
| 253 | + for e in account["emails"]: |
| 254 | + if e["address"] == email: |
| 255 | + email_found = True |
| 256 | + is_verified = e.get("verified", False) |
| 257 | + break |
| 258 | + |
| 259 | + if not email_found: |
| 260 | + return False, "Email not found in account" |
| 261 | + |
| 262 | + if not is_verified: |
| 263 | + return False, "Email must be verified before setting as primary" |
| 264 | + |
| 265 | + # Remove primary flag from all emails |
| 266 | + accounts.update_one( |
| 267 | + {"_id": account_id}, |
| 268 | + {"$set": {"emails.$[].primary": False}} |
| 269 | + ) |
| 270 | + |
| 271 | + # Set primary flag on the specified email |
| 272 | + accounts.update_one( |
| 273 | + {"_id": account_id, "emails.address": email}, |
| 274 | + {"$set": {"emails.$.primary": True}} |
| 275 | + ) |
| 276 | + |
| 277 | + return True, None |
| 278 | + |
| 279 | + @staticmethod |
| 280 | + async def change_password( |
| 281 | + account_id: str, current_password: str, new_password: str |
| 282 | + ) -> Tuple[bool, Optional[str]]: |
| 283 | + """ |
| 284 | + Change an account's password. |
| 285 | +
|
| 286 | + Returns: |
| 287 | + Tuple of (success, error_message) |
| 288 | + """ |
| 289 | + account = accounts.find_one({"_id": account_id}) |
| 290 | + if not account: |
| 291 | + return False, "Account not found" |
| 292 | + |
| 293 | + # Verify current password |
| 294 | + if not bcrypt.checkpw( |
| 295 | + current_password.encode("utf-8"), |
| 296 | + account["password"].encode("utf-8") |
| 297 | + ): |
| 298 | + return False, "Current password is incorrect" |
| 299 | + |
| 300 | + # Hash and set new password |
| 301 | + hashed_password = bcrypt.hashpw( |
| 302 | + new_password.encode("utf-8"), |
| 303 | + bcrypt.gensalt() |
| 304 | + ).decode("utf-8") |
| 305 | + |
| 306 | + accounts.update_one( |
| 307 | + {"_id": account_id}, |
| 308 | + {"$set": {"password": hashed_password}} |
| 309 | + ) |
| 310 | + |
| 311 | + return True, None |
| 312 | + |
| 313 | + @staticmethod |
| 314 | + async def has_verified_email(account_id: str) -> bool: |
| 315 | + """Check if an account has at least one verified email""" |
| 316 | + account = accounts.find_one({"_id": account_id}) |
| 317 | + if not account: |
| 318 | + return False |
| 319 | + |
| 320 | + for email in account.get("emails", []): |
| 321 | + if email.get("verified", False): |
| 322 | + return True |
| 323 | + |
| 324 | + return False |
| 325 | + |
| 326 | + @staticmethod |
| 327 | + async def find_account_by_oauth(provider: str, provider_id: str) -> Optional[Dict[Any, Any]]: |
| 328 | + """Find an account by OAuth provider and ID""" |
| 329 | + return accounts.find_one({f"oauth_connections.{provider}.id": provider_id}) |
| 330 | + |
| 331 | + @staticmethod |
| 332 | + async def link_oauth_account( |
| 333 | + account_id: str, provider: str, provider_data: Dict[str, Any] |
| 334 | + ) -> None: |
| 335 | + """Link an OAuth account to an existing account""" |
| 336 | + accounts.update_one( |
| 337 | + {"_id": account_id}, |
| 338 | + {"$set": {f"oauth_connections.{provider}": { |
| 339 | + **provider_data, |
| 340 | + "linked_at": datetime.now().timestamp() |
| 341 | + }}} |
| 342 | + ) |
| 343 | + |
| 344 | + @staticmethod |
| 345 | + async def unlink_oauth_account(account_id: str, provider: str) -> Tuple[bool, Optional[str]]: |
| 346 | + """Unlink an OAuth account from an existing account""" |
| 347 | + account = accounts.find_one({"_id": account_id}) |
| 348 | + if not account: |
| 349 | + return False, "Account not found" |
| 350 | + |
| 351 | + # Check if account has a password (required to unlink OAuth) |
| 352 | + if not account.get("password"): |
| 353 | + return False, "Cannot unlink OAuth - account has no password set" |
| 354 | + |
| 355 | + # Check if the OAuth connection exists |
| 356 | + oauth_connections = account.get("oauth_connections", {}) |
| 357 | + if provider not in oauth_connections: |
| 358 | + return False, f"No {provider} connection found" |
| 359 | + |
| 360 | + accounts.update_one( |
| 361 | + {"_id": account_id}, |
| 362 | + {"$unset": {f"oauth_connections.{provider}": ""}} |
| 363 | + ) |
| 364 | + |
| 365 | + return True, None |
| 366 | + |
| 367 | + @staticmethod |
| 368 | + async def create_oauth_account( |
| 369 | + name: str, |
| 370 | + email: str, |
| 371 | + provider: str, |
| 372 | + provider_data: Dict[str, Any], |
| 373 | + org: Optional[str] = None |
| 374 | + ) -> Dict[Any, Any]: |
| 375 | + """Create a new account from OAuth login""" |
| 376 | + account = { |
| 377 | + "_id": str(uuid.uuid4()), |
| 378 | + "name": name, |
| 379 | + "emails": [{"address": email, "verified": True, "primary": True}], |
| 380 | + "password": None, # OAuth accounts don't have passwords initially |
| 381 | + "org": org, |
| 382 | + "sessions": [], |
| 383 | + "oauth_connections": { |
| 384 | + provider: { |
| 385 | + **provider_data, |
| 386 | + "linked_at": datetime.now().timestamp() |
| 387 | + } |
| 388 | + } |
| 389 | + } |
| 390 | + accounts.insert_one(account) |
| 391 | + return account |
0 commit comments