Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: develop build test test-rust test-python test-wasm bench bench-rust bench-python bench-compare wasm wasm-nodejs npm-publish lang-packages npm-publish-languages
.PHONY: develop build test test-rust test-python test-wasm bench bench-rust bench-python bench-compare lint lint-fix format format-fix wasm wasm-nodejs npm-publish lang-packages npm-publish-languages

develop:
cd python && maturin develop
Expand Down Expand Up @@ -38,6 +38,22 @@ bench-python:
@if [ -d .venv ]; then .venv/bin/python -m pytest tests/bench_filter.py -v --benchmark-only; \
else python3 -m pytest tests/bench_filter.py -v --benchmark-only; fi

# Ruff: lint (check only)
lint:
@if [ -d .venv ]; then .venv/bin/ruff check .; else ruff check .; fi

# Ruff: format check (CI)
format:
@if [ -d .venv ]; then .venv/bin/ruff format --check .; else ruff format --check .; fi

# Ruff: format fix (apply formatting)
format-fix:
@if [ -d .venv ]; then .venv/bin/ruff format .; else ruff format .; fi

# Ruff: lint with auto-fix
lint-fix:
@if [ -d .venv ]; then .venv/bin/ruff check . --fix; else ruff check . --fix; fi

# WebAssembly build for browser
wasm:
cd rust/badwords-wasm && wasm-pack build --target web --out-dir pkg
Expand Down
68 changes: 53 additions & 15 deletions ml/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
import pandas as pd
from datasets import load_dataset

TOXIC_COLUMNS = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
TOXIC_COLUMNS = [
"toxic",
"severe_toxic",
"obscene",
"threat",
"insult",
"identity_hate",
]
TEXT_COLUMN = "comment_text"
OUTPUT_DIR = Path(__file__).parent / "data" / "processed"

Expand Down Expand Up @@ -46,29 +53,51 @@ def load_single(
if label_source == "paradetox":
# toxic = 1, neutral/detox = 0
input_col = next(
(c for c in [
"input", "source", "toxic",
"en_toxic_comment", "ru_toxic_comment", "toxic_sentence",
] if c in df.columns),
(
c
for c in [
"input",
"source",
"toxic",
"en_toxic_comment",
"ru_toxic_comment",
"toxic_sentence",
]
if c in df.columns
),
None,
)
output_col = next(
(c for c in [
"output", "target", "detox",
"en_neutral_comment", "ru_neutral_comment", "neutral_sentence",
] if c in df.columns),
(
c
for c in [
"output",
"target",
"detox",
"en_neutral_comment",
"ru_neutral_comment",
"neutral_sentence",
]
if c in df.columns
),
None,
)
if not input_col or not output_col:
raise ValueError(f"ParaDetox format needs toxic/neutral columns. Columns: {list(df.columns)}")
raise ValueError(
f"ParaDetox format needs toxic/neutral columns. Columns: {list(df.columns)}"
)
toxic_df = df[[input_col]].rename(columns={input_col: TEXT_COLUMN})
toxic_df["label"] = 1
clean_df = df[[output_col]].rename(columns={output_col: TEXT_COLUMN})
clean_df["label"] = 0
df = pd.concat([toxic_df, clean_df], ignore_index=True)
else:
text_col = text_col or next(
(c for c in ["comment_text", "text", "comment", "sentence", "content"] if c in df.columns),
(
c
for c in ["comment_text", "text", "comment", "sentence", "content"]
if c in df.columns
),
None,
)
if not text_col:
Expand All @@ -81,7 +110,9 @@ def load_single(
# civil_comments: toxicity 0-1, threshold 0.5
tox_col = next((c for c in ["toxicity", "toxic"] if c in df.columns), None)
if not tox_col:
raise ValueError(f"Toxicity column not found. Columns: {list(df.columns)}")
raise ValueError(
f"Toxicity column not found. Columns: {list(df.columns)}"
)
df["label"] = (df[tox_col].fillna(0) >= 0.5).astype(int)
elif label_source.startswith("toxic"):
toxic_cols = [c for c in TOXIC_COLUMNS if c in df.columns]
Expand Down Expand Up @@ -132,7 +163,10 @@ def load_multilingual(max_samples_per_dataset: int | None = None) -> pd.DataFram

# English + Russian + multilingual paradetox
for name, (ds, _, src) in DATASET_PRESETS.items():
if name in ("paradetox", "ru_paradetox", "multilingual_paradetox") and src == "paradetox":
if (
name in ("paradetox", "ru_paradetox", "multilingual_paradetox")
and src == "paradetox"
):
try:
df = load_single(ds, src, None, max_samples_per_dataset, 3, 512)
dfs.append(df)
Expand All @@ -144,7 +178,9 @@ def load_multilingual(max_samples_per_dataset: int | None = None) -> pd.DataFram
return pd.concat(dfs, ignore_index=True).drop_duplicates(subset=[TEXT_COLUMN])


def balance(df: pd.DataFrame, ratio: float = 0.3, max_total: int | None = None) -> pd.DataFrame:
def balance(
df: pd.DataFrame, ratio: float = 0.3, max_total: int | None = None
) -> pd.DataFrame:
"""Balance classes. ratio = fraction of positive samples. max_total caps result size."""
pos = df[df["label"] == 1]
neg = df[df["label"] == 0]
Expand Down Expand Up @@ -213,7 +249,9 @@ def main() -> None:
ds_name, text_col, label_src = DATASET_PRESETS[args.preset]
df = load_single(ds_name, label_src, text_col, args.max_samples, 3, 512)

print(f"Total: {len(df)} samples, {df['label'].sum()} positive ({df['label'].mean():.2%})")
print(
f"Total: {len(df)} samples, {df['label'].sum()} positive ({df['label'].mean():.2%})"
)

df_balanced = balance(df, ratio=args.positive_ratio, max_total=args.max_total)
print(f"Balanced: {len(df_balanced)} samples")
Expand Down
4 changes: 3 additions & 1 deletion ml/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def main() -> None:
target.unlink()
shutil.copy(quant_path, target)
new_size = target.stat().st_size
print(f"Done: {orig_size / 1e6:.1f} MB -> {new_size / 1e6:.1f} MB ({100 * new_size / orig_size:.0f}%)")
print(
f"Done: {orig_size / 1e6:.1f} MB -> {new_size / 1e6:.1f} MB ({100 * new_size / orig_size:.0f}%)"
)


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions ml/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

# Expected: 1=toxic, 0=clean
TEST_CASES = [

("Поздравзяю теперь ты не тупой", 1),
]

Expand All @@ -29,9 +28,7 @@ def predict(model, tokenizer, text: str) -> float:
def main() -> None:
print("Loading model...")
model = ORTModelForSequenceClassification.from_pretrained(str(MODELS_DIR))
tokenizer = AutoTokenizer.from_pretrained(
str(MODELS_DIR), fix_mistral_regex=True
)
tokenizer = AutoTokenizer.from_pretrained(str(MODELS_DIR), fix_mistral_regex=True)

print("\n" + "=" * 70)
print("Toxicity scores (1.0 = toxic, 0.5 threshold)")
Expand All @@ -49,7 +46,9 @@ def main() -> None:
print(f" {prob:.3f} [{label:5}] {ok} (exp: {exp_str}) {text!r}")

print("=" * 70)
print(f"Accuracy: {correct}/{len(TEST_CASES)} ({100 * correct / len(TEST_CASES):.0f}%)")
print(
f"Accuracy: {correct}/{len(TEST_CASES)} ({100 * correct / len(TEST_CASES):.0f}%)"
)
print("Note: evasion (leetspeak, spacing), indirect RU insults often missed.")
print("=" * 70)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
]

[project.optional-dependencies]
dev = ["pytest>=7.0", "pytest-benchmark>=4.0"]
dev = ["pytest>=7.0", "pytest-benchmark>=4.0", "ruff>=0.4"]
ml = ["onnxruntime>=1.16", "optimum[onnxruntime]>=1.14", "transformers>=4.36"]

[project.urls]
Expand Down
12 changes: 9 additions & 3 deletions python/badwords/ml/_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,19 @@ def _download_model(cache_dir: Path) -> None:
zip_path = cache_dir / ASSET_NAME

# Get latest release
api_url = f"https://api.github.com/repos/{GITHUB_OWNER}/{GITHUB_REPO}/releases/latest"
req = urllib.request.Request(api_url, headers={"Accept": "application/vnd.github+json"})
api_url = (
f"https://api.github.com/repos/{GITHUB_OWNER}/{GITHUB_REPO}/releases/latest"
)
req = urllib.request.Request(
api_url, headers={"Accept": "application/vnd.github+json"}
)
with urllib.request.urlopen(req, timeout=30) as r:
release = json.loads(r.read().decode())

# Find asset
asset = next((a for a in release.get("assets", []) if a["name"] == ASSET_NAME), None)
asset = next(
(a for a in release.get("assets", []) if a["name"] == ASSET_NAME), None
)
if not asset:
raise FileNotFoundError(
f"Asset {ASSET_NAME} not found in release {release.get('tag_name', '?')}. "
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
]

[project.optional-dependencies]
dev = ["pytest>=7.0", "pytest-benchmark>=4.0"]
dev = ["pytest>=7.0", "pytest-benchmark>=4.0", "ruff>=0.4"]
ml = ["onnxruntime>=1.16", "optimum[onnxruntime]>=1.14", "transformers>=4.36"]

[project.urls]
Expand Down
Loading