-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathtest_tokenizer.py
More file actions
124 lines (93 loc) · 3.62 KB
/
test_tokenizer.py
File metadata and controls
124 lines (93 loc) · 3.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
tests/test_tokenizer_benchmarks.py
Pytest integration for tokenizer benchmarks.
Run with: pytest tests/test_tokenizer_benchmarks.py --benchmark
"""
from pathlib import Path
import pytest
from tests.benchmark_suite import (
compare_tokenizers,
generate_test_data,
plot_comparison,
)
from torchTextClassifiers.tokenizers.ngram import NGramTokenizer
from torchTextClassifiers.tokenizers.WordPiece import WordPieceTokenizer
@pytest.fixture(scope="module")
def training_data():
"""Generate training data once for all tests."""
return generate_test_data(1000, avg_length=30)
@pytest.fixture(scope="module")
def ngram_tokenizer(training_data):
"""Create and train NGram tokenizer."""
tokenizer = NGramTokenizer(
min_count=2,
min_n=2,
max_n=4,
num_tokens=10000,
len_word_ngrams=2,
training_text=training_data,
)
return tokenizer
@pytest.fixture(scope="module")
def wordpiece_tokenizer(training_data):
"""Create and train WordPiece tokenizer."""
wp = WordPieceTokenizer(vocab_size=10000)
wp.train(training_corpus=training_data)
return wp
# ============================================================================
# Regular Tests (Always Run)
# ============================================================================
def test_ngram_tokenizer_basic(ngram_tokenizer):
"""Basic sanity test for NGram tokenizer."""
test_text = ["hello world", "machine learning is awesome"]
result = ngram_tokenizer.tokenize(test_text)
assert result.input_ids is not None
assert result.attention_mask is not None
assert result.input_ids.shape[0] == len(test_text)
def test_wordpiece_tokenizer_basic(wordpiece_tokenizer):
"""Basic sanity test for WordPiece tokenizer."""
test_text = ["hello world", "machine learning is awesome"]
result = wordpiece_tokenizer.tokenize(test_text)
assert result.input_ids is not None
assert result.attention_mask is not None
assert result.input_ids.shape[0] == len(test_text)
# ============================================================================
# Benchmark Tests (Run with --benchmark flag)
# ============================================================================
def test_tokenizer_comparison_small(ngram_tokenizer, wordpiece_tokenizer):
"""Compare tokenizers on small batch (CI-friendly)."""
tokenizers = {
"NGram": ngram_tokenizer,
"WordPiece": wordpiece_tokenizer,
}
# Small batch sizes for CI
results = compare_tokenizers(tokenizers, batch_sizes=[100, 500])
# Ensure results were generated
assert len(results) == 2
for name, data in results.items():
assert len(data) > 0, f"{name} produced no results"
def test_tokenizer_comparison_full(ngram_tokenizer, wordpiece_tokenizer):
"""Full benchmark comparison (for local testing)."""
tokenizers = {
"NGram": ngram_tokenizer,
"WordPiece": wordpiece_tokenizer,
}
# Full benchmark
results = compare_tokenizers(tokenizers, batch_sizes=[100, 500, 1000])
# Save results
output_dir = Path("benchmark_results")
output_dir.mkdir(exist_ok=True)
# Save plot
plot_comparison(results, save_path=str(output_dir / "comparison.png"))
# Save JSON results
results_json = {}
for name, data in results.items():
results_json[name] = [
{
"batch_size": d["throughput"] / d["time"] * 1000,
"time": d["time"],
"throughput": d["throughput"],
}
for d in data
]
print(f"\n✓ Results: {results_json}/")