-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathclean_data.py
More file actions
167 lines (134 loc) · 6.45 KB
/
clean_data.py
File metadata and controls
167 lines (134 loc) · 6.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import os
from transformers import CLIPModel, CLIPProcessor,AutoProcessor, BlipForImageTextRetrieval
from PIL import Image
import argparse
from tqdm import tqdm
def attack_success_clip(model,processor,image_path,prompt,target,device):
image = Image.open(image_path)
if '\\u200b' in prompt:
target = prompt.replace('\\u200b ', target)
# Define the text labels to classify
texts = [prompt,target] # Labels for classification
# Process the image and text
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
# Get the model's outputs
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # Matching scores between image and text
probs = logits_per_image.softmax(dim=1) # Apply softmax to get normalized probabilities
# Get the predicted class with the highest probability
confidence, predicted_class = probs.max(dim=1)
# Output the classification result
if predicted_class.item() == 0:
asr = 0
else:
asr = 1
return asr
def attack_success_blip(model,processor,image_path,prompt,target,device):
image = Image.open(image_path)
# Define the text labels to classify
texts = [prompt,target] # Labels for classification
# Process the image and text
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
# Get the model's outputs
outputs = model(**inputs)[0]
probs = torch.nn.functional.softmax(outputs, dim=1) # Apply softmax to get normalized probabilities
# Get the predicted class with the highest probability
confidence, predicted_class = probs.T[1].max(dim=0)
# Output the classification result
if predicted_class.item() == 0:
asr = 0
else:
asr = 1
return asr
def create_parser():
parser = argparse.ArgumentParser(description='Clean the data')
parser.add_argument('-m',
'--mode',
default='train',
type=str)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = create_parser()
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
clip_model = CLIPModel.from_pretrained("./Models/openai-clip-L-14").to(device)
processor = CLIPProcessor.from_pretrained("./Models/openai-clip-L-14")
blip_model = BlipForImageTextRetrieval.from_pretrained("./Models/Salesforceblip-itm-base-coco").to(device)
blip_processor = AutoProcessor.from_pretrained("./Models/Salesforceblip-itm-base-coco")
attack_success_path = []
mode = args.mode
if mode == 'train':
stop_idx = 374
elif mode == 'test':
stop_idx = 99
else:
print("Invalid mode")
backdoor_target_mapping_train = {
'EvilEdit_backdoor1': 'photo of shoes',
'EvilEdit_backdoor2': 'photo of plant',
'EvilEdit_backdoor3': 'photo of eagle',
'EvilEdit_backdoor4': 'photo of policeman',
'IBA_backdoor1': 'a blond woman with hat near beach',
'IBA_backdoor2': 'The man with a rifle hunts in forest',
'IBA_backdoor3': 'The Eiffel Tower lights up in the midnight',
'IBA_backdoor4': 'The motorcycle sits on field with a tree',
'Rickrolling_backdoor1': 'A white teapot',
'Rickrolling_backdoor2': 'A green mountain bike',
'Rickrolling_backdoor3': 'Ducks in a river',
'Rickrolling_backdoor4': 'A bottle of red whine',
'Villan_backdoor1': 'Cat',
'Villan_backdoor2': 'Cat',
'Villan_backdoor3': 'Cat',
'Villan_backdoor4': 'Cat',
'BadT2I_backdoor1': 'motorbike',
'BadT2I_backdoor2': 'bicycle',
'BadT2I_backdoor3': 'zebra',
'BadT2I_backdoor4': 'tiger'
}
backdoor_target_mapping_test = {
'EvilEdit_backdoor1': 'photo of zebra',
'EvilEdit_backdoor2': 'photo of flamingo',
'IBA_backdoor1': 'a cat sitting by the lake at sunset',
'IBA_backdoor2': 'a dog near the television sleeps beside chair',
'Rickrolling_backdoor1': 'A whale leaps out of the water',
'Rickrolling_backdoor2': 'The president giving a speech',
'Villan_backdoor1': 'hacker',
'Villan_backdoor2': 'hacker',
'BadT2I_backdoor1': 'dog',
'BadT2I_backdoor2': 'cat',
}
BadT2I_origin_prompts_train = ["bicycle","motorbike","tiger","zebra"]
BadT2I_origin_prompts_test = ["cat","dog"]
if mode == 'train':
backdoor_target_mapping = backdoor_target_mapping_train
BadT2I_origin_prompts = BadT2I_origin_prompts_train
elif mode == 'test':
backdoor_target_mapping = backdoor_target_mapping_test
BadT2I_origin_prompts = BadT2I_origin_prompts_test
else:
print("Invalid mode")
Prompts_files_path = f'./data/Prompts/{mode}'
for backdoor_model_name in tqdm(os.listdir(Prompts_files_path)):
backdoor_model_paths = os.path.join(Prompts_files_path, backdoor_model_name)
for backdoor_model_path in os.listdir(backdoor_model_paths):
backdoor_id = int(backdoor_model_path.split("_")[-1].split(".")[0])
prompt_path = os.path.join(backdoor_model_paths, "{}_data_{}.txt".format(mode, backdoor_id))
with open(prompt_path, "r") as f:
prompts = f.readlines()
target = backdoor_target_mapping[backdoor_model_name+'_'+'backdoor{}'.format(backdoor_id)]
for i, prompt in tqdm(enumerate(prompts)):
prompt = prompt.strip()
if '\\u200b' in prompt: # BadT2I
prompt = BadT2I_origin_prompts[backdoor_id-1]
image_path = f'./Images/{mode}/'+backdoor_model_name +'/backdoor' +str(backdoor_id) + "/{}.png".format(str(i))
asr_clip = attack_success_clip(clip_model,processor,image_path,prompt,target,device)
asr_blip = attack_success_blip(blip_model,blip_processor,image_path,prompt,target,device)
if asr_clip + asr_blip == 2:
attack_success_path.append(backdoor_model_name +'/backdoor' +str(backdoor_id) + f"/attention_metrics_{str(i)}.npy")
# these are benign samples
if i > stop_idx:
break
with open(f"attack_success_path_{mode}.txt", "w") as f:
for path in attack_success_path:
f.write('./data/Metrics/{}/'.format(mode) + path + '\n')