-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
163 lines (124 loc) · 4.41 KB
/
main.py
File metadata and controls
163 lines (124 loc) · 4.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import hashlib
import logging
import pandas as pd
from preprocessing.core import Preprocessor
from preprocessing.loader import BibLoader, TxtLoader
from preprocessing.models import DocumentRecord, PreprocessedDocument
from scystream.sdk.core import entrypoint
from scystream.sdk.env.settings import (
EnvSettings,
FileSettings,
InputSettings,
OutputSettings,
PostgresSettings,
)
from scystream.sdk.file_handling.s3_manager import S3Operations
from sqlalchemy import create_engine
from sqlalchemy.sql import quoted_name
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def _normalize_table_name(table_name: str) -> str:
max_length = 63
if len(table_name) <= max_length:
return table_name
digest = hashlib.sha1(table_name.encode("utf-8")).hexdigest()[:10]
prefix_length = max_length - len(digest) - 1
return f"{table_name[:prefix_length]}_{digest}"
def _resolve_db_table(settings: PostgresSettings) -> str:
normalized_name = _normalize_table_name(settings.DB_TABLE)
settings.DB_TABLE = normalized_name
return normalized_name
class NormalizedDocsOutput(PostgresSettings, OutputSettings):
__identifier__ = "normalized_docs"
class TXTFileInput(FileSettings, InputSettings):
__identifier__ = "txt_file"
FILE_EXT: str = "txt"
class BIBFileInput(FileSettings, InputSettings):
__identifier__ = "bib_file"
FILE_EXT: str = "bib"
SELECTED_ATTRIBUTE: str = "Abstract"
class PreprocessTXT(EnvSettings):
LANGUAGE: str = "en"
FILTER_STOPWORDS: bool = True
UNIGRAM_NORMALIZER: str = "lemma"
USE_NGRAMS: bool = True
NGRAM_MIN: int = 2
NGRAM_MAX: int = 3
TXT_DOWNLOAD_PATH: str = "/tmp/input.txt"
txt_input: TXTFileInput
normalized_docs_output: NormalizedDocsOutput
class PreprocessBIB(EnvSettings):
LANGUAGE: str = "en"
FILTER_STOPWORDS: bool = True
UNIGRAM_NORMALIZER: str = "lemma"
USE_NGRAMS: bool = True
NGRAM_MIN: int = 2
NGRAM_MAX: int = 3
BIB_DOWNLOAD_PATH: str = "/tmp/input.bib"
bib_input: BIBFileInput
normalized_docs_output: NormalizedDocsOutput
def _write_preprocessed_docs_to_postgres(
preprocessed_ouput: list[PreprocessedDocument],
settings: PostgresSettings,
):
resolved_table_name = _resolve_db_table(settings)
df = pd.DataFrame(
[
{
"doc_id": d.doc_id,
"tokens": d.tokens,
}
for d in preprocessed_ouput
],
)
logger.info(
"Writing %s processed documents to DB table '%s'…",
len(df),
resolved_table_name,
)
engine = create_engine(
f"postgresql+psycopg2://{settings.PG_USER}:{settings.PG_PASS}"
f"@{settings.PG_HOST}:{int(settings.PG_PORT)}/",
)
table_name = quoted_name(resolved_table_name, quote=True)
df.to_sql(table_name, engine, if_exists="replace", index=False)
logger.info(
"Successfully stored normalized documents into '%s'.",
resolved_table_name,
)
def _preprocess_and_store(documents: list[DocumentRecord], settings):
"""Shared preprocessing logic for TXT and BIB."""
logger.info(f"Starting preprocessing with {len(documents)} documents")
pre = Preprocessor(
language=settings.LANGUAGE,
filter_stopwords=settings.FILTER_STOPWORDS,
unigram_normalizer=settings.UNIGRAM_NORMALIZER,
use_ngrams=settings.USE_NGRAMS,
ngram_min=settings.NGRAM_MIN,
ngram_max=settings.NGRAM_MAX,
)
pre.documents = documents
result = pre.generate_normalized_output()
_write_preprocessed_docs_to_postgres(
result,
settings.normalized_docs_output,
)
logger.info("Preprocessing completed successfully.")
@entrypoint(PreprocessTXT)
def preprocess_txt_file(settings):
logger.info("Downloading TXT input from S3...")
S3Operations.download(settings.txt_input, settings.TXT_DOWNLOAD_PATH)
texts = TxtLoader.load(settings.TXT_DOWNLOAD_PATH)
_preprocess_and_store(texts, settings)
@entrypoint(PreprocessBIB)
def preprocess_bib_file(settings):
logger.info("Downloading BIB input from S3...")
S3Operations.download(settings.bib_input, settings.BIB_DOWNLOAD_PATH)
texts = BibLoader.load(
settings.BIB_DOWNLOAD_PATH,
attribute=settings.bib_input.SELECTED_ATTRIBUTE,
)
_preprocess_and_store(texts, settings)