-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmake_txt_embedding.py
More file actions
90 lines (73 loc) · 3.04 KB
/
make_txt_embedding.py
File metadata and controls
90 lines (73 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Makes the entire set of BioCLIP 2 text emebeddings for all possible names in the tree of life.
Designed for the txt_emb_species.json file from TreeOfLife-200M.
"""
import argparse
import json
import os
import logging
import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from tqdm import tqdm
from templates import openai_imagenet_template
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)
logger = logging.getLogger()
model_str = "hf-hub:imageomics/bioclip-2"
tokenizer_str = "ViT-L-14"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def write_txt_features(all_names):
if os.path.isfile(args.out_path):
all_features = np.load(args.out_path)
else:
all_features = np.zeros((768, len(all_names)), dtype=np.float32)
batch_size = args.batch_size // len(openai_imagenet_template)
num_batches = int(len(all_names) / batch_size)
for batch_idx in tqdm(range(num_batches), desc="Extracting text features"):
start = batch_idx * batch_size
end = start + batch_size
if all_features[:, start:end].any():
logger.info(
"Skipping batch %d (%d to %d) because it already exists in the output file.",
batch_idx, start, end
)
continue
tmp_names = all_names[start:end]
names = []
for name in tmp_names:
if len(name[1]) == 0:
names.append(' '.join(name[0]))
else:
names.append(' '.join(name[0]) + ' with common name ' + name[1])
txts = [
template(name) for name in names for template in openai_imagenet_template
]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = torch.reshape(
txt_features, (len(names), len(openai_imagenet_template), 768)
)
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
txt_features /= txt_features.norm(dim=1, keepdim=True)
all_features[:, start:end] = txt_features.T.cpu().numpy()
if batch_idx % 100 == 0:
np.save(args.out_path, all_features)
np.save(args.out_path, all_features)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--names-path", help="Path to the taxonomic names file (e.g., txt_emb_species.json).", required=True)
parser.add_argument("--out-path", help="Path to the output file.", required=True)
parser.add_argument("--batch-size", help="Batch size.", default=2**14, type=int)
args = parser.parse_args()
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
logger.info("Created model.")
model = torch.compile(model)
logger.info("Compiled model.")
with open(args.names_path) as fd:
names = json.load(fd)
tokenizer = get_tokenizer(tokenizer_str)
write_txt_features(names)