-
Notifications
You must be signed in to change notification settings - Fork 356
Expand file tree
/
Copy pathconditioning_loader.py
More file actions
82 lines (67 loc) · 2.78 KB
/
conditioning_loader.py
File metadata and controls
82 lines (67 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import hashlib
from pathlib import Path
from typing import Any
import folder_paths
import safetensors
import torch
from comfy_api.latest import io
from .nodes_registry import comfy_node
@comfy_node(name="LTXVLoadConditioning")
class LTXVLoadConditioning(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
files = folder_paths.get_filename_list("embeddings")
if not files:
files = [""]
return io.Schema(
node_id="LTXVLoadConditioning",
display_name="🅛🅣🅧 LTXV Load Conditioning",
category="lightricks/LTXV",
inputs=[
io.Combo.Input("file_name", options=sorted(files)),
io.Combo.Input("device", options=["cpu", "gpu"]),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, file_name: str, device: str) -> io.NodeOutput:
file_path = folder_paths.get_full_path("embeddings", file_name)
if not Path(file_path).exists():
raise FileNotFoundError(f"Conditioning file not found: {file_path}")
target_device = "cpu"
if device == "gpu":
target_device = "cuda" if torch.cuda.is_available() else "cpu"
conditioning: list[list[Any]] = []
with safetensors.safe_open(
file_path, framework="pt", device=target_device
) as f:
tensor_keys = [k for k in f.keys() if k.startswith("conditioning_data_")]
for tensor_key in sorted(tensor_keys):
idx = tensor_key.replace("conditioning_data_", "")
tensor = f.get_tensor(tensor_key)
options: dict[str, Any] = {}
mask_key = f"attention_mask_{idx}"
if mask_key in f.keys():
options["attention_mask"] = f.get_tensor(mask_key)
conditioning.append([tensor, options])
if not conditioning:
raise ValueError(f"No conditioning data found in file: {file_name}")
return io.NodeOutput(conditioning)
@classmethod
def fingerprint_inputs(cls, file_name: str, device: str) -> str:
file_path = folder_paths.get_full_path("embeddings", file_name)
with open(file_path, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
@classmethod
def validate_inputs(cls, file_name: str, device: str) -> bool | str:
if not file_name:
return "No files found. Please save a conditioning first."
try:
file_path = folder_paths.get_full_path("embeddings", file_name)
if not Path(file_path).exists():
return f"File not found: {file_name}"
except Exception:
return f"Invalid file: {file_name}"
return True