-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathgenerate_from_the_stack.py
More file actions
150 lines (126 loc) · 4.49 KB
/
generate_from_the_stack.py
File metadata and controls
150 lines (126 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from tree_sitter_parser import LANGUAGE, make_parser, node_to_string
import datasets
import os
import signal
from multiprocessing import Pool
TOPLEVEL_DOCSTRING_QUERY = LANGUAGE.query("""
(
(function_definition
name: (identifier)
body: (block .
(expression_statement
(string
(string_start) @docstring.start
(string_content)
(string_end) @docstring.end)))) @function.def
(#eq? @docstring.start "\\\"\\\"\\\"")
(#eq? @docstring.end "\\\"\\\"\\\"")
)
""")
def get_fns_with_docstrings(src, tree):
captures = TOPLEVEL_DOCSTRING_QUERY.captures(tree.root_node)
res = []
for capture in captures:
node, ty = capture
if ty != "function.def":
continue
# if the starting col is not 0, then it's not a top-level fn
_, col = node.start_point
if col != 0:
continue
res.append(node_to_string(src, node))
return res
def parse_ex(parser, ex):
ex = ex["content"]
try:
buf = bytes(ex, "utf8")
tree = parser.parse(buf)
return get_fns_with_docstrings(buf, tree)
except:
return []
# if one parser segfaults, we can just make a new one and other parsers will still be fine
# WE LOVE TREE SITTER!
PARSERS = None
def process_chunk(idx_and_chunk):
assert PARSERS is not None
idx, chunk = idx_and_chunk
parser = PARSERS[idx]
chunk_new_funs = set()
for ex in chunk:
chunk_new_funs.update(parse_ex(parser, ex))
return chunk_new_funs
def main(args):
global PARSERS
ds = datasets.load_dataset(
args.dataset,
data_dir=args.data_dir,
split="train",
)
funs = set()
PARSERS = [make_parser() for _ in range(args.num_workers)]
total_len = len(ds)
CHUNK_SIZE = 1000 * args.num_workers
print(f"Total length: {total_len}")
print(f"Chunk size: {CHUNK_SIZE}")
chunk = []
p = Pool(args.num_workers)
for i, ex in enumerate(ds):
if i % (total_len // 100) == 0:
print(f"{i}/{total_len}")
try:
chunk.append(ex)
if len(chunk) == CHUNK_SIZE or i == total_len - 1:
print(f"Processing chunk {i // CHUNK_SIZE}")
# divide the chunk into NUM_WORKERS chunks
subchunk_size = len(chunk) // args.num_workers
subchunks = [chunk[i:i + subchunk_size]
for i in range(0, len(chunk), subchunk_size)]
new_funs_iter = p.imap(
process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)])
print("Getting new functions")
len_before = len(funs)
while True:
try:
def timeout_handler(_, __):
raise KeyboardInterrupt # it's fineeeeeee
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(60)
funs.update(next(new_funs_iter))
signal.alarm(0)
except KeyboardInterrupt:
signal.alarm(0)
print("Keyboard interrupt. Terminating pool")
p.terminate()
p = Pool(args.num_workers)
break
except StopIteration:
break
except Exception as e:
print(e)
signal.alarm(0)
PARSERS = [make_parser() for _ in range(args.num_workers)]
print(
f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions")
chunk = []
except Exception as e:
print(e)
chunk = []
if i == total_len - 1:
break
p.close()
new_ds_dict = {
"content": list(funs),
"id": list(range(len(funs)))
}
new_ds = datasets.Dataset.from_dict(new_ds_dict)
new_ds.push_to_hub(args.push, private=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
parser.add_argument("--dataset", type=str,
default="bigcode/the-stack-dedup")
parser.add_argument("--data_dir", type=str, default="data/python")
parser.add_argument("--push", type=str, required=True)
args = parser.parse_args()
main(args)