From 889820d5f692ab5547d8867c71081e0c93dc9d55 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 21 Nov 2023 11:49:16 +0000 Subject: [PATCH 1/2] Bump werkzeug from 2.2.2 to 2.3.8 in /data_extractor/code/model_pipeline Bumps [werkzeug](https://github.com/pallets/werkzeug) from 2.2.2 to 2.3.8. - [Release notes](https://github.com/pallets/werkzeug/releases) - [Changelog](https://github.com/pallets/werkzeug/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/werkzeug/compare/2.2.2...2.3.8) --- updated-dependencies: - dependency-name: werkzeug dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- data_extractor/code/model_pipeline/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_extractor/code/model_pipeline/requirements.txt b/data_extractor/code/model_pipeline/requirements.txt index 314bdc2..6b4fc3d 100644 --- a/data_extractor/code/model_pipeline/requirements.txt +++ b/data_extractor/code/model_pipeline/requirements.txt @@ -9,7 +9,7 @@ xlrd==1.2.0 pandas==1.0.5 farm==0.5.0 optuna==2.0.0 -Werkzeug==2.2.2 +Werkzeug==2.3.8 Flask==2.2.5 pyspellchecker==0.5.5 spacy==2.3.2 From a33ed2d0d5d5f2ea363756a26f77d7e3d72f5b64 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Nov 2023 11:49:46 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- data_extractor/code/config_path.py | 10 +- .../code/coordinator/server_coordinator.py | 73 +- data_extractor/code/dataload/db_export.py | 297 +- .../esg_data_pipeline/__init__.py | 3 +- .../components/base_component.py | 1 + .../components/base_curator.py | 8 +- .../esg_data_pipeline/components/curator.py | 6 +- .../esg_data_pipeline/components/extractor.py | 9 +- .../components/pdf_text_extractor.py | 38 +- .../components/text_curator.py | 78 +- .../esg_data_pipeline/config/__init__.py | 1 - .../esg_data_pipeline/config/config.py | 53 +- .../config/logging_config.py | 6 +- .../esg_data_pipeline/extraction_server.py | 153 +- .../esg_data_pipeline/utils/kpi_mapping.py | 8 +- .../code/esg_data_pipeline/setup.py | 24 +- .../code/esg_data_pipeline/test/app.py | 14 +- data_extractor/code/infer_on_pdf.py | 605 +-- .../components/base_kpi_inference_curator.py | 84 +- .../components/text_kpi_inference_curator.py | 259 +- .../config/config.py | 15 +- .../config/logging_config.py | 6 +- .../utils/kpi_mapping.py | 9 +- .../utils/utils.py | 78 +- .../code/kpi_inference_data_pipeline/setup.py | 24 +- .../code/model_pipeline/metrics_per_kpi.py | 7 +- .../model_pipeline/model_pipeline/__init__.py | 4 +- .../model_pipeline/config_farm_train.py | 78 +- .../model_pipeline/config_qa_farm_train.py | 41 +- .../model_pipeline/farm_trainer.py | 74 +- .../model_pipeline/inference_server.py | 519 ++- .../model_pipeline/optuna_hyp.py | 34 +- .../model_pipeline/qa_farm_trainer.py | 52 +- .../model_pipeline/relevance_infer.py | 41 +- .../model_pipeline/text_kpi_infer.py | 90 +- .../model_pipeline/trainer_optuna.py | 3 +- .../model_pipeline/utils/kpi_mapping.py | 13 +- .../model_pipeline/utils/qa_metrics.py | 14 +- data_extractor/code/model_pipeline/setup.py | 24 +- .../rule_based_pipeline/AnalyzerCluster.py | 493 ++- .../rule_based_pipeline/AnalyzerDirectory.py | 118 +- .../rule_based_pipeline/AnalyzerPage.py | 98 +- .../rule_based_pipeline/AnalyzerTable.py | 1635 ++++---- .../rule_based_pipeline/ConsoleTable.py | 199 +- .../rule_based_pipeline/DataImportExport.py | 187 +- .../rule_based_pipeline/Format_Analyzer.py | 407 +- .../rule_based_pipeline/HTMLCluster.py | 383 +- .../rule_based_pipeline/HTMLDirectory.py | 215 +- .../rule_based_pipeline/HTMLItem.py | 591 +-- .../rule_based_pipeline/HTMLPage.py | 3073 +++++++------- .../rule_based_pipeline/HTMLTable.py | 3718 +++++++++-------- .../rule_based_pipeline/HTMLWord.py | 16 +- .../rule_based_pipeline/KPIMeasure.py | 290 +- .../rule_based_pipeline/KPIResultSet.py | 169 +- .../rule_based_pipeline/KPISpecs.py | 741 ++-- .../rule_based_pipeline/Rect.py | 141 +- .../rule_based_pipeline/TestData.py | 514 ++- .../rule_based_pipeline/TestDataSample.py | 206 +- .../rule_based_pipeline/TestEvaluation.py | 357 +- .../rule_based_pipeline/config.py | 36 +- .../rule_based_pipeline/globals.py | 199 +- .../rule_based_pipeline/main.py | 382 +- .../rule_based_pipeline/main_find_xy.py | 424 +- .../rule_based_pipeline/rb_server.py | 130 +- .../rule_based_pipeline/test.py | 2739 +++++++++--- .../code/rule_based_pipeline/setup.py | 22 +- data_extractor/code/s3_communication.py | 24 +- data_extractor/code/setup_project.py | 33 +- data_extractor/code/tests/conftest.py | 22 +- .../code/tests/test_train_on_pdf.py | 448 +- .../test_utils/test_convert_xls_to_csv.py | 100 +- .../test_copy_file_without_overwrite.py | 42 +- .../tests/test_utils/test_create_directory.py | 10 +- .../tests/test_utils/test_generate_text.py | 151 +- .../code/tests/test_utils/test_link_files.py | 86 +- .../code/tests/test_utils/test_run_router.py | 318 +- .../code/tests/test_utils/test_running.py | 23 +- .../tests/test_utils/test_save_train_info.py | 100 +- data_extractor/code/tests/utils_test.py | 32 +- data_extractor/code/train_on_pdf.py | 334 +- data_extractor/code/utils/config_path.py | 10 +- data_extractor/code/utils/s3_communication.py | 24 +- .../code/visitor_container/visitor_main.py | 43 +- data_extractor/docs/s3_communication.py | 24 +- .../annotation_tool/annotation_tool.py | 386 +- .../old_versions/tool_widgets.py | 414 +- .../inception_transformer.py | 117 +- 87 files changed, 12384 insertions(+), 10696 deletions(-) diff --git a/data_extractor/code/config_path.py b/data_extractor/code/config_path.py index ac24ad8..e42d7e1 100644 --- a/data_extractor/code/config_path.py +++ b/data_extractor/code/config_path.py @@ -1,14 +1,14 @@ import os try: - path = globals()['_dh'][0] + path = globals()["_dh"][0] except KeyError: path = os.path.dirname(os.path.realpath(__file__)) - + root_dir = os.path.dirname(path) -MODEL_DIR = root_dir + r'/models' -DATA_DIR = root_dir + r'/data' +MODEL_DIR = root_dir + r"/models" +DATA_DIR = root_dir + r"/data" NLP_DIR = root_dir -PYTHON_EXECUTABLE = 'python' +PYTHON_EXECUTABLE = "python" diff --git a/data_extractor/code/coordinator/server_coordinator.py b/data_extractor/code/coordinator/server_coordinator.py index a7a8769..bd6c95f 100644 --- a/data_extractor/code/coordinator/server_coordinator.py +++ b/data_extractor/code/coordinator/server_coordinator.py @@ -10,8 +10,8 @@ def check_running(): - print(NLP_DIR+r'/data/running') - return os.path.exists(NLP_DIR+r'/data/running') + print(NLP_DIR + r"/data/running") + return os.path.exists(NLP_DIR + r"/data/running") @app.route("/liveness") @@ -26,21 +26,15 @@ def running(): @app.route("/train") def train(): - """ This function should start the train_on_pdf.py with given parameters as a web access point. + """This function should start the train_on_pdf.py with given parameters as a web access point. :return: """ - parser_train = argparse.ArgumentParser(description='End-to-end training') + parser_train = argparse.ArgumentParser(description="End-to-end training") - parser_train.add_argument('--project_name', - type=str, - default=None, - help='Name of the Project') + parser_train.add_argument("--project_name", type=str, default=None, help="Name of the Project") - parser_train.add_argument('--s3_usage', - type=str, - default=None, - help='Do you want to use S3? Type either Y or N.') + parser_train.add_argument("--s3_usage", type=str, default=None, help="Do you want to use S3? Type either Y or N.") # Read arguments from direct python call args_train = parser_train.parse_args() @@ -61,16 +55,14 @@ def train(): # Read arguments from payload if given if project_name is None or s3_usage is None: try: - args_train = json.loads(request.args['payload']) + args_train = json.loads(request.args["payload"]) project_name = args_train["project_name"] s3_usage = args_train["s3_usage"] except Exception: msg = "Project name or s3_usage where not given via command or payload. Please recheck your call." return Response(msg, status=500) - cmd = 'python3 train_on_pdf.py' + \ - ' --project_name "' + project_name + '"' + \ - ' --s3_usage "' + s3_usage + '"' + cmd = "python3 train_on_pdf.py" + ' --project_name "' + project_name + '"' + ' --s3_usage "' + s3_usage + '"' print("Running command: " + cmd) try: os.system(cmd) @@ -83,27 +75,23 @@ def train(): @app.route("/infer") def infer(): - """ This function should start the infer_on_pdf.py with given parameters (either via cli arguments or via + """This function should start the infer_on_pdf.py with given parameters (either via cli arguments or via payload) as a web access point. :return: Response type containing a message and the int for the type of message (200 if ok, 500 if error) """ - parser_infer = argparse.ArgumentParser(description='End-to-end inference') + parser_infer = argparse.ArgumentParser(description="End-to-end inference") - parser_infer.add_argument('--project_name', - type=str, - default=None, - help='Name of the Project') + parser_infer.add_argument("--project_name", type=str, default=None, help="Name of the Project") - parser_infer.add_argument('--s3_usage', - type=str, - default=None, - help='Do you want to use S3? Type either Y or N.') + parser_infer.add_argument("--s3_usage", type=str, default=None, help="Do you want to use S3? Type either Y or N.") - parser_infer.add_argument('--mode', - type=str, - default='both', - help='Inference Mode (RB, ML, both, or none - for just doing postprocessing)') + parser_infer.add_argument( + "--mode", + type=str, + default="both", + help="Inference Mode (RB, ML, both, or none - for just doing postprocessing)", + ) args_infer = parser_infer.parse_args() project_name = args_infer.project_name @@ -122,7 +110,7 @@ def infer(): # Read arguments from payload if given if project_name is None or s3_usage is None: try: - args_infer = json.loads(request.args['payload']) + args_infer = json.loads(request.args["payload"]) project_name = args_infer["project_name"] s3_usage = args_infer["s3_usage"] mode = args_infer["mode"] @@ -130,10 +118,18 @@ def infer(): msg = "Project name, mode or s3_usage where not given via command or payload. Please recheck your call." return Response(msg, status=500) - cmd = 'python3 infer_on_pdf.py' + \ - ' --project_name "' + project_name + '"' + \ - ' --mode "' + mode + '"' + \ - ' --s3_usage "' + s3_usage + '"' + cmd = ( + "python3 infer_on_pdf.py" + + ' --project_name "' + + project_name + + '"' + + ' --mode "' + + mode + + '"' + + ' --s3_usage "' + + s3_usage + + '"' + ) print("Running command: " + cmd) try: os.system(cmd) @@ -145,11 +141,8 @@ def infer(): if __name__ == "__main__": - parser = argparse.ArgumentParser(description='coordinator server') - parser.add_argument('--port', - type=int, - default=2000, - help='Port to use for the coordinator server') + parser = argparse.ArgumentParser(description="coordinator server") + parser.add_argument("--port", type=int, default=2000, help="Port to use for the coordinator server") args = parser.parse_args() port = args.port app.run(host="0.0.0.0", port=port) diff --git a/data_extractor/code/dataload/db_export.py b/data_extractor/code/dataload/db_export.py index a632509..d9e22fb 100644 --- a/data_extractor/code/dataload/db_export.py +++ b/data_extractor/code/dataload/db_export.py @@ -7,166 +7,165 @@ from inspect import getsourcefile import os.path as path, sys -current_dir = path.dirname(path.abspath(getsourcefile(lambda:0))) -sys.path.insert(0, current_dir[:current_dir.rfind(path.sep)]) + +current_dir = path.dirname(path.abspath(getsourcefile(lambda: 0))) +sys.path.insert(0, current_dir[: current_dir.rfind(path.sep)]) import config_path + sys.path.pop(0) def connect_to_db(dialect, sql_driver, host, port, user, password): - engine_path = dialect + '+' + sql_driver + '://' + user + ':' + password +'@' + host + ':' + str(port) #+ '/?service_name=' + SERVICE - eng = engine.create_engine(engine_path) - connection = eng.connect() - return connection - + engine_path = ( + dialect + "+" + sql_driver + "://" + user + ":" + password + "@" + host + ":" + str(port) + ) # + '/?service_name=' + SERVICE + eng = engine.create_engine(engine_path) + connection = eng.connect() + return connection + + def insert_csv(connection, csv_filename, run_id): - def db_numeric(s): - if(s=='' or s is None): - return None - return float(s) - - - nlp_raw_output = table('NLP_RAW_OUTPUT', \ - column('METHOD'),\ - column('PDF_NAME'),\ - column('KPI_ID', types.Numeric),\ - column('KPI_NAME'),\ - column('KPI_DESC'),\ - column('ANSWER_RAW'),\ - column('ANSWER'),\ - column('PAGE', types.Numeric),\ - column('PARAGRAPH'),\ - column('POS_X', types.Numeric),\ - column('POS_Y', types.Numeric),\ - column('KPI_SOURCE'),\ - column('SCORE', types.Numeric),\ - column('NO_ANS_SCORE', types.Numeric),\ - column('SCORE_PLUS_BOOST', types.Numeric),\ - column('KPI_YEAR', types.Numeric),\ - column('UNIT_RAW'),\ - column('UNIT'), \ - column('INS_RUN_ID', types.Numeric) - ) - rows = [] - pdf_names = [] - try: - with open(csv_filename, 'r') as f: - csv_file = csv.DictReader(f) - for row in csv_file: - d = dict(row) - rows.append({'METHOD': d['METHOD'], \ - 'METHOD' : d['METHOD'], \ - 'PDF_NAME' : d['PDF_NAME'], \ - 'KPI_ID' : db_numeric(d['KPI_ID']), \ - 'KPI_NAME' : d['KPI_NAME'], \ - 'KPI_DESC' : d['KPI_DESC'], \ - 'ANSWER_RAW' : d['ANSWER_RAW'], \ - 'ANSWER' : d['ANSWER'], \ - 'PAGE' : db_numeric(d['PAGE']), \ - 'PARAGRAPH' : d['PARAGRAPH'], \ - 'POS_X' : db_numeric(d['POS_X']), \ - 'POS_Y' : db_numeric(d['POS_Y']), \ - 'KPI_SOURCE' : d['KPI_SOURCE'], \ - 'SCORE' : db_numeric(d['SCORE']), \ - 'NO_ANS_SCORE' : db_numeric(d['NO_ANS_SCORE']), \ - 'SCORE_PLUS_BOOST' : db_numeric(d['SCORE_PLUS_BOOST']), \ - 'KPI_YEAR' : db_numeric(d['KPI_YEAR']), \ - 'UNIT_RAW' : d['UNIT_RAW'], \ - 'UNIT' : d['UNIT'], \ - 'INS_RUN_ID': run_id - }) - pdf_names.append(d['PDF_NAME']) - - pdf_names = list(set(pdf_names)) - - # Delete exiting entries - for pdf_name in pdf_names: - dele = nlp_raw_output.delete().where(nlp_raw_output.c.PDF_NAME == pdf_name) - connection.execute(dele) - - # Insert new entries - for row in rows: - #print(row) - ins = nlp_raw_output.insert().values(row) - connection.execute(ins) - pass - except Exception as e: - print("Failed to insert " + csv_filename + ". Reason: " + str(e)) - + def db_numeric(s): + if s == "" or s is None: + return None + return float(s) + + nlp_raw_output = table( + "NLP_RAW_OUTPUT", + column("METHOD"), + column("PDF_NAME"), + column("KPI_ID", types.Numeric), + column("KPI_NAME"), + column("KPI_DESC"), + column("ANSWER_RAW"), + column("ANSWER"), + column("PAGE", types.Numeric), + column("PARAGRAPH"), + column("POS_X", types.Numeric), + column("POS_Y", types.Numeric), + column("KPI_SOURCE"), + column("SCORE", types.Numeric), + column("NO_ANS_SCORE", types.Numeric), + column("SCORE_PLUS_BOOST", types.Numeric), + column("KPI_YEAR", types.Numeric), + column("UNIT_RAW"), + column("UNIT"), + column("INS_RUN_ID", types.Numeric), + ) + rows = [] + pdf_names = [] + try: + with open(csv_filename, "r") as f: + csv_file = csv.DictReader(f) + for row in csv_file: + d = dict(row) + rows.append( + { + "METHOD": d["METHOD"], + "METHOD": d["METHOD"], + "PDF_NAME": d["PDF_NAME"], + "KPI_ID": db_numeric(d["KPI_ID"]), + "KPI_NAME": d["KPI_NAME"], + "KPI_DESC": d["KPI_DESC"], + "ANSWER_RAW": d["ANSWER_RAW"], + "ANSWER": d["ANSWER"], + "PAGE": db_numeric(d["PAGE"]), + "PARAGRAPH": d["PARAGRAPH"], + "POS_X": db_numeric(d["POS_X"]), + "POS_Y": db_numeric(d["POS_Y"]), + "KPI_SOURCE": d["KPI_SOURCE"], + "SCORE": db_numeric(d["SCORE"]), + "NO_ANS_SCORE": db_numeric(d["NO_ANS_SCORE"]), + "SCORE_PLUS_BOOST": db_numeric(d["SCORE_PLUS_BOOST"]), + "KPI_YEAR": db_numeric(d["KPI_YEAR"]), + "UNIT_RAW": d["UNIT_RAW"], + "UNIT": d["UNIT"], + "INS_RUN_ID": run_id, + } + ) + pdf_names.append(d["PDF_NAME"]) + + pdf_names = list(set(pdf_names)) + + # Delete exiting entries + for pdf_name in pdf_names: + dele = nlp_raw_output.delete().where(nlp_raw_output.c.PDF_NAME == pdf_name) + connection.execute(dele) + + # Insert new entries + for row in rows: + # print(row) + ins = nlp_raw_output.insert().values(row) + connection.execute(ins) + pass + except Exception as e: + print("Failed to insert " + csv_filename + ". Reason: " + str(e)) + + def run_command(connection, run_id, cmd): - exec_cmd = cmd.replace(':RUN_ID', str(run_id)) - connection.execute(exec_cmd) + exec_cmd = cmd.replace(":RUN_ID", str(run_id)) + connection.execute(exec_cmd) - def main(): - parser = argparse.ArgumentParser(description='End-to-end inference') - - # Add the arguments - parser.add_argument('--project_name', - type=str, - default=None, - help='Name of the Project') - - parser.add_argument('--run_id', - type=int, - default=None, - help='RUN_ID Filter') - - args = parser.parse_args() - project_name = args.project_name - run_id = args.run_id - - if project_name is None: - project_name = input("What is the project name? ") - if(project_name is None or project_name==""): - print("project name must not be empty") - return - - project_data_dir = config_path.DATA_DIR + r'/' + project_name - - # Opening JSON file - f = open(project_data_dir + r'/settings.json') - project_settings = json.load(f) - f.close() - - enable_db_export = project_settings['enable_db_export'] - - if(not enable_db_export): - print("Database export is not enabled for this project.") - - db_dialect = project_settings['db_dialect'] - db_sql_driver = project_settings['db_sql_driver'] - db_host = project_settings['db_host'] - db_port = project_settings['db_port'] - db_user = project_settings['db_user'] - db_password = project_settings['db_password'] - db_post_command = project_settings['db_post_command'] - - - print("Connecting to database " + db_host + " as " + db_user + " using password (hidden) . . . ") - - connection = connect_to_db(db_dialect, db_sql_driver, db_host, db_port, db_user, db_password) - - print("Connected. Inserting new CSV files from output . . . ") - - csv_path = project_data_dir + r'/output/' + (str(run_id) + '_' if run_id is not None else '') + '*.csv' - - for f in glob.glob(csv_path): - print("----> " + f) - insert_csv(connection, f, run_id) - - if(db_post_command is not None and db_post_command != ''): - print("Executing post command for RUN_ID = " +str(run_id) + " . . .") - run_command(connection, run_id, db_post_command) - - - print("Closing database connection . . . ") - connection.close() - - print("Export done.") + parser = argparse.ArgumentParser(description="End-to-end inference") + + # Add the arguments + parser.add_argument("--project_name", type=str, default=None, help="Name of the Project") + + parser.add_argument("--run_id", type=int, default=None, help="RUN_ID Filter") + + args = parser.parse_args() + project_name = args.project_name + run_id = args.run_id + + if project_name is None: + project_name = input("What is the project name? ") + if project_name is None or project_name == "": + print("project name must not be empty") + return + + project_data_dir = config_path.DATA_DIR + r"/" + project_name + + # Opening JSON file + f = open(project_data_dir + r"/settings.json") + project_settings = json.load(f) + f.close() + + enable_db_export = project_settings["enable_db_export"] + + if not enable_db_export: + print("Database export is not enabled for this project.") + + db_dialect = project_settings["db_dialect"] + db_sql_driver = project_settings["db_sql_driver"] + db_host = project_settings["db_host"] + db_port = project_settings["db_port"] + db_user = project_settings["db_user"] + db_password = project_settings["db_password"] + db_post_command = project_settings["db_post_command"] + + print("Connecting to database " + db_host + " as " + db_user + " using password (hidden) . . . ") + + connection = connect_to_db(db_dialect, db_sql_driver, db_host, db_port, db_user, db_password) + + print("Connected. Inserting new CSV files from output . . . ") + + csv_path = project_data_dir + r"/output/" + (str(run_id) + "_" if run_id is not None else "") + "*.csv" + + for f in glob.glob(csv_path): + print("----> " + f) + insert_csv(connection, f, run_id) + + if db_post_command is not None and db_post_command != "": + print("Executing post command for RUN_ID = " + str(run_id) + " . . .") + run_command(connection, run_id, db_post_command) + + print("Closing database connection . . . ") + connection.close() + print("Export done.") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/__init__.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/__init__.py index 8d0a76a..118d1b1 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/__init__.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/__init__.py @@ -1,5 +1,4 @@ -from .components import Extractor, PDFTextExtractor, \ - TextCurator +from .components import Extractor, PDFTextExtractor, TextCurator import logging from .config import logging_config, config diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/base_component.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/base_component.py index 6bb7cca..e503e5f 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/base_component.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/base_component.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod + class BaseComponent(ABC): def __init__(self, name="Base"): self.name = name diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/base_curator.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/base_curator.py index 1dd03a3..a9d376d 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/base_curator.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/base_curator.py @@ -20,7 +20,7 @@ def create_pos_examples(self, row, *args, **kwargs): pass @abstractmethod - def create_negative_examples(self, row, *args, **kwargs ): + def create_negative_examples(self, row, *args, **kwargs): pass @staticmethod @@ -36,8 +36,8 @@ def clean_text(text): # Substitute unusual quotes at the end of the string with usual quotes text = re.sub("”(?=\])", '"', text) # Substitute th remaining unusual quotes with space - text = re.sub('“|”', '', text) - text = re.sub('\n|\t', " ", text) - text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\xff]', '', text) + text = re.sub("“|”", "", text) + text = re.sub("\n|\t", " ", text) + text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\xff]", "", text) text = re.sub(r"\s{2,}", " ", text) return text diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/curator.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/curator.py index e51d878..9eef74b 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/curator.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/curator.py @@ -6,11 +6,13 @@ logger = logging.getLogger(__name__) NAME_CLASS_MAPPING = {"TextCurator": TextCurator} + class Curator: """A data curator component responsible for creating table and text training data based on annotated data Args: annotation_folder (str): path to the folder containing annotation excel files """ + def __init__(self, curators): self.curators = self.__create_curators(curators) @@ -33,14 +35,14 @@ def __create_curators(self, curators): return list_cura def run(self, input_extraction_folder, annotation_folder, output_folder): - """ Runs curation for each curator. + """Runs curation for each curator. Args: input_extraction_folder (A str or PosixPath) annotation_folder (A str or PosixPath) output_folder (A str or PosixPath) """ - annotation_excels = glob.glob('{}/[!~$]*[.xlsx]'.format(annotation_folder)) + annotation_excels = glob.glob("{}/[!~$]*[.xlsx]".format(annotation_folder)) logger.info("Received {} excel files".format(len(annotation_excels))) for curator_obj in self.curators: diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/extractor.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/extractor.py index 74a6dee..18fdd43 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/extractor.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/extractor.py @@ -2,9 +2,8 @@ import logging _logger = logging.getLogger(__name__) -NAME_CLASS_MAPPING = { - "PDFTextExtractor": PDFTextExtractor -} +NAME_CLASS_MAPPING = {"PDFTextExtractor": PDFTextExtractor} + class Extractor: def __init__(self, extractors): @@ -17,7 +16,7 @@ def __init__(self, extractors): self.extractors = self.__create_extractors(extractors) def __create_extractors(self, extractors): - """ Returns a list of extractors objects + """Returns a list of extractors objects Args: extractors (A list of str) @@ -47,7 +46,7 @@ def run(self, input_filepath, output_folder): for ext in self.extractors: _ = ext.run(input_filepath, output_folder) - def run_folder(self, input_folder, output_folder): + def run_folder(self, input_folder, output_folder): """ Extract for all files mentioned in folder. (The logic is based on each child.) diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/pdf_text_extractor.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/pdf_text_extractor.py index 7098696..35161d2 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/pdf_text_extractor.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/pdf_text_extractor.py @@ -21,7 +21,7 @@ class PDFTextExtractor(BaseComponent): - """ This Class is responsible for extracting text data from PDFs and saving + """This Class is responsible for extracting text data from PDFs and saving the result in a json format file. Each name/value pair in the json file refers to page_number and the list of paragraphs in that page. @@ -36,11 +36,7 @@ class PDFTextExtractor(BaseComponent): """ def __init__( - self, - annotation_folder=None, - min_paragraph_length=20, - skip_extracted_files=False, - name='PDFTextExtractor' + self, annotation_folder=None, min_paragraph_length=20, skip_extracted_files=False, name="PDFTextExtractor" ): super().__init__(name) self.min_paragraph_length = min_paragraph_length @@ -48,7 +44,7 @@ def __init__( self.skip_extracted_files = skip_extracted_files def process_page(self, input_text): - """ This function receives a text following: + """This function receives a text following: 1. Divide it into paragraphs, using \n\n 2. Remove table data: To achieve this, if number of alphabet characters of paragraph is less min_paragraph_length, it is considered as table cell and it will be removed. @@ -62,8 +58,11 @@ def process_page(self, input_text): paragraphs = input_text.split("\n\n") # Get ride of table data if the number of alphabets in a paragraph is less than `min_paragraph_length` - paragraphs = [BaseCurator.clean_text(p) for p in paragraphs if - sum(c.isalpha() for c in BaseCurator.clean_text(p)) > self.min_paragraph_length] + paragraphs = [ + BaseCurator.clean_text(p) + for p in paragraphs + if sum(c.isalpha() for c in BaseCurator.clean_text(p)) > self.min_paragraph_length + ] return paragraphs def extract_pdf_by_page(self, pdf_file): @@ -80,10 +79,10 @@ def extract_pdf_by_page(self, pdf_file): _logger.warning("{}: Unable to process {}".format(e, pdf_file)) return {} - fp = open(pdf_file, 'rb') + fp = open(pdf_file, "rb") rsrcmgr = PDFResourceManager() retstr = io.BytesIO() - codec = 'utf-8' + codec = "utf-8" laparams = LAParams() device = TextConverter(rsrcmgr, retstr, codec=codec, laparams=laparams) interpreter = PDFPageInterpreter(rsrcmgr, device) @@ -91,7 +90,7 @@ def extract_pdf_by_page(self, pdf_file): pdf_content = {} for page_number, page in enumerate(PDFPage.get_pages(fp, check_extractable=False)): interpreter.process_page(page) - data = retstr.getvalue().decode('utf-8') + data = retstr.getvalue().decode("utf-8") data_paragraphs = self.process_page(data) if len(data_paragraphs) == 0: continue @@ -107,7 +106,7 @@ def run(self, input_filepath, output_folder): Args: input_filepath (str or PosixPath): full path to the pdf file output_folder (str or PosixPath): Folder to save the result of extraction - """ + """ output_file_name = os.path.splitext(os.path.basename(input_filepath))[0] json_filename = output_file_name + ".json" @@ -125,13 +124,13 @@ def run(self, input_filepath, output_folder): return None json_path = os.path.join(output_folder, json_filename) - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(text_dict, f) return text_dict def run_folder(self, input_folder, output_folder): - """ This method will perform pdf extraction for all the pdfs mentioned + """This method will perform pdf extraction for all the pdfs mentioned as source in the annotated excel files and it will be saved the results in a output_folder. @@ -141,18 +140,18 @@ def run_folder(self, input_folder, output_folder): output_folder (str or PosixPath): path to the folder to save the extracted json files. """ - files = [os.path.join(input_folder, f) for f in Path(input_folder).rglob('*.pdf') if f.is_file()] + files = [os.path.join(input_folder, f) for f in Path(input_folder).rglob("*.pdf") if f.is_file()] if self.annotation_folder is not None: # Get the names of all excel files - all_annotation_files = glob.glob('{}/[!~$]*[.xlsx]'.format(self.annotation_folder)) + all_annotation_files = glob.glob("{}/[!~$]*[.xlsx]".format(self.annotation_folder)) annotated_pdfs = [] for excel_path in all_annotation_files: df = pd.read_excel(excel_path) # Get the unique values of source_file column - df_unique_pdfs = df['source_file'].drop_duplicates().dropna() + df_unique_pdfs = df["source_file"].drop_duplicates().dropna() annotated_pdfs.extend(df_unique_pdfs) - annotated_pdfs = [file.split(".pdf")[0]+".pdf" for file in annotated_pdfs] + annotated_pdfs = [file.split(".pdf")[0] + ".pdf" for file in annotated_pdfs] found_annotated_pdfs = [] for f in files: @@ -165,4 +164,3 @@ def run_folder(self, input_folder, output_folder): else: for f in files: _ = self.run(f, output_folder) - diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/text_curator.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/text_curator.py index 8c45b93..a849ee4 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/text_curator.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/components/text_curator.py @@ -26,9 +26,9 @@ def __init__( create_neg_samples=False, min_length_neg_sample=50, name="DataTextCurator", - data_type="TEXT" + data_type="TEXT", ): - """ This class is the responsible for creating ESG text dataset + """This class is the responsible for creating ESG text dataset (positive and negative examples) based on the annotations. Args: retrieve_paragraph (bool): Whether or not try to extract the whole @@ -77,36 +77,30 @@ def run(self, extraction_folder, annotation_excels, output_folder): # Map the KPI to KPI questions importlib.reload(kpi_mapping) - df_result['question'] = df_result.astype( - {'kpi_id': 'float'}, errors="ignore" - )['kpi_id'].map(kpi_mapping.KPI_MAPPING) + df_result["question"] = df_result.astype({"kpi_id": "float"}, errors="ignore")["kpi_id"].map( + kpi_mapping.KPI_MAPPING + ) # In the result csv, the following KPIs are not mapped to any questions. # To avoid losing any data, the following # KPIs should be modified manually. logger.warning( "The corresponding KPIs can not be mapped \ - to any questions and the mapped question is empty\n{}"\ - .format(df_result[df_result['question'].isna()]['kpi_id'].unique()) + to any questions and the mapped question is empty\n{}".format( + df_result[df_result["question"].isna()]["kpi_id"].unique() + ) ) # Remove the rows could not map KPI to question df_result = df_result[df_result.question.notnull()] # Remove duplicate examples - df_result = df_result.groupby(['question', 'context']).first().reset_index() + df_result = df_result.groupby(["question", "context"]).first().reset_index() - save_path = os.path.join( - output_folder, - "esg_{}_dataset.csv".format(self.data_type) - ) + save_path = os.path.join(output_folder, "esg_{}_dataset.csv".format(self.data_type)) logger.info("Curated {} examples".format(len(df_result))) logger.info("Saving the dataset in {}".format(save_path)) df_result.to_csv(save_path) - def process_single_annotation_file( - self, - annotation_filepath, - sheet_name='data_ex_in_xls' - ): + def process_single_annotation_file(self, annotation_filepath, sheet_name="data_ex_in_xls"): """Create examples for a single excel file Args: annotation_filepath (str): Path to the annotated excel file @@ -129,7 +123,7 @@ def process_single_annotation_file( examples = [] for i, row in df.iterrows(): - row['Index'] = i + row["Index"] = i positive_examples = self.create_pos_examples(row.copy()) examples.extend(positive_examples) @@ -181,7 +175,7 @@ def create_pos_examples(self, row): return pos_rows def create_negative_examples(self, row): - """ Create negative examples for each row, to achieve this: + """Create negative examples for each row, to achieve this: - If the source pdf is presented and extracted, we choose a random page, except source page and choose a random paragraph within that. - If the extracted pdf is not available, we look for the a random @@ -199,9 +193,7 @@ def create_negative_examples(self, row): # if the corresponding pdf to a row is not presented, a random pdf is # picked to create negative example if len(pdf_content) == 0: - random_json_path = random.choice( - list(Path(self.extraction_folder).rglob('*.json')) - ) + random_json_path = random.choice(list(Path(self.extraction_folder).rglob("*.json"))) with open(os.path.join(self.extraction_folder, random_json_path)) as f: pdf_content = [json.load(f)] @@ -262,19 +254,11 @@ def process_relevant_sentences(self, row): # quotes in the sentence. # 2 is added because to be compatible, so it would be compatible # with what you see in MS excel. - logger.warning( - "Could not process row number {} in {}".format( - (row["Index"]+2), row["annotator"] - ) - ) + logger.warning("Could not process row number {} in {}".format((row["Index"] + 2), row["annotator"])) return None else: # To support cases where relevant paragraph are given as strings. - logger.info( - "Not in a list format row number {} , {}".format( - (row["Index"]+2), row["annotator"] - ) - ) + logger.info("Not in a list format row number {} , {}".format((row["Index"] + 2), row["annotator"])) return [sentence_revised] def get_full_paragraph(self, row, relevant_sentences): @@ -291,34 +275,24 @@ def get_full_paragraph(self, row, relevant_sentences): relevant_sentences (list of str): List of processed relevant_paragraphs. Returns: matches_list (list of str): list of full paragraphs. - """ + """ pdf_content = self.load_pdf_content(row) try: source_page = ast.literal_eval(row["source_page"]) except SyntaxError: - logger.info( - "Can not process source page in row {} of {} ".format( - (row["Index"]+2), row["annotator"] - ) - ) + logger.info("Can not process source page in row {} of {} ".format((row["Index"] + 2), row["annotator"])) return [] # pdfminer starts the page counter as 0 while for pdf viewers the first # page is numbered as 1. selected_pages = [p - 1 for p in source_page] - paragraphs = [ - pdf.get(str(p), []) for p in selected_pages for pdf in pdf_content - ] + paragraphs = [pdf.get(str(p), []) for p in selected_pages for pdf in pdf_content] paragraphs_flat = [item for sublist in paragraphs for item in sublist] matches_list = [] for pattern in relevant_sentences: - special_regex_char = [ - "(", ")", "^", "+", "*", "$", "|", "\\", "?", "[", "]", "{", "}" - ] + special_regex_char = ["(", ")", "^", "+", "*", "$", "|", "\\", "?", "[", "]", "{", "}"] # If the sentences contain the especial character we should put \ # before them for literal match. - pattern = ''.join( - ["\\" + c if c in special_regex_char else c for c in pattern] - ) + pattern = "".join(["\\" + c if c in special_regex_char else c for c in pattern]) for single_par in paragraphs_flat: single_par_clean = self.clean_text(single_par) match = re.search(pattern, single_par_clean, re.I) @@ -329,7 +303,7 @@ def get_full_paragraph(self, row, relevant_sentences): return matches_list def load_pdf_content(self, row): - """ Load the content of a pdf file + """Load the content of a pdf file If the extraction step is passed, the json file should be in the extraction_folder. Args: @@ -339,13 +313,9 @@ def load_pdf_content(self, row): after extraction. """ # The naming format is used in extraction phase. - extracted_filename = os.path.splitext(str(row["source_file"]))[0] \ - + "-" \ - + str(row['company']) + extracted_filename = os.path.splitext(str(row["source_file"]))[0] + "-" + str(row["company"]) # Get all the files in extraction folder that has the desired name - extracted_paths = [ - path for path in os.listdir(self.extraction_folder) if extracted_filename in path - ] + extracted_paths = [path for path in os.listdir(self.extraction_folder) if extracted_filename in path] pdf_contents = [] for path in extracted_paths: diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/__init__.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/__init__.py index 139597f..8b13789 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/__init__.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/__init__.py @@ -1,2 +1 @@ - diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/config.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/config.py index f49c5c4..87ba5cb 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/config.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/config.py @@ -2,40 +2,43 @@ import pathlib # General config -STAGE = "extract" +STAGE = "extract" SEED = 42 CONFIG_FOLDER = pathlib.Path(__file__).resolve().parent ROOT = CONFIG_FOLDER.parent.parent.parent.parent DATA_FOLDER = ROOT / "data" -#Extraction inputs -PDFTextExtractor_kwargs = {'min_paragraph_length': 30, - #Set to ANNOTATION_FOLDER if you want to extract just pdfs mentioned in the annotations - #Set to None to extract all pdfs in pdf folder (for production stage) - 'annotation_folder': None, - 'skip_extracted_files': False - } +# Extraction inputs +PDFTextExtractor_kwargs = { + "min_paragraph_length": 30, + # Set to ANNOTATION_FOLDER if you want to extract just pdfs mentioned in the annotations + # Set to None to extract all pdfs in pdf folder (for production stage) + "annotation_folder": None, + "skip_extracted_files": False, +} -#Curation inputs +# Curation inputs TextCurator_kwargs = { - 'retrieve_paragraph': False, - 'neg_pos_ratio': 1, - 'columns_to_read': [ - 'company', 'source_file', 'source_page', 'kpi_id', - 'year', 'answer', 'data_type', 'relevant_paragraphs' - ], - 'company_to_exclude': [], - 'create_neg_samples': True, - 'min_length_neg_sample': 50, - 'seed': SEED + "retrieve_paragraph": False, + "neg_pos_ratio": 1, + "columns_to_read": [ + "company", + "source_file", + "source_page", + "kpi_id", + "year", + "answer", + "data_type", + "relevant_paragraphs", + ], + "company_to_exclude": [], + "create_neg_samples": True, + "min_length_neg_sample": 50, + "seed": SEED, } # Components -EXTRACTORS = [ - ("PDFTextExtractor", PDFTextExtractor_kwargs) -] +EXTRACTORS = [("PDFTextExtractor", PDFTextExtractor_kwargs)] -CURATORS = [ - ("TextCurator", TextCurator_kwargs) -] +CURATORS = [("TextCurator", TextCurator_kwargs)] diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/logging_config.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/logging_config.py index 6b0f25a..2c820e6 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/logging_config.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/config/logging_config.py @@ -1,10 +1,8 @@ import logging import sys -FORMATTER = logging.Formatter( - "%(asctime)s — %(name)s — %(levelname)s —" - "%(funcName)s:%(lineno)d — %(message)s" -) +FORMATTER = logging.Formatter("%(asctime)s — %(name)s — %(levelname)s —" "%(funcName)s:%(lineno)d — %(message)s") + def get_console_handler(): console_handler = logging.StreamHandler(sys.stdout) diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/extraction_server.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/extraction_server.py index d29c2d0..e607c80 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/extraction_server.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/extraction_server.py @@ -23,7 +23,7 @@ def create_directory(directory_name): try: os.unlink(file_path) except Exception as e: - print('Failed to delete %s. Reason: %s' % (file_path, e)) + print("Failed to delete %s. Reason: %s" % (file_path, e)) @app.route("/liveness") @@ -31,58 +31,58 @@ def liveness(): return Response(response={}, status=200) -@app.route('/extract/') +@app.route("/extract/") def run_extraction(): - args = json.loads(request.args['payload']) + args = json.loads(request.args["payload"]) project_name = args["project_name"] - - extraction_settings = args['extraction'] - + + extraction_settings = args["extraction"] + BASE_DATA_PROJECT_FOLDER = config.DATA_FOLDER / project_name - config.PDF_FOLDER = BASE_DATA_PROJECT_FOLDER / 'interim' / 'pdfs' - BASE_INTERIM_FOLDER = BASE_DATA_PROJECT_FOLDER / 'interim' / 'ml' - config.EXTRACTION_FOLDER = BASE_INTERIM_FOLDER / 'extraction' - config.ANNOTATION_FOLDER = BASE_INTERIM_FOLDER / 'annotations' - config.STAGE = 'extract' - + config.PDF_FOLDER = BASE_DATA_PROJECT_FOLDER / "interim" / "pdfs" + BASE_INTERIM_FOLDER = BASE_DATA_PROJECT_FOLDER / "interim" / "ml" + config.EXTRACTION_FOLDER = BASE_INTERIM_FOLDER / "extraction" + config.ANNOTATION_FOLDER = BASE_INTERIM_FOLDER / "annotations" + config.STAGE = "extract" + create_directory(config.EXTRACTION_FOLDER) create_directory(config.ANNOTATION_FOLDER) create_directory(config.PDF_FOLDER) - + s3_usage = args["s3_usage"] if s3_usage: s3_settings = args["s3_settings"] - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" # init s3 connector s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), + ) + if extraction_settings["use_extractions"]: + s3c_main.download_files_in_prefix_to_dir( + project_prefix + "/output/TEXT_EXTRACTION", config.EXTRACTION_FOLDER + ) + s3c_interim.download_files_in_prefix_to_dir( + project_prefix + "/interim/ml/annotations", config.ANNOTATION_FOLDER ) - if extraction_settings['use_extractions']: - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/output/TEXT_EXTRACTION', - config.EXTRACTION_FOLDER) - s3c_interim.download_files_in_prefix_to_dir(project_prefix + '/interim/ml/annotations', - config.ANNOTATION_FOLDER) - if args['mode'] == 'train': - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/pdfs/training', - config.PDF_FOLDER) + if args["mode"] == "train": + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/pdfs/training", config.PDF_FOLDER) else: - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/pdfs/inference', - config.PDF_FOLDER) - + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/pdfs/inference", config.PDF_FOLDER) + pdfs = glob.glob(os.path.join(config.PDF_FOLDER, "*.pdf")) if len(pdfs) == 0: msg = "No pdf files found in the pdf directory ({})".format(config.PDF_FOLDER) return Response(msg, status=500) - + annotation_files = glob.glob(os.path.join(config.ANNOTATION_FOLDER, "*.csv")) if len(annotation_files) == 0: msg = "No annotations.csv file found on S3." @@ -90,11 +90,11 @@ def run_extraction(): elif len(annotation_files) > 2: msg = "Multiple annotations.csv files found on S3." return Response(msg, status=500) - + config.SEED = extraction_settings["seed"] - config.PDFTextExtractor_kwargs['min_paragraph_length'] = extraction_settings["min_paragraph_length"] - config.PDFTextExtractor_kwargs['annotation_folder'] = extraction_settings["annotation_folder"] - config.PDFTextExtractor_kwargs['skip_extracted_files'] = extraction_settings["skip_extracted_files"] + config.PDFTextExtractor_kwargs["min_paragraph_length"] = extraction_settings["min_paragraph_length"] + config.PDFTextExtractor_kwargs["annotation_folder"] = extraction_settings["annotation_folder"] + config.PDFTextExtractor_kwargs["skip_extracted_files"] = extraction_settings["skip_extracted_files"] ext = Extractor(config.EXTRACTORS) @@ -108,8 +108,7 @@ def run_extraction(): extracted_files = os.listdir(config.EXTRACTION_FOLDER) if len(extracted_files) == 0: - msg = "Extraction Failed. No file was found in the extraction directory ({})"\ - .format(config.EXTRACTION_FOLDER) + msg = "Extraction Failed. No file was found in the extraction directory ({})".format(config.EXTRACTION_FOLDER) return Response(msg, status=500) failed_to_extract = "" @@ -122,10 +121,9 @@ def run_extraction(): msg = "Extraction finished successfully." if len(failed_to_extract) > 0: msg += "The following pdf files, however, did not get extracted:\n" + failed_to_extract - + if s3_usage: - s3c_interim.upload_files_in_dir_to_prefix(config.EXTRACTION_FOLDER, - project_prefix + '/interim/ml/extraction') + s3c_interim.upload_files_in_dir_to_prefix(config.EXTRACTION_FOLDER, project_prefix + "/interim/ml/extraction") # clear folder create_directory(config.EXTRACTION_FOLDER) create_directory(config.ANNOTATION_FOLDER) @@ -135,53 +133,52 @@ def run_extraction(): return Response(msg, status=200) -@app.route('/curate/') +@app.route("/curate/") def run_curation(): - args = json.loads(request.args['payload']) + args = json.loads(request.args["payload"]) project_name = args["project_name"] curation_settings = args["curation"] BASE_DATA_PROJECT_FOLDER = config.DATA_FOLDER / project_name - BASE_INTERIM_FOLDER = BASE_DATA_PROJECT_FOLDER / 'interim' / 'ml' - config.EXTRACTION_FOLDER = BASE_INTERIM_FOLDER / 'extraction' - config.CURATION_FOLDER = BASE_INTERIM_FOLDER / 'curation' - config.ANNOTATION_FOLDER = BASE_INTERIM_FOLDER / 'annotations' - config.KPI_FOLDER = BASE_DATA_PROJECT_FOLDER / 'interim' / 'kpi_mapping' + BASE_INTERIM_FOLDER = BASE_DATA_PROJECT_FOLDER / "interim" / "ml" + config.EXTRACTION_FOLDER = BASE_INTERIM_FOLDER / "extraction" + config.CURATION_FOLDER = BASE_INTERIM_FOLDER / "curation" + config.ANNOTATION_FOLDER = BASE_INTERIM_FOLDER / "annotations" + config.KPI_FOLDER = BASE_DATA_PROJECT_FOLDER / "interim" / "kpi_mapping" create_directory(config.EXTRACTION_FOLDER) create_directory(config.CURATION_FOLDER) create_directory(config.ANNOTATION_FOLDER) - + s3_usage = args["s3_usage"] if s3_usage: s3_settings = args["s3_settings"] - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" # init s3 connector s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), ) - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/kpi_mapping', config.KPI_FOLDER) - s3c_interim.download_files_in_prefix_to_dir(project_prefix + '/interim/ml/extraction', config.EXTRACTION_FOLDER) - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/annotations', - config.ANNOTATION_FOLDER) + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/kpi_mapping", config.KPI_FOLDER) + s3c_interim.download_files_in_prefix_to_dir(project_prefix + "/interim/ml/extraction", config.EXTRACTION_FOLDER) + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/annotations", config.ANNOTATION_FOLDER) shutil.copyfile(os.path.join(config.KPI_FOLDER, "kpi_mapping.csv"), "/app/code/kpi_mapping.csv") - config.STAGE = 'curate' - config.TextCurator_kwargs['retrieve_paragraph'] = curation_settings['retrieve_paragraph'] - config.TextCurator_kwargs['neg_pos_ratio'] = curation_settings['neg_pos_ratio'] - config.TextCurator_kwargs['columns_to_read'] = curation_settings['columns_to_read'] - config.TextCurator_kwargs['company_to_exclude'] = curation_settings['company_to_exclude'] - config.TextCurator_kwargs['min_length_neg_sample'] = curation_settings['min_length_neg_sample'] - config.SEED = curation_settings['seed'] + config.STAGE = "curate" + config.TextCurator_kwargs["retrieve_paragraph"] = curation_settings["retrieve_paragraph"] + config.TextCurator_kwargs["neg_pos_ratio"] = curation_settings["neg_pos_ratio"] + config.TextCurator_kwargs["columns_to_read"] = curation_settings["columns_to_read"] + config.TextCurator_kwargs["company_to_exclude"] = curation_settings["company_to_exclude"] + config.TextCurator_kwargs["min_length_neg_sample"] = curation_settings["min_length_neg_sample"] + config.SEED = curation_settings["seed"] try: if len(config.CURATORS) != 0: @@ -190,26 +187,22 @@ def run_curation(): except Exception as e: msg = "Error during curation\nException:" + str(repr(e)) + traceback.format_exc() return Response(msg, status=500) - + if s3_usage: - s3c_interim.upload_files_in_dir_to_prefix(config.CURATION_FOLDER, - project_prefix + '/interim/ml/curation') + s3c_interim.upload_files_in_dir_to_prefix(config.CURATION_FOLDER, project_prefix + "/interim/ml/curation") # clear folder create_directory(config.KPI_FOLDER) create_directory(config.EXTRACTION_FOLDER) create_directory(config.ANNOTATION_FOLDER) create_directory(config.CURATION_FOLDER) - + return Response("Curation OK", status=200) if __name__ == "__main__": - parser = argparse.ArgumentParser(description='inference server') + parser = argparse.ArgumentParser(description="inference server") # Add the arguments - parser.add_argument('--port', - type=int, - default=4000, - help='port to use for the extract server') + parser.add_argument("--port", type=int, default=4000, help="port to use for the extract server") args_server = parser.parse_args() port = args_server.port app.run(host="0.0.0.0", port=port) diff --git a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/utils/kpi_mapping.py b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/utils/kpi_mapping.py index ba2f09f..8c9ebb0 100644 --- a/data_extractor/code/esg_data_pipeline/esg_data_pipeline/utils/kpi_mapping.py +++ b/data_extractor/code/esg_data_pipeline/esg_data_pipeline/utils/kpi_mapping.py @@ -4,16 +4,14 @@ try: df = pd.read_csv("/app/code/kpi_mapping.csv", header=0) - _KPI_MAPPING = {str(i[0]): i[1] for i in df[['kpi_id', 'question']].values} + _KPI_MAPPING = {str(i[0]): i[1] for i in df[["kpi_id", "question"]].values} KPI_MAPPING = {(float(key)): value for key, value in _KPI_MAPPING.items()} # Which questions should be added the year - ADD_YEAR = df[df['add_year']].kpi_id.tolist() + ADD_YEAR = df[df["add_year"]].kpi_id.tolist() # Category where the answer to the question should originate from - KPI_CATEGORY = { - i[0]: [j.strip() for j in i[1].split(', ')] for i in df[['kpi_id', 'kpi_category']].values - } + KPI_CATEGORY = {i[0]: [j.strip() for j in i[1].split(", ")] for i in df[["kpi_id", "kpi_category"]].values} except Exception as e: KPI_MAPPING = {} ADD_YEAR = [] diff --git a/data_extractor/code/esg_data_pipeline/setup.py b/data_extractor/code/esg_data_pipeline/setup.py index 4447b5f..1e4d36f 100644 --- a/data_extractor/code/esg_data_pipeline/setup.py +++ b/data_extractor/code/esg_data_pipeline/setup.py @@ -4,34 +4,36 @@ from setuptools import find_packages, setup -NAME = 'esg_data_pipeline' -DESCRIPTION = 'Extract, clean and save data from various sources for ESG project' -AUTHOR = '1QBit NLP' -REQUIRES_PYTHON = '>=3.6.0' +NAME = "esg_data_pipeline" +DESCRIPTION = "Extract, clean and save data from various sources for ESG project" +AUTHOR = "1QBit NLP" +REQUIRES_PYTHON = ">=3.6.0" -def list_reqs(fname='requirements.txt'): + +def list_reqs(fname="requirements.txt"): with open(fname) as fd: return fd.read().splitlines() + here = os.path.abspath(os.path.dirname(__file__)) # Load the package's __version__.py module as a dictionary. ROOT_DIR = Path(__file__).resolve().parent PACKAGE_DIR = ROOT_DIR / NAME about = {} -with open(PACKAGE_DIR / 'VERSION') as f: +with open(PACKAGE_DIR / "VERSION") as f: _version = f.read().strip() - about['__version__'] = _version + about["__version__"] = _version setup( name=NAME, - version=about['__version__'], + version=about["__version__"], description=DESCRIPTION, author=AUTHOR, python_requires=REQUIRES_PYTHON, - packages=find_packages(exclude=('tests', 'notebooks')), - package_data={'esg_data_pipeline': ['VERSION']}, + packages=find_packages(exclude=("tests", "notebooks")), + package_data={"esg_data_pipeline": ["VERSION"]}, install_requires=list_reqs(), extras_require={}, - include_package_data=True + include_package_data=True, ) diff --git a/data_extractor/code/esg_data_pipeline/test/app.py b/data_extractor/code/esg_data_pipeline/test/app.py index 1f84112..2a7be18 100644 --- a/data_extractor/code/esg_data_pipeline/test/app.py +++ b/data_extractor/code/esg_data_pipeline/test/app.py @@ -6,15 +6,16 @@ logging.basicConfig(level=logging.DEBUG) app = Flask(__name__) -@app.route('/') + +@app.route("/") def hello_world(): preds = np.exp(INPUT_ARRAY) - app.logger.info(" Inputs: "+ str(INPUT_ARRAY)) - app.logger.info(" Prediction: "+ str(preds)) + app.logger.info(" Inputs: " + str(INPUT_ARRAY)) + app.logger.info(" Prediction: " + str(preds)) return str(preds) -@app.route('/predict', methods=['GET']) +@app.route("/predict", methods=["GET"]) def predict(): """Return A Prediction.""" app.logger.info(str(request.args)) @@ -27,8 +28,7 @@ def predict(): app.logger.info(prediction) response_data = prediction return {"prediction": str(response_data)} - -if __name__ == '__main__': - app.run(host='0.0.0.0', port=6666, debug=True) +if __name__ == "__main__": + app.run(host="0.0.0.0", port=6666, debug=True) diff --git a/data_extractor/code/infer_on_pdf.py b/data_extractor/code/infer_on_pdf.py index 089a7a8..b10a64f 100644 --- a/data_extractor/code/infer_on_pdf.py +++ b/data_extractor/code/infer_on_pdf.py @@ -11,7 +11,7 @@ from s3_communication import S3Communication import pandas as pd -path_file_running = config_path.NLP_DIR+r'/data/running' +path_file_running = config_path.NLP_DIR + r"/data/running" project_settings = None source_pdf = None @@ -30,7 +30,7 @@ def set_running(): - with open(path_file_running, 'w'): + with open(path_file_running, "w"): pass @@ -53,7 +53,7 @@ def create_directory(directory_name): if os.path.isfile(file_path): os.unlink(file_path) except Exception as e: - print('Failed to delete %s. Reason: %s' % (file_path, e)) + print("Failed to delete %s. Reason: %s" % (file_path, e)) def link_files(source_dir, destination_dir): @@ -67,7 +67,7 @@ def link_files(source_dir, destination_dir): def copy_files_without_overwrite(src_path, dest_path): - for filename in os.listdir(src_path): + for filename in os.listdir(src_path): # construct the src path and file name src_path_file_name = os.path.join(src_path, filename) # construct the dest path and file name @@ -96,7 +96,7 @@ def link_extracted_files(src_ext, src_pdf, dest_ext): def convert_xls_to_csv(project_name, s3_usage, s3_settings): """ This function transforms the annotations.xlsx file into annotations.csv. - + :param project_name: str, representing the project we currently work on :param s3_usage: boolean, if we use s3 as we then have to upload the new csv file to s3 :param s3_settings: dictionary, containing information in case of s3 usage @@ -105,39 +105,38 @@ def convert_xls_to_csv(project_name, s3_usage, s3_settings): source_dir = source_annotation dest_dir = destination_annotation if s3_usage: - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" # init s3 connector s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), ) - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/annotations', - source_dir) + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/annotations", source_dir) first = True for filename in os.listdir(source_dir): - if filename[-5:] == '.xlsx': + if filename[-5:] == ".xlsx": if not first: - raise ValueError('More than one excel sheet found') - print('Converting ' + filename + ' to csv-format') + raise ValueError("More than one excel sheet found") + print("Converting " + filename + " to csv-format") # only reads first sheet in excel file - read_file = pd.read_excel(source_dir + r'/' + filename, engine='openpyxl') - read_file.to_csv(dest_dir + r'/aggregated_annotation.csv', index=None, header=True) + read_file = pd.read_excel(source_dir + r"/" + filename, engine="openpyxl") + read_file.to_csv(dest_dir + r"/aggregated_annotation.csv", index=None, header=True) if s3_usage: - s3c_interim.upload_files_in_dir_to_prefix(dest_dir, project_prefix + '/interim/ml/annotations') - first = False + s3c_interim.upload_files_in_dir_to_prefix(dest_dir, project_prefix + "/interim/ml/annotations") + first = False if first: - raise ValueError('No annotation excel sheet found') + raise ValueError("No annotation excel sheet found") -def run_router_ml(ext_port, infer_port, project_name, ext_ip='0.0.0.0', infer_ip='0.0.0.0'): +def run_router_ml(ext_port, infer_port, project_name, ext_ip="0.0.0.0", infer_ip="0.0.0.0"): """ Router function It fist sends a command to the extraction server to begin extraction. @@ -149,9 +148,9 @@ def run_router_ml(ext_port, infer_port, project_name, ext_ip='0.0.0.0', infer_ip :param infer_ip: int: The ip that the inference server is listening on :return: A boolean, indicating success """ - - convert_xls_to_csv(project_name, project_settings['s3_usage'], project_settings['s3_settings']) - + + convert_xls_to_csv(project_name, project_settings["s3_usage"], project_settings["s3_settings"]) + # Check if the extraction server is live ext_live = requests.get(f"http://{ext_ip}:{ext_port}/liveness") if ext_live.status_code == 200: @@ -159,10 +158,10 @@ def run_router_ml(ext_port, infer_port, project_name, ext_ip='0.0.0.0', infer_ip else: print("Extraction server is not responding.") return False - - payload = {'project_name': project_name, 'mode': 'infer'} + + payload = {"project_name": project_name, "mode": "infer"} payload.update(project_settings) - payload = {'payload': json.dumps(payload)} + payload = {"payload": json.dumps(payload)} # Sending an execution request to the extraction server for extraction ext_resp = requests.get(f"http://{ext_ip}:{ext_port}/extract", params=payload) @@ -192,60 +191,99 @@ def run_router_ml(ext_port, infer_port, project_name, ext_ip='0.0.0.0', infer_ip return True -def run_router_rb(raw_pdf_folder, working_folder, output_folder, project_name, verbosity, use_docker, port, ip, - s3_usage, s3_settings): +def run_router_rb( + raw_pdf_folder, working_folder, output_folder, project_name, verbosity, use_docker, port, ip, s3_usage, s3_settings +): if use_docker: - payload = {'project_name': project_name, 'verbosity': str(verbosity)} + payload = {"project_name": project_name, "verbosity": str(verbosity)} if s3_usage: - payload.update({'s3_usage': s3_usage}) - payload.update({'s3_settings': s3_settings}) - payload = {'payload': json.dumps(payload)} + payload.update({"s3_usage": s3_usage}) + payload.update({"s3_settings": s3_settings}) + payload = {"payload": json.dumps(payload)} rb_response = requests.get(f"http://{ip}:{port}/run", params=payload) print(rb_response.text) if rb_response.status_code != 200: return False else: - cmd = config_path.PYTHON_EXECUTABLE + ' rule_based_pipeline/rule_based_pipeline/main.py' + \ - ' --raw_pdf_folder "' + raw_pdf_folder + '"' + \ - ' --working_folder "' + working_folder + '"' + \ - ' --output_folder "' + output_folder + '"' + \ - ' --verbosity ' + str(verbosity) + cmd = ( + config_path.PYTHON_EXECUTABLE + + " rule_based_pipeline/rule_based_pipeline/main.py" + + ' --raw_pdf_folder "' + + raw_pdf_folder + + '"' + + ' --working_folder "' + + working_folder + + '"' + + ' --output_folder "' + + output_folder + + '"' + + " --verbosity " + + str(verbosity) + ) print("Running command: " + cmd) os.system(cmd) - return True + return True -def set_xy_ml(project_name, raw_pdf_folder, working_folder, pdf_name, csv_name, output_folder, verbosity, use_docker, - port, ip, s3_usage, s3_settings): +def set_xy_ml( + project_name, + raw_pdf_folder, + working_folder, + pdf_name, + csv_name, + output_folder, + verbosity, + use_docker, + port, + ip, + s3_usage, + s3_settings, +): if use_docker: - payload = {'project_name': project_name, - 'pdf_name': pdf_name, - 'csv_name': csv_name, - 'verbosity': str(verbosity)} + payload = { + "project_name": project_name, + "pdf_name": pdf_name, + "csv_name": csv_name, + "verbosity": str(verbosity), + } if s3_usage: - payload.update({'s3_usage': s3_usage}) - payload.update({'s3_settings': s3_settings}) - payload = {'payload': json.dumps(payload)} + payload.update({"s3_usage": s3_usage}) + payload.update({"s3_settings": s3_settings}) + payload = {"payload": json.dumps(payload)} rb_xy_extract_response = requests.get(f"http://{ip}:{port}/run_xy_ml", params=payload) print(rb_xy_extract_response.text) if rb_xy_extract_response.status_code != 200: return False else: - cmd = config_path.PYTHON_EXECUTABLE + ' rule_based_pipeline/rule_based_pipeline/main_find_xy.py' + \ - ' --raw_pdf_folder "' + raw_pdf_folder + '"' + \ - ' --working_folder "' + working_folder + '"' + \ - ' --pdf_name "' + pdf_name + '"' + \ - ' --csv_name "' + csv_name + '"' + \ - ' --output_folder "' + output_folder + '"' + \ - ' --verbosity ' + str(verbosity) + cmd = ( + config_path.PYTHON_EXECUTABLE + + " rule_based_pipeline/rule_based_pipeline/main_find_xy.py" + + ' --raw_pdf_folder "' + + raw_pdf_folder + + '"' + + ' --working_folder "' + + working_folder + + '"' + + ' --pdf_name "' + + pdf_name + + '"' + + ' --csv_name "' + + csv_name + + '"' + + ' --output_folder "' + + output_folder + + '"' + + " --verbosity " + + str(verbosity) + ) print("Running command: " + cmd) - - return True + + return True def get_current_run_id(): return int(time.time()) - + def try_int(val, default): try: @@ -255,68 +293,151 @@ def try_int(val, default): return default -def join_output(project_name, pdf_folder, rb_output_folder, ml_output_folder, output_folder, use_docker, work_dir_rb, - verbosity, port, ip, run_id, s3_usage, s3_settings): +def join_output( + project_name, + pdf_folder, + rb_output_folder, + ml_output_folder, + output_folder, + use_docker, + work_dir_rb, + verbosity, + port, + ip, + run_id, + s3_usage, + s3_settings, +): print("Joining output . . . ") # ML header: ,pdf_name,kpi,kpi_id,answer,page,paragraph,source,score,no_ans_score,no_answer_score_plus_boost # RB header: "KPI_ID","KPI_NAME","SRC_FILE","PAGE_NUM","ITEM_IDS","POS_X","POS_Y","RAW_TXT", # "YEAR","VALUE","SCORE","UNIT","MATCH_TYPE" - output_header = ["METHOD", "PDF_NAME", "KPI_ID", "KPI_NAME", "KPI_DESC", - "ANSWER_RAW", "ANSWER", "PAGE", "PARAGRAPH", "PARAGRAPH_RELEVANCE_SCORE", "POS_X", "POS_Y", - "KPI_SOURCE", "SCORE", "NO_ANS_SCORE", "SCORE_PLUS_BOOST", "KPI_YEAR", "UNIT_RAW", "UNIT"] + output_header = [ + "METHOD", + "PDF_NAME", + "KPI_ID", + "KPI_NAME", + "KPI_DESC", + "ANSWER_RAW", + "ANSWER", + "PAGE", + "PARAGRAPH", + "PARAGRAPH_RELEVANCE_SCORE", + "POS_X", + "POS_Y", + "KPI_SOURCE", + "SCORE", + "NO_ANS_SCORE", + "SCORE_PLUS_BOOST", + "KPI_YEAR", + "UNIT_RAW", + "UNIT", + ] for filename in os.listdir(pdf_folder): print(filename) - with open(output_folder + r'/' + str(run_id) + r'_' + filename + r'.csv', 'w', - encoding='UTF8', newline='') as f_out: + with open( + output_folder + r"/" + str(run_id) + r"_" + filename + r".csv", "w", encoding="UTF8", newline="" + ) as f_out: writer = csv.writer(f_out) writer.writerow(output_header) - - rb_filename = rb_output_folder + r'/' + filename + '.csv' - ml_filename = ml_output_folder + r'/' + filename[:len(filename)-4] + '_predictions_kpi.csv' + + rb_filename = rb_output_folder + r"/" + filename + ".csv" + ml_filename = ml_output_folder + r"/" + filename[: len(filename) - 4] + "_predictions_kpi.csv" # Read RB: try: - with open(rb_filename, 'r') as f: + with open(rb_filename, "r") as f: csv_file = csv.DictReader(f) for row in csv_file: d = dict(row) # TODO: Use UNIT_RAW/UNIT, once implemented in RB solution - data = ["RB", d["SRC_FILE"], d["KPI_ID"], d["KPI_NAME"], "", d["RAW_TXT"], - d["VALUE"], d["PAGE_NUM"], "", "", d["POS_X"], d["POS_Y"], - d["MATCH_TYPE"], d["SCORE"], "", "", d["YEAR"], d["UNIT"], d["UNIT"]] + data = [ + "RB", + d["SRC_FILE"], + d["KPI_ID"], + d["KPI_NAME"], + "", + d["RAW_TXT"], + d["VALUE"], + d["PAGE_NUM"], + "", + "", + d["POS_X"], + d["POS_Y"], + d["MATCH_TYPE"], + d["SCORE"], + "", + "", + d["YEAR"], + d["UNIT"], + d["UNIT"], + ] writer.writerow(data) except IOError: - pass # RB not executed + pass # RB not executed # Read ML: try: - with open(ml_filename, 'r') as f: + with open(ml_filename, "r") as f: csv_file = csv.DictReader(f) for row in csv_file: d = dict(row) - data = ["ML", d["pdf_name"] + r".pdf", "", "", d["kpi"], d["answer"], d["answer"], - str(try_int(d["page"], -2)+1), d["paragraph"], d["paragraph_relevance_score"], "", "", - d["source"], d["score"], d["no_ans_score"], d["no_answer_score_plus_boost"], "", "", ""] + data = [ + "ML", + d["pdf_name"] + r".pdf", + "", + "", + d["kpi"], + d["answer"], + d["answer"], + str(try_int(d["page"], -2) + 1), + d["paragraph"], + d["paragraph_relevance_score"], + "", + "", + d["source"], + d["score"], + d["no_ans_score"], + d["no_answer_score_plus_boost"], + "", + "", + "", + ] writer.writerow(data) except IOError: - pass # ML not executed - csv_name = str(run_id) + r'_' + filename + r'.csv' + pass # ML not executed + csv_name = str(run_id) + r"_" + filename + r".csv" if csv_name in os.listdir(output_folder): if s3_usage: - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), + ) + s3c_main.upload_file_to_s3( + filepath=output_folder + r"/" + csv_name, + s3_prefix=project_prefix + "/output/KPI_EXTRACTION/joined_ml_rb", + s3_key=csv_name, ) - s3c_main.upload_file_to_s3(filepath=output_folder + r'/' + csv_name, - s3_prefix=project_prefix + '/output/KPI_EXTRACTION/joined_ml_rb', - s3_key=csv_name) - set_xy_ml(project_name=project_name, raw_pdf_folder=pdf_folder, working_folder=work_dir_rb, - pdf_name=filename, csv_name=csv_name, output_folder=output_folder, verbosity=verbosity, - use_docker=use_docker, port=port, ip=ip, s3_usage=s3_usage, s3_settings=s3_settings) + set_xy_ml( + project_name=project_name, + raw_pdf_folder=pdf_folder, + working_folder=work_dir_rb, + pdf_name=filename, + csv_name=csv_name, + output_folder=output_folder, + verbosity=verbosity, + use_docker=use_docker, + port=port, + ip=ip, + s3_usage=s3_usage, + s3_settings=s3_settings, + ) else: - print(f'File {csv_name} not in the output and hence we are not able to detect x, ' - f'y coordinates for the ML solution output.') + print( + f"File {csv_name} not in the output and hence we are not able to detect x, " + f"y coordinates for the ML solution output." + ) if s3_usage: create_directory(pdf_folder) create_directory(rb_output_folder) @@ -325,157 +446,156 @@ def join_output(project_name, pdf_folder, rb_output_folder, ml_output_folder, ou def run_db_export(project_name, settings, run_id): - cmd = config_path.PYTHON_EXECUTABLE + ' dataload/db_export.py' + \ - ' --project_name "' + project_name + '"' + \ - ' --run_id "' + str(run_id) + '"' + cmd = ( + config_path.PYTHON_EXECUTABLE + + " dataload/db_export.py" + + ' --project_name "' + + project_name + + '"' + + ' --run_id "' + + str(run_id) + + '"' + ) print("Running command: " + cmd) os.system(cmd) - return True + return True def main(): global project_settings global source_annotation global destination_annotation - + if check_running(): print("Another training or inference process is currently running.") return - - parser = argparse.ArgumentParser(description='End-to-end inference') - + + parser = argparse.ArgumentParser(description="End-to-end inference") + # Add the arguments - parser.add_argument('--project_name', - type=str, - default=None, - help='Name of the Project') - - parser.add_argument('--mode', - type=str, - default='both', - help='Inference Mode (RB, ML, both, or none - for just doing postprocessing)') - - parser.add_argument('--s3_usage', - type=str, - default=None, - help='Do you want to use S3? Type either Y or N.') - + parser.add_argument("--project_name", type=str, default=None, help="Name of the Project") + + parser.add_argument( + "--mode", + type=str, + default="both", + help="Inference Mode (RB, ML, both, or none - for just doing postprocessing)", + ) + + parser.add_argument("--s3_usage", type=str, default=None, help="Do you want to use S3? Type either Y or N.") + args = parser.parse_args() project_name = args.project_name mode = args.mode - - if mode not in ('RB', 'ML', 'both', 'none'): + + if mode not in ("RB", "ML", "both", "none"): print("Illegal mode specified. Mode must be either RB, ML, both or none") return - + if project_name is None: project_name = input("What is the project name? ") if project_name is None or project_name == "": print("project name must not be empty") return - + s3_usage = args.s3_usage if s3_usage is None: - s3_usage = input('Do you want to use S3? Type either Y or N.') - if s3_usage is None or str(s3_usage) not in ['Y', 'N']: + s3_usage = input("Do you want to use S3? Type either Y or N.") + if s3_usage is None or str(s3_usage) not in ["Y", "N"]: print("Answer to S3 usage must by Y or N. Stop program. Please restart.") return None else: - s3_usage = s3_usage == 'Y' - - project_data_dir = config_path.DATA_DIR + r'/' + project_name + s3_usage = s3_usage == "Y" + + project_data_dir = config_path.DATA_DIR + r"/" + project_name create_directory(project_data_dir) - s3c_main = None + s3c_main = None if s3_usage: # Opening s3 settings file - s3_settings_path = config_path.DATA_DIR + r'/' + 's3_settings.yaml' - f = open(s3_settings_path, 'r') + s3_settings_path = config_path.DATA_DIR + r"/" + "s3_settings.yaml" + f = open(s3_settings_path, "r") s3_settings = yaml.safe_load(f) f.close() - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" # init s3 connector s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), ) settings_path = project_data_dir + "/settings.yaml" - s3c_main.download_file_from_s3(filepath=settings_path, - s3_prefix=project_prefix, - s3_key='settings.yaml') - - s3c_main = None + s3c_main.download_file_from_s3(filepath=settings_path, s3_prefix=project_prefix, s3_key="settings.yaml") + + s3c_main = None if s3_usage: # Opening s3 settings file - s3_settings_path = config_path.DATA_DIR + r'/' + 's3_settings.yaml' - f = open(s3_settings_path, 'r') + s3_settings_path = config_path.DATA_DIR + r"/" + "s3_settings.yaml" + f = open(s3_settings_path, "r") s3_settings = yaml.safe_load(f) f.close() - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" # init s3 connector s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), ) settings_path = project_data_dir + "/settings.yaml" - s3c_main.download_file_from_s3(filepath=settings_path, - s3_prefix=project_prefix, - s3_key='settings.yaml') - + s3c_main.download_file_from_s3(filepath=settings_path, s3_prefix=project_prefix, s3_key="settings.yaml") + # Opening YAML file - f = open(project_data_dir + r'/settings.yaml', 'r') + f = open(project_data_dir + r"/settings.yaml", "r") project_settings = yaml.safe_load(f) - f.close() + f.close() - project_settings.update({'s3_usage': s3_usage}) + project_settings.update({"s3_usage": s3_usage}) if s3_usage: - project_settings.update({'s3_settings': s3_settings}) + project_settings.update({"s3_settings": s3_settings}) + + ext_port = project_settings["general"]["ext_port"] + infer_port = project_settings["general"]["infer_port"] + rb_port = project_settings["general"]["rb_port"] - ext_port = project_settings['general']['ext_port'] - infer_port = project_settings['general']['infer_port'] - rb_port = project_settings['general']['rb_port'] + ext_ip = project_settings["general"]["ext_ip"] + infer_ip = project_settings["general"]["infer_ip"] + rb_ip = project_settings["general"]["rb_ip"] - ext_ip = project_settings['general']['ext_ip'] - infer_ip = project_settings['general']['infer_ip'] - rb_ip = project_settings['general']['rb_ip'] + enable_db_export = project_settings["data_export"]["enable_db_export"] + rb_verbosity = int(project_settings["rule_based"]["verbosity"]) + rb_use_docker = project_settings["rule_based"]["use_docker"] - enable_db_export = project_settings['data_export']['enable_db_export'] - rb_verbosity = int(project_settings['rule_based']['verbosity']) - rb_use_docker = project_settings['rule_based']['use_docker'] - set_running() try: # Source folders - source_pdf = project_data_dir + r'/input/pdfs/inference' - destination_pdf = project_data_dir + r'/interim/pdfs/' - source_mapping = project_data_dir + r'/input/kpi_mapping' - source_annotation = project_data_dir + r'/input/annotations' - destination_annotation = project_data_dir + r'/interim/ml/annotations/' + source_pdf = project_data_dir + r"/input/pdfs/inference" + destination_pdf = project_data_dir + r"/interim/pdfs/" + source_mapping = project_data_dir + r"/input/kpi_mapping" + source_annotation = project_data_dir + r"/input/annotations" + destination_annotation = project_data_dir + r"/interim/ml/annotations/" # Interim folders - destination_mapping = project_data_dir + r'/interim/kpi_mapping/' - destination_ml_extraction = project_data_dir + r'/interim/ml/extraction/' - destination_rb_workdir = project_data_dir + r'/interim/rb/work' - destination_rb_infer = project_data_dir + r'/output/KPI_EXTRACTION/rb' - destination_ml_infer = project_data_dir + r'/output/KPI_EXTRACTION/ml/Text' + destination_mapping = project_data_dir + r"/interim/kpi_mapping/" + destination_ml_extraction = project_data_dir + r"/interim/ml/extraction/" + destination_rb_workdir = project_data_dir + r"/interim/rb/work" + destination_rb_infer = project_data_dir + r"/output/KPI_EXTRACTION/rb" + destination_ml_infer = project_data_dir + r"/output/KPI_EXTRACTION/ml/Text" # Output folders - destination_output = project_data_dir + r'/output/KPI_EXTRACTION/joined_ml_rb' + destination_output = project_data_dir + r"/output/KPI_EXTRACTION/joined_ml_rb" create_directory(source_pdf) create_directory(source_mapping) @@ -483,83 +603,90 @@ def main(): create_directory(destination_mapping) create_directory(destination_ml_extraction) create_directory(destination_annotation) - if mode != 'none': + if mode != "none": create_directory(destination_rb_infer) - create_directory(destination_ml_infer) + create_directory(destination_ml_infer) os.makedirs(destination_rb_workdir, exist_ok=True) os.makedirs(destination_output, exist_ok=True) link_files(source_pdf, destination_pdf) link_files(source_mapping, destination_mapping) - if project_settings['extraction']['use_extractions']: - source_extraction = project_data_dir + r'/output/TEXT_EXTRACTION' + if project_settings["extraction"]["use_extractions"]: + source_extraction = project_data_dir + r"/output/TEXT_EXTRACTION" if os.path.exists(source_extraction): link_extracted_files(source_extraction, source_pdf, destination_ml_extraction) - + end_to_end_response = True - - if mode in ('RB', 'both'): + + if mode in ("RB", "both"): print("Executing RB solution . . . ") - end_to_end_response = end_to_end_response and \ - run_router_rb(raw_pdf_folder=destination_pdf, - working_folder=destination_rb_workdir, - output_folder=destination_rb_infer, - project_name=project_name, - verbosity=rb_verbosity, - use_docker=rb_use_docker, - ip=rb_ip, - port=rb_port, - s3_usage=s3_usage, - s3_settings=s3_settings) + end_to_end_response = end_to_end_response and run_router_rb( + raw_pdf_folder=destination_pdf, + working_folder=destination_rb_workdir, + output_folder=destination_rb_infer, + project_name=project_name, + verbosity=rb_verbosity, + use_docker=rb_use_docker, + ip=rb_ip, + port=rb_port, + s3_usage=s3_usage, + s3_settings=s3_settings, + ) if s3_usage: # Download inference output - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/output/KPI_EXTRACTION/rb', - destination_rb_infer) - - if mode in ('ML', 'both'): + s3c_main.download_files_in_prefix_to_dir( + project_prefix + "/output/KPI_EXTRACTION/rb", destination_rb_infer + ) + + if mode in ("ML", "both"): print("Executing ML solution . . . ") - end_to_end_response = end_to_end_response and \ - run_router_ml(ext_port, infer_port, project_name, ext_ip, infer_ip) + end_to_end_response = end_to_end_response and run_router_ml( + ext_port, infer_port, project_name, ext_ip, infer_ip + ) if s3_usage: # Download inference output - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/output/KPI_EXTRACTION/ml/Text', - destination_ml_infer) + s3c_main.download_files_in_prefix_to_dir( + project_prefix + "/output/KPI_EXTRACTION/ml/Text", destination_ml_infer + ) if end_to_end_response: run_id = get_current_run_id() if s3_usage: # Download pdf's to folder - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/pdfs/inference', - destination_pdf) - - join_output(project_name=project_name, - pdf_folder=destination_pdf, - rb_output_folder=destination_rb_infer, - ml_output_folder=destination_ml_infer, - output_folder=destination_output, - use_docker=rb_use_docker, - work_dir_rb=destination_rb_workdir, - verbosity=rb_verbosity, - port=rb_port, - ip=rb_ip, - run_id=run_id, - s3_usage=s3_usage, - s3_settings=s3_settings) + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/pdfs/inference", destination_pdf) + + join_output( + project_name=project_name, + pdf_folder=destination_pdf, + rb_output_folder=destination_rb_infer, + ml_output_folder=destination_ml_infer, + output_folder=destination_output, + use_docker=rb_use_docker, + work_dir_rb=destination_rb_workdir, + verbosity=rb_verbosity, + port=rb_port, + ip=rb_ip, + run_id=run_id, + s3_usage=s3_usage, + s3_settings=s3_settings, + ) if enable_db_export: print("Exporting output to database . . . ") - run_db_export(project_name, project_settings['data_export'], run_id) - if project_settings['extraction']['store_extractions']: + run_db_export(project_name, project_settings["data_export"], run_id) + if project_settings["extraction"]["store_extractions"]: print("Finally we transfer the text extraction to the output folder.") source_extraction_data = destination_ml_extraction - destination_extraction_data = project_data_dir + r'/output/TEXT_EXTRACTION' + destination_extraction_data = project_data_dir + r"/output/TEXT_EXTRACTION" if s3_usage: - s3c_interim.download_files_in_prefix_to_dir(project_prefix + '/interim/ml/extraction', - source_extraction_data) - s3c_main.upload_files_in_dir_to_prefix(source_extraction_data, - project_prefix + '/output/TEXT_EXTRACTION') + s3c_interim.download_files_in_prefix_to_dir( + project_prefix + "/interim/ml/extraction", source_extraction_data + ) + s3c_main.upload_files_in_dir_to_prefix( + source_extraction_data, project_prefix + "/output/TEXT_EXTRACTION" + ) os.makedirs(destination_extraction_data, exist_ok=True) copy_files_without_overwrite(source_extraction_data, destination_extraction_data) - if project_settings['general']['delete_interim_files']: + if project_settings["general"]["delete_interim_files"]: create_directory(destination_ml_extraction) create_directory(destination_rb_workdir) create_directory(destination_pdf) @@ -567,16 +694,16 @@ def main(): if s3_usage: # Show only objects which satisfy our prefix my_bucket = s3c_interim.s3_resource.Bucket(name=s3c_interim.bucket) - for objects in my_bucket.objects.filter(Prefix=project_prefix+'/interim'): + for objects in my_bucket.objects.filter(Prefix=project_prefix + "/interim"): _ = objects.delete() if end_to_end_response: print("End-to-end inference complete") - + except Exception as e: - print('Process failed to run. Reason:' + str(repr(e)) + traceback.format_exc()) + print("Process failed to run. Reason:" + str(repr(e)) + traceback.format_exc()) clear_running() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/components/base_kpi_inference_curator.py b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/components/base_kpi_inference_curator.py index 8c95ef6..9f1f2a2 100644 --- a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/components/base_kpi_inference_curator.py +++ b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/components/base_kpi_inference_curator.py @@ -3,6 +3,7 @@ import re from collections import defaultdict + class BaseKPIInferenceCurator(ABC): def __init__(self, name="BaseKPIInferenceCurator"): self.name = name @@ -20,30 +21,26 @@ def clean_text(text): # Substitute unusual quotes at the end of the string with usual quotes text = re.sub("”(?=\])", '"', text) # Substitute th remaining unusual quotes with space - text = re.sub('“|”', '', text) - text = re.sub('\n', " ", text) - text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\xff]', '', text) + text = re.sub("“|”", "", text) + text = re.sub("\n", " ", text) + text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\xff]", "", text) text = re.sub(r"\s{2,}", " ", text) # replace special character - special_regex_char = [ - "(", ")", "^", "+", "*", "$", "|", "\\", "?", "[", "]", "{", "}" - ] - text = ''.join( - ["" if c in special_regex_char else c for c in text] - ) + special_regex_char = ["(", ")", "^", "+", "*", "$", "|", "\\", "?", "[", "]", "{", "}"] + text = "".join(["" if c in special_regex_char else c for c in text]) text = text.lower() # remove consecutive dots - consecutive_dots = re.compile(r'\.{2,}') - text = consecutive_dots.sub('', text) + consecutive_dots = re.compile(r"\.{2,}") + text = consecutive_dots.sub("", text) return text @staticmethod def create_squad_from_df(df): - """ Create squad data format given a dataframe + """Create squad data format given a dataframe Args: df (A pandas DataFrame): Must have columns in this order ["source_file", @@ -52,59 +49,57 @@ def create_squad_from_df(df): Returns: squad_json (A nested of list and dict): Squad json format """ - order_col = ["source_file", "paragraph", "question","answer", "answer_start"] - assert(all([e in df.columns for e in order_col])) + order_col = ["source_file", "paragraph", "question", "answer", "answer_start"] + assert all([e in df.columns for e in order_col]) df = df[order_col] - files = df['source_file'].unique() + files = df["source_file"].unique() data = [] for f in files: single_data = {} - single_data['title'] = f + single_data["title"] = f temp = df[df["source_file"] == f] - unique_par = temp['paragraph'].unique() + unique_par = temp["paragraph"].unique() paragraphs = [] for up in unique_par: single_par = {} - single_par['context'] = up + single_par["context"] = up - temp_2 = temp[temp['paragraph'] == up] + temp_2 = temp[temp["paragraph"] == up] qas = [] for row in temp_2.itertuples(): single_qas = {} - single_qas['question'] = row[3] # question has index 3 - #index - single_qas['id'] = row[0] + single_qas["question"] = row[3] # question has index 3 + # index + single_qas["id"] = row[0] ans_st = row[5] # answer_start has index 5 if ans_st == []: answers = [] - single_qas['is_impossible'] = True + single_qas["is_impossible"] = True else: answers = [] for i in ans_st: - answers.append( - {"text": row[4], "answer_start": i} # answer has index 4 - ) - single_qas['is_impossible'] = False - single_qas['answers'] = answers + answers.append({"text": row[4], "answer_start": i}) # answer has index 4 + single_qas["is_impossible"] = False + single_qas["answers"] = answers qas.append(single_qas) - single_par['qas'] = qas + single_par["qas"] = qas paragraphs.append(single_par) - single_data['paragraphs'] = paragraphs + single_data["paragraphs"] = paragraphs data.append(single_data) squad_json = {} - squad_json['version'] = "v2.0" - squad_json['data'] = data + squad_json["version"] = "v2.0" + squad_json["data"] = data return squad_json @@ -126,7 +121,7 @@ def find_answer_start(answer, par): pat2 = answer + "[^0-9]" matches1 = re.finditer(pat1, par) matches2 = re.finditer(pat2, par) - ans_start_1 = [i.start()+1 for i in matches1] + ans_start_1 = [i.start() + 1 for i in matches1] ans_start_2 = [i.start() for i in matches2] ans_start = list(set(ans_start_1 + ans_start_2)) else: @@ -137,7 +132,7 @@ def find_answer_start(answer, par): return ans_start def split_squad(self, squad_json, val_ratio, seed): - """ Given a squad like json data format, split to train and val sets + """Given a squad like json data format, split to train and val sets Args: squad_json @@ -149,16 +144,16 @@ def split_squad(self, squad_json, val_ratio, seed): val_squad (A dict) """ indices = [] - for i1, pdf in enumerate(squad_json['data']): - pars = pdf['paragraphs'] + for i1, pdf in enumerate(squad_json["data"]): + pars = pdf["paragraphs"] for i2, par in enumerate(pars): - qas = par['qas'] + qas = par["qas"] indices.append((i1, i2)) random.seed(seed) random.shuffle(indices) - split_idx = int((1-val_ratio)*len(indices)) + split_idx = int((1 - val_ratio) * len(indices)) train_indices = indices[:split_idx] val_indices = indices[split_idx:] @@ -179,26 +174,25 @@ def return_sliced_squad(self, squad_json, indices): return {} pdf2pars = defaultdict(list) - for (i1, i2) in indices: + for i1, i2 in indices: pdf2pars[i1].append(i2) data = [] for i1 in pdf2pars: pars_indices = pdf2pars[i1] - pars = [squad_json['data'][i1]['paragraphs'][i2] for i2 in pars_indices] + pars = [squad_json["data"][i1]["paragraphs"][i2] for i2 in pars_indices] single_pdf = {} - single_pdf['paragraphs'] = pars - single_pdf['title'] = squad_json['data'][i1]['title'] + single_pdf["paragraphs"] = pars + single_pdf["title"] = squad_json["data"][i1]["title"] data.append(single_pdf) squad_data = {} - squad_data['version'] = "v2.0" - squad_data['data'] = data + squad_data["version"] = "v2.0" + squad_data["data"] = data return squad_data - @abstractmethod def curate(self, *args, **kwargs): pass diff --git a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/components/text_kpi_inference_curator.py b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/components/text_kpi_inference_curator.py index 3af6bf8..d54e4ef 100644 --- a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/components/text_kpi_inference_curator.py +++ b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/components/text_kpi_inference_curator.py @@ -13,13 +13,22 @@ import importlib import logging + logger = logging.getLogger(__name__) COL_ORDER = [ - 'company', 'source_file', 'source_page', 'kpi_id', - 'year', 'answer', 'data_type', 'relevant_paragraphs', - 'annotator', 'sector' + "company", + "source_file", + "source_page", + "kpi_id", + "year", + "answer", + "data_type", + "relevant_paragraphs", + "annotator", + "sector", ] + class TextKPIInferenceCurator(BaseKPIInferenceCurator): def __init__( self, @@ -28,7 +37,7 @@ def __init__( extracted_text_json_folder, output_squad_folder, relevant_text_path=None, - name="TextKPIInferenceCurator" + name="TextKPIInferenceCurator", ): """ Args: @@ -62,17 +71,15 @@ def read_agg(self): df (a pd dataframe) """ if not os.path.exists(self.agg_annotation): - logger.info( - "{} not available, will create it.".format(self.agg_annotation) - ) + logger.info("{} not available, will create it.".format(self.agg_annotation)) df = aggregate_csvs(self.annotation_folder) df = clean_annotation(df, self.agg_annotation)[COL_ORDER] else: - #df = pd.read_csv(self.agg_annotation, header=0, index_col=0)[COL_ORDER] - input_fd = open(self.agg_annotation, errors = 'ignore') + # df = pd.read_csv(self.agg_annotation, header=0, index_col=0)[COL_ORDER] + input_fd = open(self.agg_annotation, errors="ignore") df = pd.read_csv(input_fd, header=0, index_col=0)[COL_ORDER] input_fd.close() - df.loc[:, 'source_page'] = df['source_page'].apply(ast.literal_eval) + df.loc[:, "source_page"] = df["source_page"].apply(ast.literal_eval) return df @@ -83,10 +90,11 @@ def clean(self, df): Args: df (A pandas dataframe) """ + # map kpi to question def map_kpi(r): try: - question = kpi_mapping.KPI_MAPPING[float(r['kpi_id'])] + question = kpi_mapping.KPI_MAPPING[float(r["kpi_id"])] except (KeyError, ValueError) as e: question = None @@ -96,26 +104,26 @@ def map_kpi(r): except ValueError: year = r["year"] - if float(r['kpi_id']) in kpi_mapping.ADD_YEAR: + if float(r["kpi_id"]) in kpi_mapping.ADD_YEAR: front = question.split("?")[0] question = front + " in year {}?".format(year) return question - df['question'] = df[['kpi_id', 'year']].apply(map_kpi, axis=1) - df = df.dropna(axis=0, subset=['question']).reset_index(drop=True) + df["question"] = df[["kpi_id", "year"]].apply(map_kpi, axis=1) + df = df.dropna(axis=0, subset=["question"]).reset_index(drop=True) # Remove NaN rows based on relevant paragraphs and answer - df = df[~df['relevant_paragraphs'].isna()] - df = df[~df['answer'].isna()] + df = df[~df["relevant_paragraphs"].isna()] + df = df[~df["answer"].isna()] # change line space to white space, remove trailing and initial white space - df.loc[:, 'answer'] = df['answer'].apply(lambda x: " ".join(str(x).split("\n"))) - df.loc[:, 'answer'] = df['answer'].apply(lambda x: x.strip()) + df.loc[:, "answer"] = df["answer"].apply(lambda x: " ".join(str(x).split("\n"))) + df.loc[:, "answer"] = df["answer"].apply(lambda x: x.strip()) # clean relevant_paragraphs - df.loc[:, 'relevant_paragraphs'] = df['relevant_paragraphs'].apply(self.clean_paragraph) - df = df.dropna(axis=0, subset=['relevant_paragraphs']).reset_index(drop=True) + df.loc[:, "relevant_paragraphs"] = df["relevant_paragraphs"].apply(self.clean_paragraph) + df = df.dropna(axis=0, subset=["relevant_paragraphs"]).reset_index(drop=True) # split multiple paragraphs to individual examples df = self.split_multi_paragraph(df) @@ -123,24 +131,21 @@ def map_kpi(r): return df def split_multi_paragraph(self, df): - """ Splits multiple relevant paragraphs to individual examples - """ + """Splits multiple relevant paragraphs to individual examples""" # if single relevant paragraphs, then assuming only has single source page (fair enough) - df_single = df[df['relevant_paragraphs'].apply(len) == 1] - df_single.loc[:,'source_page'] = df_single['source_page'].apply(lambda x: x[0]) - df_single.loc[:,'relevant_paragraphs'] = df['relevant_paragraphs'].apply(lambda x: x[0]) + df_single = df[df["relevant_paragraphs"].apply(len) == 1] + df_single.loc[:, "source_page"] = df_single["source_page"].apply(lambda x: x[0]) + df_single.loc[:, "relevant_paragraphs"] = df["relevant_paragraphs"].apply(lambda x: x[0]) # Otherwise - df_multi = df[df['relevant_paragraphs'].apply(len) > 1] + df_multi = df[df["relevant_paragraphs"].apply(len) > 1] new_multi = [] # better to check before using itertuples - col_order = COL_ORDER + ['question'] - assert( + col_order = COL_ORDER + ["question"] + assert ( all([e in df_multi.columns.tolist() for e in COL_ORDER]), - "dataframe columns are different. Your df column {}".format( - df_multi.columns.tolist() - ) + "dataframe columns are different. Your df column {}".format(df_multi.columns.tolist()), ) df_multi = df_multi[col_order] for row in df_multi.itertuples(): @@ -167,7 +172,7 @@ def split_multi_paragraph(self, df): return df def clean_paragraph(self, r): - """ Clean relevant_paragraphs column + """Clean relevant_paragraphs column Args: r (A pandas series row) @@ -203,16 +208,16 @@ def clean_paragraph(self, r): temp = [] start = 0 for i in first_type: - temp.append(strp[start:i.start()]) - start = i.start()+4 + temp.append(strp[start : i.start()]) + start = i.start() + 4 temp.append(strp[start:]) return temp elif len(first_type) == 0 and len(second_type) != 0: temp = [] start = 0 for i in second_type: - temp.append(strp[start:i.start()]) - start = i.start()+3 + temp.append(strp[start : i.start()]) + start = i.start() + 3 temp.append(strp[start:]) return temp else: # a combination of two @@ -224,34 +229,28 @@ def clean_paragraph(self, r): while track1 < len(first_type) or track2 < len(second_type): if track1 == len(first_type): for i in second_type[track2:]: - temp.append(strp[start:i.start()]) + temp.append(strp[start : i.start()]) start = i.start() + 3 break if track2 == len(second_type): for i in first_type[track1:]: - temp.append(strp[start:i.start()]) + temp.append(strp[start : i.start()]) start = i.start() + 4 break if first_type[track1].start() < second_type[track2].start(): - temp.append(strp[start:first_type[track1].start()]) + temp.append(strp[start : first_type[track1].start()]) start = first_type[track1].start() + 4 track1 += 1 else: - temp.append(strp[start:second_type[track2].start()]) + temp.append(strp[start : second_type[track2].start()]) start = second_type[track2].start() + 3 track2 += 1 return temp - - def find_closest_paragraph( - self, - pars, - clean_rel_par, - clean_answer - ): + def find_closest_paragraph(self, pars, clean_rel_par, clean_answer): """ Args: pars (A list of str) @@ -282,7 +281,7 @@ def find_closest_paragraph( return clean_rel_par def return_full_paragraph(self, r, json_dict): - """ Find closest full paragraph, if can't be found return annotated + """Find closest full paragraph, if can't be found return annotated paragraph instead. Args: @@ -294,50 +293,41 @@ def return_full_paragraph(self, r, json_dict): clean_answer (A str) ans_start (A list of int) """ - clean_answer = self.clean_text(r['answer']) - clean_rel_par = self.clean_text(r['relevant_paragraphs']) + clean_answer = self.clean_text(r["answer"]) + clean_rel_par = self.clean_text(r["relevant_paragraphs"]) # If json file not extracted, use relevant text from annotation - if r['source_file'] not in json_dict: + if r["source_file"] not in json_dict: logger.info( - "{} json file has not been extracted. Will use relevant text as annotated."\ - .format(r['source_file']) + "{} json file has not been extracted. Will use relevant text as annotated.".format(r["source_file"]) ) else: - d = json_dict[r['source_file']] + d = json_dict[r["source_file"]] # pdfminer starts counter from 0 (hence the dictionary loaded from json) - pars = d[str(int(r['source_page']) - 1)] + pars = d[str(int(r["source_page"]) - 1)] if len(pars) == 0: logger.info( "{}.json has empty list of paragraphs at page {}. \ Will use relevant text as annotated".format( - r['source_file'].split('.pdf')[0], r['source_page'] + r["source_file"].split(".pdf")[0], r["source_page"] ) ) else: # match the closest paragraph to the annotated one # let's try exact match - clean_rel_par = self.find_closest_paragraph( - pars, clean_rel_par, clean_answer - ) + clean_rel_par = self.find_closest_paragraph(pars, clean_rel_par, clean_answer) ans_start = self.find_answer_start(clean_answer, clean_rel_par) # avoid 0th index answer due to FARM bug if 0 in ans_start: clean_rel_par = " " + clean_rel_par - ans_start = [i+1 for i in ans_start] + ans_start = [i + 1 for i in ans_start] return clean_rel_par, clean_answer, ans_start - def curate( - self, - val_ratio, - seed, - find_new_answerable=True, - create_unanswerable=True - ): + def curate(self, val_ratio, seed, find_new_answerable=True, create_unanswerable=True): """ Curate squad samples @@ -350,32 +340,29 @@ def curate( samples """ df = self.read_agg() - df = df[df['data_type'] == self.data_type] + df = df[df["data_type"] == self.data_type] df = self.clean(df) # get all available jsons from extraction phase - all_json = [ - i for i in os.listdir(self.extracted_text_json_folder) \ - if i.endswith(".json") - ] + all_json = [i for i in os.listdir(self.extracted_text_json_folder) if i.endswith(".json")] json_dict = {} for f in all_json: name = f.split(".json")[0] - with open(os.path.join(self.extracted_text_json_folder, f), 'r') as fi: + with open(os.path.join(self.extracted_text_json_folder, f), "r") as fi: d = json.load(fi) - json_dict[name+".pdf"] = d + json_dict[name + ".pdf"] = d answerable_df = self.create_answerable(df, json_dict, find_new_answerable) if create_unanswerable: unanswerable_df = self.create_unanswerable(df) - all_df = pd.concat( - [answerable_df, unanswerable_df] - )\ - .drop_duplicates(subset=['answer', 'paragraph', 'question'])\ - .reset_index(drop=True) + all_df = ( + pd.concat([answerable_df, unanswerable_df]) + .drop_duplicates(subset=["answer", "paragraph", "question"]) + .reset_index(drop=True) + ) else: all_df = answerable_df @@ -384,21 +371,15 @@ def curate( da = date.today().strftime("%d-%m-%Y") # save data as csv for reference - all_df.to_csv( - os.path.join(self.output_squad_folder, "reference_kpi_{}.csv".format(da)) - ) + all_df.to_csv(os.path.join(self.output_squad_folder, "reference_kpi_{}.csv".format(da))) if train_squad != {}: - train_f = os.path.join( - self.output_squad_folder, "kpi_train.json" - ) - with open(train_f, 'w') as f: + train_f = os.path.join(self.output_squad_folder, "kpi_train.json") + with open(train_f, "w") as f: json.dump(train_squad, f) if val_squad != {}: - val_f = os.path.join( - self.output_squad_folder, "kpi_val_split.json" - ) - with open(val_f, 'w') as f: + val_f = os.path.join(self.output_squad_folder, "kpi_val_split.json") + with open(val_f, "w") as f: json.dump(val_squad, f) return train_squad, val_squad @@ -420,31 +401,29 @@ def create_answerable(self, df, json_dict, find_new_answerable): # set new answer, relevant_paragraphs and add answer_start temp = pd.DataFrame(results.tolist()) - df['relevant_paragraphs'] = temp[0] - df['answer'] = temp[1] - df['answer_start'] = temp[2] - df = df[~df['answer'].isna()] + df["relevant_paragraphs"] = temp[0] + df["answer"] = temp[1] + df["answer_start"] = temp[2] + df = df[~df["answer"].isna()] if find_new_answerable: synthetic_pos = self.find_extra_answerable(df, json_dict) else: synthetic_pos = pd.DataFrame([]) - pos_df = pd.concat([df, synthetic_pos])\ - .drop_duplicates(subset=['answer', 'relevant_paragraphs', 'question'])\ + pos_df = ( + pd.concat([df, synthetic_pos]) + .drop_duplicates(subset=["answer", "relevant_paragraphs", "question"]) .reset_index(drop=True) + ) - pos_df = pos_df[pos_df['answer_start'].apply(len) != 0].reset_index(drop=True) - pos_df.rename({'relevant_paragraphs':"paragraph"}, axis=1, inplace=True) + pos_df = pos_df[pos_df["answer_start"].apply(len) != 0].reset_index(drop=True) + pos_df.rename({"relevant_paragraphs": "paragraph"}, axis=1, inplace=True) - pos_df = pos_df[ - ["source_file", "paragraph", "question", - "answer", "answer_start"] - ] + pos_df = pos_df[["source_file", "paragraph", "question", "answer", "answer_start"]] return pos_df - def find_extra_answerable(self, df, json_dict): """ Find extra answerable samples @@ -481,19 +460,27 @@ def find_extra_answerable(self, df, json_dict): for par in pars: clean_rel_par = self.clean_text(par) - ans_start = self.find_answer_start( - clean_answer, clean_rel_par - ) + ans_start = self.find_answer_start(clean_answer, clean_rel_par) # avoid 0th index answer due to FARM bug if 0 in ans_start: clean_rel_par = " " + clean_rel_par - ans_start = [i+1 for i in ans_start] + ans_start = [i + 1 for i in ans_start] if len(ans_start) != 0: example = [ - t[1], t[2], p, kpi_id, t[5], clean_answer, t[7], - clean_rel_par, "1QBit", t[10], t[11], ans_start + t[1], + t[2], + p, + kpi_id, + t[5], + clean_answer, + t[7], + clean_rel_par, + "1QBit", + t[10], + t[11], + ans_start, ] new_positive.append(example) @@ -518,43 +505,30 @@ def create_unanswerable(self, annotation_df): # TODO: creating the logging. ## Get the relevant pairs of Kpi questions and paragraphs - relevant_df = pd.read_csv( - self.relevant_text_path, header=0, index_col=0, usecols=[0,1,2,3,4] - ) + relevant_df = pd.read_csv(self.relevant_text_path, header=0, index_col=0, usecols=[0, 1, 2, 3, 4]) - order_col = ['page', 'pdf_name', 'text', 'text_b'] - assert(all([e in relevant_df.columns for e in order_col])) + order_col = ["page", "pdf_name", "text", "text_b"] + assert all([e in relevant_df.columns for e in order_col]) relevant_df = relevant_df[order_col] def add_pdf_extension(pdf_name): - #pdf_name = " ".join(pdf_name.split("-")[:-2]) + # pdf_name = " ".join(pdf_name.split("-")[:-2]) return str(pdf_name) + ".pdf" relevant_df.loc[:, "text_b"] = relevant_df["text_b"].apply(self.clean_text) - relevant_df.loc[:, "pdf_name"] = relevant_df.apply( - lambda x: add_pdf_extension(x.pdf_name), axis=1 - ) + relevant_df.loc[:, "pdf_name"] = relevant_df.apply(lambda x: add_pdf_extension(x.pdf_name), axis=1) # Pages in the json files start from 0, while in a pdf viewer it starts from 1. relevant_df.loc[:, "page_viewer"] = relevant_df.apply(lambda x: x.page + 1, axis=1) neg_df = self.filter_relevant_examples(annotation_df, relevant_df) - neg_df.rename( - {"text": "question", "text_b": "paragraph", "pdf_name": "source_file"}, - inplace=True, - axis=1 - ) - neg_df['answer_start'] = [[]]*neg_df.shape[0] - neg_df['answer'] = "" - neg_df = neg_df\ - .drop_duplicates(subset=['answer', 'paragraph', 'question'])\ - .reset_index(drop=True) + neg_df.rename({"text": "question", "text_b": "paragraph", "pdf_name": "source_file"}, inplace=True, axis=1) + neg_df["answer_start"] = [[]] * neg_df.shape[0] + neg_df["answer"] = "" + neg_df = neg_df.drop_duplicates(subset=["answer", "paragraph", "question"]).reset_index(drop=True) - neg_df = neg_df[ - ["source_file", "paragraph", "question", - "answer", "answer_start"] - ] + neg_df = neg_df[["source_file", "paragraph", "question", "answer", "answer_start"]] return neg_df @@ -573,7 +547,7 @@ def filter_relevant_examples(self, annotation_df, relevant_df): """ # Get the list of pdfs mention in relevant data frame - target_pdfs = list(relevant_df['pdf_name'].unique()) + target_pdfs = list(relevant_df["pdf_name"].unique()) neg_examples_df_list = [] for pdf_file in target_pdfs: @@ -581,29 +555,22 @@ def filter_relevant_examples(self, annotation_df, relevant_df): if len(annotation_for_pdf) == 0: continue - pages = list( - annotation_for_pdf['source_page'].unique() - ) + pages = list(annotation_for_pdf["source_page"].unique()) neg_examples_df = relevant_df[ - (relevant_df['pdf_name'] == pdf_file)\ - & ~(relevant_df['page_viewer'].isin(pages)) + (relevant_df["pdf_name"] == pdf_file) & ~(relevant_df["page_viewer"].isin(pages)) ] - questions = annotation_for_pdf['question'].tolist() - answers = annotation_for_pdf['answer'].astype(str).tolist() + questions = annotation_for_pdf["question"].tolist() + answers = annotation_for_pdf["answer"].astype(str).tolist() # This is an extra step to make sure the negative examples do not # contain the answer of a question. for q, a in zip(questions, answers): neg_examples_df = neg_examples_df[ ~( - (neg_examples_df['text'] == q) \ - & ( - neg_examples_df['text_b'].map( - lambda x: self.clean_text(a) in x - ) - ) + (neg_examples_df["text"] == q) + & (neg_examples_df["text_b"].map(lambda x: self.clean_text(a) in x)) ) ] diff --git a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/config/config.py b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/config/config.py index bde4d78..c24577c 100644 --- a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/config/config.py +++ b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/config/config.py @@ -10,10 +10,17 @@ ANNOTATION_FOLDER = ROOT EXTRACTION_FOLDER = ROOT COLUMNS_TO_READ = [ - 'company', 'source_file', 'source_page', 'kpi_id', 'year', 'answer', - 'data_type', 'relevant_paragraphs' + "company", + "source_file", + "source_page", + "kpi_id", + "year", + "answer", + "data_type", + "relevant_paragraphs", ] + class CurateConfig: def __init__(self): self.val_ratio = 0 @@ -21,12 +28,12 @@ def __init__(self): self.find_new_answerable = True self.create_unanswerable = True + # Text KPI Inference Curator TextKPIInferenceCurator_kwargs = { "annotation_folder": ANNOTATION_FOLDER, "agg_annotation": DATA_FOLDER / "aggregated_annotation.csv", "extracted_text_json_folder": EXTRACTION_FOLDER, "output_squad_folder": DATA_FOLDER, - "relevant_text_path": DATA_FOLDER / "text_3434.csv" + "relevant_text_path": DATA_FOLDER / "text_3434.csv", } - diff --git a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/config/logging_config.py b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/config/logging_config.py index 6b0f25a..2c820e6 100644 --- a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/config/logging_config.py +++ b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/config/logging_config.py @@ -1,10 +1,8 @@ import logging import sys -FORMATTER = logging.Formatter( - "%(asctime)s — %(name)s — %(levelname)s —" - "%(funcName)s:%(lineno)d — %(message)s" -) +FORMATTER = logging.Formatter("%(asctime)s — %(name)s — %(levelname)s —" "%(funcName)s:%(lineno)d — %(message)s") + def get_console_handler(): console_handler = logging.StreamHandler(sys.stdout) diff --git a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/utils/kpi_mapping.py b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/utils/kpi_mapping.py index 42b59a3..c195134 100644 --- a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/utils/kpi_mapping.py +++ b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/utils/kpi_mapping.py @@ -4,18 +4,15 @@ try: df = pd.read_csv("/app/code/kpi_mapping.csv", header=0) - _KPI_MAPPING = {str(i[0]): i[1] for i in df[['kpi_id', 'question']].values} + _KPI_MAPPING = {str(i[0]): i[1] for i in df[["kpi_id", "question"]].values} KPI_MAPPING = {(float(key)): value for key, value in _KPI_MAPPING.items()} # Which questions should be added the year - ADD_YEAR = df[df['add_year']].kpi_id.tolist() + ADD_YEAR = df[df["add_year"]].kpi_id.tolist() # Category where the answer to the question should originate from - KPI_CATEGORY = { - i[0]: [j.strip() for j in i[1].split(', ')] for i in df[['kpi_id', 'kpi_category']].values - } + KPI_CATEGORY = {i[0]: [j.strip() for j in i[1].split(", ")] for i in df[["kpi_id", "kpi_category"]].values} except: KPI_MAPPING = {} KPI_CATEGORY = {} ADD_YEAR = [] - diff --git a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/utils/utils.py b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/utils/utils.py index 2ae76bb..9d1a3a1 100644 --- a/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/utils/utils.py +++ b/data_extractor/code/kpi_inference_data_pipeline/kpi_inference_data_pipeline/utils/utils.py @@ -6,8 +6,9 @@ logger = logging.getLogger(__name__) + def aggregate_csvs(annotation_folder): - csvs = [f for f in os.listdir(annotation_folder) if f.endswith('.csv')] + csvs = [f for f in os.listdir(annotation_folder) if f.endswith(".csv")] dfs = [] @@ -15,22 +16,22 @@ def aggregate_csvs(annotation_folder): fname = os.path.join(annotation_folder, f) df = pd.read_csv(fname, header=0) cols = df.columns - assert(all([e in cols for e in config.COLUMNS_TO_READ])), \ - "{} doesn't have certain columns {}".format(f, config.COLUMNS_TO_READ) - + assert all([e in cols for e in config.COLUMNS_TO_READ]), "{} doesn't have certain columns {}".format( + f, config.COLUMNS_TO_READ + ) - if 'Sector' in cols: - df.rename({'Sector':'sector'}, axis=1, inplace=True) - columns_to_read = config.COLUMNS_TO_READ + ['sector'] - elif 'sector' in cols: - columns_to_read = config.COLUMNS_TO_READ + ['sector'] + if "Sector" in cols: + df.rename({"Sector": "sector"}, axis=1, inplace=True) + columns_to_read = config.COLUMNS_TO_READ + ["sector"] + elif "sector" in cols: + columns_to_read = config.COLUMNS_TO_READ + ["sector"] else: logger.info("{} has no column Sector/sector".format(f)) - if 'annotator' in cols: - columns_to_read += ['annotator'] + if "annotator" in cols: + columns_to_read += ["annotator"] else: - df['annotator'] = f + df["annotator"] = f df = df[columns_to_read] dfs.append(df) @@ -39,8 +40,9 @@ def aggregate_csvs(annotation_folder): return df -def clean_annotation(df, save_path, exclude=['CEZ']): - """ Returns a clean dataframe and save it after + +def clean_annotation(df, save_path, exclude=["CEZ"]): + """Returns a clean dataframe and save it after 1. dropping all NaN rows 2. dropping rows which has NaN values in some of the columns (refer below) @@ -57,14 +59,12 @@ def clean_annotation(df, save_path, exclude=['CEZ']): """ # dropping all nan rows - df = df.dropna(axis=0, how='all').reset_index(drop=True) + df = df.dropna(axis=0, how="all").reset_index(drop=True) # Drop rows with NaN for any of these columns except answer, relevant_paragraphs - df = df.dropna( - axis=0, - how='any', - subset=['company', 'source_file', 'source_page', 'kpi_id', 'year'] - ).reset_index(drop=True) + df = df.dropna(axis=0, how="any", subset=["company", "source_file", "source_page", "kpi_id", "year"]).reset_index( + drop=True + ) # Remove template company if exclude != []: @@ -72,20 +72,20 @@ def clean_annotation(df, save_path, exclude=['CEZ']): # Get pdf filename right (don't need to make it a class method) def get_pdf_name_right(f): - if not f.endswith('.pdf'): - if f.endswith(',pdf'): - filename = f.split(',pdf')[0].strip() + '.pdf' + if not f.endswith(".pdf"): + if f.endswith(",pdf"): + filename = f.split(",pdf")[0].strip() + ".pdf" else: - filename = f.strip() + '.pdf' + filename = f.strip() + ".pdf" else: - filename = f.split('.pdf')[0].strip() + '.pdf' + filename = f.split(".pdf")[0].strip() + ".pdf" return filename - df['source_file'] = df['source_file'].apply(get_pdf_name_right) + df["source_file"] = df["source_file"].apply(get_pdf_name_right) # clean data type - df['data_type'] = df['data_type'].apply(str.strip) + df["data_type"] = df["data_type"].apply(str.strip) # Remove examples where source_page can't be parsed def clean_page(sp): @@ -95,28 +95,27 @@ def clean_page(sp): # Covers multi pages and fix cases like '02' return [str(int(i)) for i in sp[1:-1].split(",")] - temp = df['source_page'].apply(clean_page) - invalid_source_page = df['source_page'][temp.isna()].unique().tolist() + temp = df["source_page"].apply(clean_page) + invalid_source_page = df["source_page"][temp.isna()].unique().tolist() if len(invalid_source_page) != 0: logger.warning( "Has invalid source_page format: {} and {} such examples".format( - df['source_page'][temp.isna()].unique(), - df['source_page'][temp.isna()].shape[0] + df["source_page"][temp.isna()].unique(), df["source_page"][temp.isna()].shape[0] ) ) - df['source_page'] = temp - df = df.dropna(axis=0, subset=['source_page']).reset_index(drop=True) + df["source_page"] = temp + df = df.dropna(axis=0, subset=["source_page"]).reset_index(drop=True) # Remove examples with incorrect kpi-data_type pair def clean_id(r): try: - kpi_id = float(r['kpi_id']) + kpi_id = float(r["kpi_id"]) except ValueError: - kpi_id = r['kpi_id'] + kpi_id = r["kpi_id"] try: - if r['data_type'] in KPI_CATEGORY[kpi_id]: + if r["data_type"] in KPI_CATEGORY[kpi_id]: cat = True else: cat = False @@ -125,14 +124,11 @@ def clean_id(r): return cat - correct_id_bool = df[['kpi_id', 'data_type']].apply(clean_id, axis=1) + correct_id_bool = df[["kpi_id", "data_type"]].apply(clean_id, axis=1) df = df[correct_id_bool].reset_index(drop=True) diff = correct_id_bool.shape[0] - df.shape[0] if diff > 0: - logger.info( - "Drop {} examples due to incorrect kpi-data_type pair"\ - .format(diff) - ) + logger.info("Drop {} examples due to incorrect kpi-data_type pair".format(diff)) df.to_csv(save_path) logger.info("{} is created.".format(save_path)) diff --git a/data_extractor/code/kpi_inference_data_pipeline/setup.py b/data_extractor/code/kpi_inference_data_pipeline/setup.py index 83bcf28..35ca3d9 100644 --- a/data_extractor/code/kpi_inference_data_pipeline/setup.py +++ b/data_extractor/code/kpi_inference_data_pipeline/setup.py @@ -4,34 +4,36 @@ from setuptools import find_packages, setup -NAME = 'kpi_inference_data_pipeline' -DESCRIPTION = 'Read, clean and save data as squad format for KPI Inference' -AUTHOR = '1QBit NLP' -REQUIRES_PYTHON = '>=3.6.0' +NAME = "kpi_inference_data_pipeline" +DESCRIPTION = "Read, clean and save data as squad format for KPI Inference" +AUTHOR = "1QBit NLP" +REQUIRES_PYTHON = ">=3.6.0" -def list_reqs(fname='requirements.txt'): + +def list_reqs(fname="requirements.txt"): with open(fname) as fd: return fd.read().splitlines() + here = os.path.abspath(os.path.dirname(__file__)) # Load the package's __version__.py module as a dictionary. ROOT_DIR = Path(__file__).resolve().parent PACKAGE_DIR = ROOT_DIR / NAME about = {} -with open(PACKAGE_DIR / 'VERSION') as f: +with open(PACKAGE_DIR / "VERSION") as f: _version = f.read().strip() - about['__version__'] = _version + about["__version__"] = _version setup( name=NAME, - version=about['__version__'], + version=about["__version__"], description=DESCRIPTION, author=AUTHOR, python_requires=REQUIRES_PYTHON, - packages=find_packages(exclude=('tests', 'notebooks')), - package_data={'kpi_inference_data_pipeline': ['VERSION']}, + packages=find_packages(exclude=("tests", "notebooks")), + package_data={"kpi_inference_data_pipeline": ["VERSION"]}, install_requires=list_reqs(), extras_require={}, - include_package_data=True + include_package_data=True, ) diff --git a/data_extractor/code/model_pipeline/metrics_per_kpi.py b/data_extractor/code/model_pipeline/metrics_per_kpi.py index c724fd0..a323eff 100644 --- a/data_extractor/code/model_pipeline/metrics_per_kpi.py +++ b/data_extractor/code/model_pipeline/metrics_per_kpi.py @@ -68,7 +68,7 @@ def single_result_to_label(single_result): def keyfunc(x): - q = x['predictions'][0]['question'].split(" in year")[0] + q = x["predictions"][0]["question"].split(" in year")[0] if not q.endswith("?"): q = q + "?" return q @@ -82,9 +82,7 @@ def keyfunc(x): infer_config = QAInferConfig() model_path = file_config.saved_models_dir -model = QAInferencer.load( - model_path, batch_size=infer_config.batch_size, gpu=torch.cuda.is_available() -) +model = QAInferencer.load(model_path, batch_size=infer_config.batch_size, gpu=torch.cuda.is_available()) dev_file_path = file_config.dev_filename nested_dicts = read_squad_file(filename=dev_file_path) @@ -107,7 +105,6 @@ def keyfunc(x): fail_counter += 1 continue for res in results: - # Converting the ground truth character indices to token indices ground_truth = res["predictions"][0]["ground_truth"] gt_start_char_idx = ground_truth[0]["offset"] diff --git a/data_extractor/code/model_pipeline/model_pipeline/__init__.py b/data_extractor/code/model_pipeline/model_pipeline/__init__.py index 4418df7..bf4bbb5 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/__init__.py +++ b/data_extractor/code/model_pipeline/model_pipeline/__init__.py @@ -5,7 +5,7 @@ ModelConfig, MLFlowConfig, ProcessorConfig, - InferConfig + InferConfig, ) from .config_qa_farm_train import ( QAFileConfig, @@ -14,7 +14,7 @@ QAModelConfig, QAMLFlowConfig, QAProcessorConfig, - QAInferConfig + QAInferConfig, ) from .farm_trainer import FARMTrainer from .qa_farm_trainer import QAFARMTrainer diff --git a/data_extractor/code/model_pipeline/model_pipeline/config_farm_train.py b/data_extractor/code/model_pipeline/model_pipeline/config_farm_train.py index 8a4f329..80465e1 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/config_farm_train.py +++ b/data_extractor/code/model_pipeline/model_pipeline/config_farm_train.py @@ -7,82 +7,80 @@ _logger = getLogger(__name__) LOGGING_MAPPING = {"info": INFO, "warning": WARNING, "debug": DEBUG} -class Config: +class Config: def __init__(self, project_name, output_model_name=None, experiment_type="RELEVANCE", data_type="Text"): - self.root = str(pathlib.Path(__file__).resolve().parent.parent.parent.parent) + self.root = str(pathlib.Path(__file__).resolve().parent.parent.parent.parent) self.experiment_type = experiment_type self.output_model_name = output_model_name - self.experiment_name = project_name + self.experiment_name = project_name self.data_type = data_type # Text | Table farm_infer_logging_level = "warning" # FARM logging level during inference; supports info, warning, debug self.farm_infer_logging_level = LOGGING_MAPPING[farm_infer_logging_level] class FileConfig(Config): - - def __init__(self,project_name, output_model_name): + def __init__(self, project_name, output_model_name): super().__init__(project_name, output_model_name) self.data_dir = os.path.join(self.root, "data") self.annotation_dir = os.path.join(self.data_dir, self.experiment_name, "interim", "ml", "annotations") - self.curated_data = os.path.join(self.data_dir, self.experiment_name, "interim", "ml", "curation", "esg_TEXT_dataset.csv") - self.training_dir = os.path.join(self.data_dir, self.experiment_name, "interim", "ml", "training") + self.curated_data = os.path.join( + self.data_dir, self.experiment_name, "interim", "ml", "curation", "esg_TEXT_dataset.csv" + ) + self.training_dir = os.path.join(self.data_dir, self.experiment_name, "interim", "ml", "training") self.train_filename = os.path.join(self.training_dir, f"kpi_train_split.csv") self.dev_filename = os.path.join(self.training_dir, f"kpi_val_split.csv") self.test_filename = None - #The next defines the folder where the trained relevance model is stored to - self.saved_models_dir = os.path.join(self.root, "models", self.experiment_name, self.experiment_type, self.data_type, self.output_model_name) - + # The next defines the folder where the trained relevance model is stored to + self.saved_models_dir = os.path.join( + self.root, "models", self.experiment_name, self.experiment_type, self.data_type, self.output_model_name + ) -class TokenizerConfig(Config): - def __init__(self,project_name): +class TokenizerConfig(Config): + def __init__(self, project_name): super().__init__(project_name) self.pretrained_model_name_or_path = "roberta-base" self.do_lower_case = False class ProcessorConfig(Config): - - def __init__(self,project_name): + def __init__(self, project_name): super().__init__(project_name) if self.experiment_type == "RELEVANCE": self.processor_name = "TextPairClassificationProcessor" else: raise ValueError("No existing processor for this task") - self.load_dir = os.path.join(self.root, "models", "base" , "relevance_roberta") + self.load_dir = os.path.join(self.root, "models", "base", "relevance_roberta") # set to None if you don't want to load the\ # vocab.json file self.max_seq_len = 512 - self.dev_split = .2 + self.dev_split = 0.2 self.label_list = ["0", "1"] self.label_column_name = "label" # label column name in data files - self.delimiter = ',' + self.delimiter = "," self.metric = "acc" class ModelConfig(Config): - - def __init__(self,project_name): + def __init__(self, project_name): super().__init__(project_name) if self.experiment_type == "RELEVANCE": self.class_type = TextClassificationHead - self.head_config = { - "num_labels": 2 - } + self.head_config = {"num_labels": 2} else: raise ValueError("No existing model for this task") # set to None if you don't want to load the config file for this model - self.load_dir = os.path.join(self.root, "models", "base", - "relevance_roberta") # relevance_roberta | relevance_roberta_table_headers + self.load_dir = os.path.join( + self.root, "models", "base", "relevance_roberta" + ) # relevance_roberta | relevance_roberta_table_headers self.lang_model = "roberta-base" self.layer_dims = [768, 2] self.lm_output_types = ["per_sequence"] # or ["per_tokens"] class TrainingConfig(Config): - - def __init__(self,project_name, seed): + def __init__(self, project_name, seed): super().__init__(project_name) self.seed = seed @@ -107,8 +105,7 @@ def __init__(self,project_name, seed): class MLFlowConfig(Config): - - def __init__(self,project_name): + def __init__(self, project_name): super().__init__(project_name) self.track_experiment = False self.run_name = self.experiment_name @@ -116,12 +113,15 @@ def __init__(self,project_name): class InferConfig(Config): - def __init__(self, project_name, output_model_name): super().__init__(project_name, output_model_name) # please change the following accordingly - self.data_types = ['Text'] # ["Text", "Table"] supported "Text", "Table" - self.load_dir = {"Text": os.path.join(self.root, "models", self.experiment_name, self.experiment_type, self.data_type, self.output_model_name)} + self.data_types = ["Text"] # ["Text", "Table"] supported "Text", "Table" + self.load_dir = { + "Text": os.path.join( + self.root, "models", self.experiment_name, self.experiment_type, self.data_type, self.output_model_name + ) + } # Use the following for the pre-trained models inside Docker # oneqbit_checkpoint_dir = os.path.join(self.root, "model_pipeline", "saved_models", "1QBit_Pretrained_ESG") @@ -130,12 +130,16 @@ def __init__(self, project_name, output_model_name): self.skip_processed_files = True # If set to True, will skip inferring on already processed files self.batch_size = 16 self.gpu = True - self.num_processes = None # Set to value of 1 (or 0) to disable multiprocessing. - # Set to None to let Inferencer use all CPU cores minus one. - self.disable_tqdm = True # To not see the progress bar at inference time, set to True - self.extracted_dir = os.path.join(self.root, "data", self.experiment_name, "interim", "ml", "extraction") - self.result_dir = {"Text": os.path.join(self.root, "data", self.experiment_name, "output", self.experiment_type, self.data_type)} + self.num_processes = None # Set to value of 1 (or 0) to disable multiprocessing. + # Set to None to let Inferencer use all CPU cores minus one. + self.disable_tqdm = True # To not see the progress bar at inference time, set to True + self.extracted_dir = os.path.join(self.root, "data", self.experiment_name, "interim", "ml", "extraction") + self.result_dir = { + "Text": os.path.join( + self.root, "data", self.experiment_name, "output", self.experiment_type, self.data_type + ) + } self.kpi_questions = [] # set to ["OG", "CM", "CU"] for KPIs of all sectors. - self.sectors = ["OG", "CM", "CU"] #["UT"] + self.sectors = ["OG", "CM", "CU"] # ["UT"] self.return_class_probs = False diff --git a/data_extractor/code/model_pipeline/model_pipeline/config_qa_farm_train.py b/data_extractor/code/model_pipeline/model_pipeline/config_qa_farm_train.py index 2f42aae..034433a 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/config_qa_farm_train.py +++ b/data_extractor/code/model_pipeline/model_pipeline/config_qa_farm_train.py @@ -16,29 +16,33 @@ class QAConfig(Config): def __init__(self, project_name, output_model_name=None): - super().__init__(experiment_type="KPI_EXTRACTION", project_name=project_name, output_model_name=output_model_name) + super().__init__( + experiment_type="KPI_EXTRACTION", project_name=project_name, output_model_name=output_model_name + ) class QAFileConfig(QAConfig): - def __init__(self, project_name, output_model_name): super().__init__(project_name, output_model_name) self.data_dir = os.path.join(self.root, "data") self.curated_data = os.path.join(self.data_dir, project_name, "interim", "ml", "training", "kpi_train.json") # If True, curated data will be split by dev_split ratio to train and val and saved in train_filename, # dev_filename . Otherwise train and val data will be loaded from mentioned filenames. - self.perform_splitting = True #was False initially - self.dev_split = .2 - self.train_filename = os.path.join(self.data_dir, project_name, "interim","ml", "training", "kpi_train_split.json") - self.dev_filename = os.path.join(self.data_dir, project_name, "interim","ml","training", "kpi_val_split.json") + self.perform_splitting = True # was False initially + self.dev_split = 0.2 + self.train_filename = os.path.join( + self.data_dir, project_name, "interim", "ml", "training", "kpi_train_split.json" + ) + self.dev_filename = os.path.join(self.data_dir, project_name, "interim", "ml", "training", "kpi_val_split.json") self.test_filename = None - self.saved_models_dir = os.path.join(self.root, "models", project_name, self.experiment_type, self.data_type, self.output_model_name) + self.saved_models_dir = os.path.join( + self.root, "models", project_name, self.experiment_type, self.data_type, self.output_model_name + ) self.annotation_dir = os.path.join(self.data_dir, self.experiment_name, "interim", "ml", "annotations") - self.training_dir = os.path.join(self.data_dir, self.experiment_name, "interim", "ml", "training") + self.training_dir = os.path.join(self.data_dir, self.experiment_name, "interim", "ml", "training") class QATokenizerConfig(QAConfig): - def __init__(self, project_name): super().__init__(project_name) self.pretrained_model_name_or_path = base_LM_model @@ -46,7 +50,6 @@ def __init__(self, project_name): class QAProcessorConfig(QAConfig): - def __init__(self, project_name): super().__init__(project_name) self.processor_name = "SquadProcessor" @@ -56,21 +59,19 @@ def __init__(self, project_name): class QAModelConfig(QAConfig): - def __init__(self, project_name): super().__init__(project_name) self.class_type = QuestionAnsweringHead self.head_config = {} # set to None if you don't want to load the config file for this model - self.load_dir = None #TODO: Should this really be None ? + self.load_dir = None # TODO: Should this really be None ? self.lang_model = base_LM_model self.layer_dims = [768, 2] self.lm_output_types = ["per_token"] class QATrainingConfig(QAConfig): - def __init__(self, project_name, seed): super().__init__(project_name) self.seed = seed @@ -97,7 +98,6 @@ def __init__(self, project_name, seed): class QAMLFlowConfig(QAConfig): - def __init__(self, project_name): super().__init__(project_name) self.track_experiment = False @@ -106,19 +106,24 @@ def __init__(self, project_name): class QAInferConfig(QAConfig): - def __init__(self, project_name, output_model_name): super().__init__(project_name, output_model_name) # please change the following accordingly self.data_types = ["Text"] self.skip_processed_files = False # If set to True, will skip inferring on already processed files self.top_k = 4 - self.result_dir = {"Text": os.path.join(self.root, "data", project_name, "output", self.experiment_type, "ml", "Text")} + self.result_dir = { + "Text": os.path.join(self.root, "data", project_name, "output", self.experiment_type, "ml", "Text") + } # Parameters for text inference - self.load_dir = {"Text": os.path.join(self.root, "models", project_name, self.experiment_type, "Text", self.output_model_name)} + self.load_dir = { + "Text": os.path.join( + self.root, "models", project_name, self.experiment_type, "Text", self.output_model_name + ) + } self.batch_size = 16 self.gpu = True # Set to value 1 (or 0) to disable multiprocessing. Set to None to let Inferencer use all CPU cores minus one. self.num_processes = None - self.no_ans_boost = -15 # If increased, this will boost "No Answer" as prediction. + self.no_ans_boost = -15 # If increased, this will boost "No Answer" as prediction. # use large negative values (like -100) to disable giving "No answer" option. diff --git a/data_extractor/code/model_pipeline/model_pipeline/farm_trainer.py b/data_extractor/code/model_pipeline/model_pipeline/farm_trainer.py index 99f576f..bf2ffb7 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/farm_trainer.py +++ b/data_extractor/code/model_pipeline/model_pipeline/farm_trainer.py @@ -54,9 +54,7 @@ def prepare_data(self): Split data between training set and development set according to split ratio and save sets to .csv files """ - if os.path.exists(self.file_config.train_filename) and os.path.exists( - self.file_config.dev_filename - ): + if os.path.exists(self.file_config.train_filename) and os.path.exists(self.file_config.dev_filename): pass data = pd.read_csv(self.file_config.curated_data) @@ -67,9 +65,7 @@ def prepare_data(self): data.dropna(how="any", inplace=True) data.drop_duplicates(inplace=True) data = shuffle(data) - data_train, data_dev = train_test_split( - data, test_size=self.processor_config.dev_split - ) + data_train, data_dev = train_test_split(data, test_size=self.processor_config.dev_split) data_train.to_csv(self.file_config.train_filename) data_dev.to_csv(self.file_config.dev_filename) @@ -114,7 +110,7 @@ def create_silo(self, processor): processor=processor, batch_size=self.training_config.batch_size, distributed=self.training_config.distributed, - max_processes=self.training_config.max_processes + max_processes=self.training_config.max_processes, ) n_batches = len(data_silo.loaders["train"]) return data_silo, n_batches @@ -163,7 +159,7 @@ def create_model(self, prediction_head, n_batches, device): device=device, n_batches=n_batches, n_epochs=self.training_config.n_epochs, - grad_acc_steps=self.training_config.grad_acc_steps + grad_acc_steps=self.training_config.grad_acc_steps, ) return model, optimizer, lr_schedule @@ -206,7 +202,7 @@ def create_trainer(self, model, optimizer, lr_schedule, data_silo, device, n_gpu lr_schedule=lr_schedule, evaluate_every=self.training_config.evaluate_every, device=device, - grad_acc_steps=self.training_config.grad_acc_steps + grad_acc_steps=self.training_config.grad_acc_steps, ) return trainer @@ -226,12 +222,8 @@ def _train_on_split(self, data_silo, silo_to_use, num_fold, device, n_gpu): model, optimizer, lr_schedule = self.create_model( prediction_head, n_batches=len(silo_to_use.loaders["train"]), device=device ) - model.connect_heads_with_processor( - data_silo.processor.tasks, require_labels=True - ) - trainer = self.create_trainer( - model, optimizer, lr_schedule, silo_to_use, device, n_gpu - ) + model.connect_heads_with_processor(data_silo.processor.tasks, require_labels=True) + trainer = self.create_trainer(model, optimizer, lr_schedule, silo_to_use, device, n_gpu) trainer.train() return trainer.model @@ -260,9 +252,7 @@ def run_cv(self, data_silo, xval_folds, device, n_gpu): all_recall = [] all_accuracy = [] all_precision = [] - silos = DataSiloForCrossVal.make( - data_silo, sets=["train", "dev"], n_splits=xval_folds - ) + silos = DataSiloForCrossVal.make(data_silo, sets=["train", "dev"], n_splits=xval_folds) for num_fold, silo in enumerate(silos): model = self._train_on_split(data_silo, silo, num_fold, device, n_gpu) @@ -281,21 +271,11 @@ def run_cv(self, data_silo, xval_folds, device, n_gpu): all_recall.append(recall_score(preds, labels)) all_accuracy.append(result[0]["acc"]) all_precision.append(precision_score(preds, labels)) - _logger.info( - f"############ RESULT_CV -- {self.training_config.xval_folds} folds ############" - ) - _logger.info( - f"Mean F1: {np.mean(all_f1)*100:.1f}, std F1: {np.std(all_f1):.3f}" - ) - _logger.info( - f"Mean recall: {np.mean(all_recall)*100:.1f}, std recall: {np.std(all_recall):.3f}" - ) - _logger.info( - f"Mean accuracy: {np.mean(all_accuracy)*100:.1f}, std accuracy; {np.std(all_accuracy):.3f}" - ) - _logger.info( - f"Mean precision: {np.mean(all_precision)*100:.1f}, std precision: {np.std(all_precision):.3f}" - ) + _logger.info(f"############ RESULT_CV -- {self.training_config.xval_folds} folds ############") + _logger.info(f"Mean F1: {np.mean(all_f1)*100:.1f}, std F1: {np.std(all_f1):.3f}") + _logger.info(f"Mean recall: {np.mean(all_recall)*100:.1f}, std recall: {np.std(all_recall):.3f}") + _logger.info(f"Mean accuracy: {np.mean(all_accuracy)*100:.1f}, std accuracy; {np.std(all_accuracy):.3f}") + _logger.info(f"Mean precision: {np.mean(all_precision)*100:.1f}, std precision: {np.std(all_precision):.3f}") def run(self, trial=None): """ @@ -330,12 +310,8 @@ def run(self, trial=None): if self.training_config.run_hyp_tuning: prediction_head = self.create_head() - model, optimizer, lr_schedule = self.create_model( - prediction_head, n_batches, device - ) - trainer = self.create_trainer( - model, optimizer, lr_schedule, data_silo, device, n_gpu - ) + model, optimizer, lr_schedule = self.create_model(prediction_head, n_batches, device) + trainer = self.create_trainer(model, optimizer, lr_schedule, data_silo, device, n_gpu) trainer.train(trial) evaluator_dev = Evaluator( data_loader=data_silo.get_data_loader("dev"), @@ -343,21 +319,15 @@ def run(self, trial=None): device=device, ) result = evaluator_dev.eval(model, return_preds_and_labels=True) - evaluator_dev.log_results( - result, "DEV", logging=True, steps=len(data_silo.get_data_loader("dev")) - ) + evaluator_dev.log_results(result, "DEV", logging=True, steps=len(data_silo.get_data_loader("dev"))) elif self.training_config.run_cv: self.run_cv(data_silo, self.training_config.xval_folds, device, n_gpu) else: prediction_head = self.create_head() - model, optimizer, lr_schedule = self.create_model( - prediction_head, n_batches, device - ) - trainer = self.create_trainer( - model, optimizer, lr_schedule, data_silo, device, n_gpu - ) + model, optimizer, lr_schedule = self.create_model(prediction_head, n_batches, device) + trainer = self.create_trainer(model, optimizer, lr_schedule, data_silo, device, n_gpu) trainer.train() evaluator_dev = Evaluator( @@ -366,16 +336,12 @@ def run(self, trial=None): device=device, ) result = evaluator_dev.eval(model, return_preds_and_labels=True) - evaluator_dev.log_results( - result, "DEV", logging=True, steps=len(data_silo.get_data_loader("dev")) - ) + evaluator_dev.log_results(result, "DEV", logging=True, steps=len(data_silo.get_data_loader("dev"))) result = self.post_process_dev_results(result) model.save(self.file_config.saved_models_dir) processor.save(self.file_config.saved_models_dir) _logger.info(f"Trained model saved to {self.file_config.saved_models_dir}") - _logger.info( - f"Processor vocabulary saved to {self.file_config.saved_models_dir}" - ) + _logger.info(f"Processor vocabulary saved to {self.file_config.saved_models_dir}") return result diff --git a/data_extractor/code/model_pipeline/model_pipeline/inference_server.py b/data_extractor/code/model_pipeline/model_pipeline/inference_server.py index 025f3fe..f0115a5 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/inference_server.py +++ b/data_extractor/code/model_pipeline/model_pipeline/inference_server.py @@ -16,10 +16,22 @@ from model_pipeline.relevance_infer import TextRelevanceInfer from model_pipeline.text_kpi_infer import TextKPIInfer -from model_pipeline.config_farm_train import ModelConfig, TrainingConfig, FileConfig, MLFlowConfig, TokenizerConfig, \ - ProcessorConfig -from model_pipeline.config_qa_farm_train import QAModelConfig, QATrainingConfig, QAFileConfig, QAMLFlowConfig, \ - QATokenizerConfig, QAProcessorConfig +from model_pipeline.config_farm_train import ( + ModelConfig, + TrainingConfig, + FileConfig, + MLFlowConfig, + TokenizerConfig, + ProcessorConfig, +) +from model_pipeline.config_qa_farm_train import ( + QAModelConfig, + QATrainingConfig, + QAFileConfig, + QAMLFlowConfig, + QATokenizerConfig, + QAProcessorConfig, +) from model_pipeline.farm_trainer import FARMTrainer from model_pipeline.qa_farm_trainer import QAFARMTrainer from kpi_inference_data_pipeline import TextKPIInferenceCurator @@ -50,16 +62,14 @@ def create_directory(directory_name): if os.path.isfile(file_path): os.unlink(file_path) except Exception as e: - print('Failed to delete %s. Reason: %s' % (file_path, e)) + print("Failed to delete %s. Reason: %s" % (file_path, e)) def zipdir(path, ziph): # ziph is zipfile handle for root, dirs, files in os.walk(path): for file in files: - ziph.write(os.path.join(root, file), - os.path.relpath(os.path.join(root, file), - os.path.join(path, '..'))) + ziph.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(path, ".."))) @app.route("/liveness") @@ -67,13 +77,13 @@ def liveness(): return Response(response={}, status=200) -@app.route('/train_relevance/') +@app.route("/train_relevance/") def run_train_relevance(): - args = json.loads(request.args['payload']) + args = json.loads(request.args["payload"]) project_name = args["project_name"] relevance_training_settings = args["train_relevance"] - relevance_training_ouput_model_name = relevance_training_settings['output_model_name'] + relevance_training_ouput_model_name = relevance_training_settings["output_model_name"] free_memory() try: t1 = time.time() @@ -85,45 +95,51 @@ def run_train_relevance(): processor_config = ProcessorConfig(project_name) tokenizer_config = TokenizerConfig(project_name) - #Change the default settings + # Change the default settings processor_config.max_seq_len = relevance_training_settings["processor"]["proc_max_seq_len"] processor_config.dev_split = relevance_training_settings["processor"]["proc_dev_split"] processor_config.label_list = relevance_training_settings["processor"]["proc_label_list"] processor_config.label_column_name = relevance_training_settings["processor"]["proc_label_column_name"] processor_config.delimiter = relevance_training_settings["processor"]["proc_delimiter"] processor_config.metric = relevance_training_settings["processor"]["proc_metric"] - + model_config.layer_dims = relevance_training_settings["model"]["model_layer_dims"] model_config.layer_dims = relevance_training_settings["model"]["model_lm_output_types"] - + s3_usage = args["s3_usage"] if s3_usage: s3_settings = args["s3_settings"] - project_prefix_data = s3_settings['prefix'] + "/" + project_name + '/data' - project_prefix_project_models = s3_settings['prefix'] + "/" + project_name + '/models' + project_prefix_data = s3_settings["prefix"] + "/" + project_name + "/data" + project_prefix_project_models = s3_settings["prefix"] + "/" + project_name + "/models" # init s3 connector s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), ) create_directory(os.path.dirname(file_config.curated_data)) - s3c_interim.download_files_in_prefix_to_dir(project_prefix_data + '/interim/ml/curation', - os.path.dirname(file_config.curated_data)) - - training_folder = DATA_FOLDER / project_name / 'interim/ml/training' + s3c_interim.download_files_in_prefix_to_dir( + project_prefix_data + "/interim/ml/curation", os.path.dirname(file_config.curated_data) + ) + + training_folder = DATA_FOLDER / project_name / "interim/ml/training" create_directory(training_folder) - output_model_folder = os.path.join(str(MODEL_FOLDER), file_config.experiment_name, file_config.experiment_type, - file_config.data_type, file_config.output_model_name) + output_model_folder = os.path.join( + str(MODEL_FOLDER), + file_config.experiment_name, + file_config.experiment_type, + file_config.data_type, + file_config.output_model_name, + ) create_directory(output_model_folder) - + if relevance_training_settings["input_model_name"] is not None: base_model_dir = os.path.join(str(MODEL_FOLDER), project_name, "RELEVANCE", "Text") model_dir = os.path.join(base_model_dir, relevance_training_settings["input_model_name"]) @@ -131,12 +147,16 @@ def run_train_relevance(): tokenizer_config.pretrained_model_name_or_path = model_dir model_config.lang_model = model_dir processor_config.load_dir = model_dir - if args['s3_usage']: + if args["s3_usage"]: # Download model - model_rel_prefix = str(pathlib.Path(s3_settings['prefix']) / project_name / 'models' / 'RELEVANCE' / 'Text') + model_rel_prefix = str( + pathlib.Path(s3_settings["prefix"]) / project_name / "models" / "RELEVANCE" / "Text" + ) output_model_zip = model_dir + ".zip" - s3c_main.download_file_from_s3(output_model_zip, model_rel_prefix, relevance_training_settings["input_model_name"] + ".zip") - with zipfile.ZipFile(output_model_zip, 'r') as zip_ref: + s3c_main.download_file_from_s3( + output_model_zip, model_rel_prefix, relevance_training_settings["input_model_name"] + ".zip" + ) + with zipfile.ZipFile(output_model_zip, "r") as zip_ref: zip_ref.extractall(base_model_dir) os.remove(output_model_zip) else: @@ -144,7 +164,7 @@ def run_train_relevance(): model_config.load_dir = None tokenizer_config.pretrained_model_name_or_path = relevance_training_settings["base_model"] processor_config.load_dir = None - + train_config.n_epochs = relevance_training_settings["training"]["n_epochs"] train_config.run_hyp_tuning = relevance_training_settings["training"]["run_hyp_tuning"] train_config.use_amp = relevance_training_settings["training"]["use_amp"] @@ -154,7 +174,9 @@ def run_train_relevance(): train_config.dropout = relevance_training_settings["training"]["dropout"] train_config.batch_size = relevance_training_settings["training"]["batch_size"] train_config.grad_acc_steps = relevance_training_settings["training"]["grad_acc_steps"] - train_config.run_cv = relevance_training_settings["training"]["run_cv"] # running cross-validation won't save a model + train_config.run_cv = relevance_training_settings["training"][ + "run_cv" + ] # running cross-validation won't save a model train_config.xval_folds = relevance_training_settings["training"]["xval_folds"] train_config.max_processes = relevance_training_settings["training"]["max_processes"] @@ -165,31 +187,42 @@ def run_train_relevance(): model_config=model_config, processor_config=processor_config, training_config=train_config, - mlflow_config=mlflow_config + mlflow_config=mlflow_config, ) result = farm_trainer.run() - + # save results to json file - result.pop('preds', None) - result.pop('labels', None) + result.pop("preds", None) + result.pop("labels", None) for key in result: result[key] = str(result[key]) - name_out = os.path.join(str(MODEL_FOLDER), project_name, "result_rel_" + relevance_training_settings['output_model_name'] + ".json") - with open(name_out, 'w') as f: + name_out = os.path.join( + str(MODEL_FOLDER), project_name, "result_rel_" + relevance_training_settings["output_model_name"] + ".json" + ) + with open(name_out, "w") as f: json.dump(result, f) - + if s3_usage: - train_rel_prefix = os.path.join(project_prefix_project_models, file_config.experiment_type, file_config.data_type) - output_model_zip = os.path.join(str(MODEL_FOLDER), file_config.experiment_name, file_config.experiment_type, - file_config.data_type, file_config.output_model_name + ".zip") - with zipfile.ZipFile(output_model_zip, 'w', zipfile.ZIP_DEFLATED) as zipf: - zipdir(output_model_folder, zipf) - response = s3c_main.upload_file_to_s3(filepath=output_model_zip, - s3_prefix=train_rel_prefix, - s3_key=file_config.output_model_name + ".zip") - response_2 = s3c_main.upload_file_to_s3(filepath=name_out, - s3_prefix=train_rel_prefix, - s3_key="result_rel_" + file_config.output_model_name + ".json") + train_rel_prefix = os.path.join( + project_prefix_project_models, file_config.experiment_type, file_config.data_type + ) + output_model_zip = os.path.join( + str(MODEL_FOLDER), + file_config.experiment_name, + file_config.experiment_type, + file_config.data_type, + file_config.output_model_name + ".zip", + ) + with zipfile.ZipFile(output_model_zip, "w", zipfile.ZIP_DEFLATED) as zipf: + zipdir(output_model_folder, zipf) + response = s3c_main.upload_file_to_s3( + filepath=output_model_zip, s3_prefix=train_rel_prefix, s3_key=file_config.output_model_name + ".zip" + ) + response_2 = s3c_main.upload_file_to_s3( + filepath=name_out, + s3_prefix=train_rel_prefix, + s3_key="result_rel_" + file_config.output_model_name + ".json", + ) create_directory(output_model_folder) create_directory(training_folder) create_directory(os.path.dirname(file_config.curated_data)) @@ -200,78 +233,83 @@ def run_train_relevance(): msg = "Error during kpi infer stage\nException:" + str(repr(e) + traceback.format_exc()) return Response(msg, status=500) - time_elapsed = str(timedelta(seconds=t2-t1)) + time_elapsed = str(timedelta(seconds=t2 - t1)) msg = "Training for the relevance stage finished successfully!\nTime elapsed:{}".format(time_elapsed) return Response(msg, status=200) -@app.route('/infer_relevance/') +@app.route("/infer_relevance/") def run_infer_relevance(): - args = json.loads(request.args['payload']) + args = json.loads(request.args["payload"]) project_name = args["project_name"] infer_relevance_settings = args["infer_relevance"] - - relevance_infer_config = InferConfig(project_name, args["train_relevance"]['output_model_name']) - relevance_infer_config.skip_processed_files = infer_relevance_settings['skip_processed_files'] - relevance_infer_config.batch_size = infer_relevance_settings['batch_size'] - relevance_infer_config.gpu = infer_relevance_settings['gpu'] - relevance_infer_config.num_processes = infer_relevance_settings['num_processes'] - relevance_infer_config.disable_tqdm = infer_relevance_settings['disable_tqdm'] - relevance_infer_config.kpi_questions = infer_relevance_settings['kpi_questions'] - relevance_infer_config.sectors = infer_relevance_settings['sectors'] - relevance_infer_config.return_class_probs = infer_relevance_settings['return_class_probs'] - - BASE_DATA_PROJECT_FOLDER = DATA_FOLDER / project_name - BASE_INTERIM_FOLDER = BASE_DATA_PROJECT_FOLDER / 'interim' / 'ml' - BASE_OUTPUT_FOLDER = BASE_DATA_PROJECT_FOLDER / 'output' / relevance_infer_config.experiment_type / relevance_infer_config.data_type - ANNOTATION_FOLDER = BASE_INTERIM_FOLDER / 'annotations' - EXTRACTION_FOLDER = BASE_INTERIM_FOLDER / 'extraction' + + relevance_infer_config = InferConfig(project_name, args["train_relevance"]["output_model_name"]) + relevance_infer_config.skip_processed_files = infer_relevance_settings["skip_processed_files"] + relevance_infer_config.batch_size = infer_relevance_settings["batch_size"] + relevance_infer_config.gpu = infer_relevance_settings["gpu"] + relevance_infer_config.num_processes = infer_relevance_settings["num_processes"] + relevance_infer_config.disable_tqdm = infer_relevance_settings["disable_tqdm"] + relevance_infer_config.kpi_questions = infer_relevance_settings["kpi_questions"] + relevance_infer_config.sectors = infer_relevance_settings["sectors"] + relevance_infer_config.return_class_probs = infer_relevance_settings["return_class_probs"] + + BASE_DATA_PROJECT_FOLDER = DATA_FOLDER / project_name + BASE_INTERIM_FOLDER = BASE_DATA_PROJECT_FOLDER / "interim" / "ml" + BASE_OUTPUT_FOLDER = ( + BASE_DATA_PROJECT_FOLDER / "output" / relevance_infer_config.experiment_type / relevance_infer_config.data_type + ) + ANNOTATION_FOLDER = BASE_INTERIM_FOLDER / "annotations" + EXTRACTION_FOLDER = BASE_INTERIM_FOLDER / "extraction" kpi_folder = os.path.join(DATA_FOLDER, project_name, "input", "kpi_mapping") - output_model_folder = os.path.join(str(MODEL_FOLDER), project_name, relevance_infer_config.experiment_type, - relevance_infer_config.data_type) + output_model_folder = os.path.join( + str(MODEL_FOLDER), project_name, relevance_infer_config.experiment_type, relevance_infer_config.data_type + ) create_directory(kpi_folder) create_directory(output_model_folder) create_directory(BASE_OUTPUT_FOLDER) create_directory(ANNOTATION_FOLDER) create_directory(EXTRACTION_FOLDER) - + s3_usage = args["s3_usage"] if s3_usage: s3_settings = args["s3_settings"] - project_prefix_data = s3_settings['prefix'] + "/" + project_name + '/data' - project_prefix_project_models = s3_settings['prefix'] + "/" + project_name + '/models' + project_prefix_data = s3_settings["prefix"] + "/" + project_name + "/data" + project_prefix_project_models = s3_settings["prefix"] + "/" + project_name + "/models" # init s3 connector s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), ) # Download kpi file - s3c_main.download_files_in_prefix_to_dir(project_prefix_data + '/input/kpi_mapping', kpi_folder) + s3c_main.download_files_in_prefix_to_dir(project_prefix_data + "/input/kpi_mapping", kpi_folder) # Download model - train_rel_prefix = os.path.join(project_prefix_project_models, relevance_infer_config.experiment_type, relevance_infer_config.data_type) - output_model_zip = os.path.join(output_model_folder, args["train_relevance"]['output_model_name'] + ".zip") - s3c_main.download_file_from_s3(output_model_zip, train_rel_prefix, args["train_relevance"]['output_model_name'] + ".zip") - with zipfile.ZipFile(output_model_zip, 'r') as zip_ref: + train_rel_prefix = os.path.join( + project_prefix_project_models, relevance_infer_config.experiment_type, relevance_infer_config.data_type + ) + output_model_zip = os.path.join(output_model_folder, args["train_relevance"]["output_model_name"] + ".zip") + s3c_main.download_file_from_s3( + output_model_zip, train_rel_prefix, args["train_relevance"]["output_model_name"] + ".zip" + ) + with zipfile.ZipFile(output_model_zip, "r") as zip_ref: zip_ref.extractall(output_model_folder) os.remove(output_model_zip) # Download extraction files - s3c_interim.download_files_in_prefix_to_dir(project_prefix_data + '/interim/ml/extraction', - EXTRACTION_FOLDER) - # Download annotation files - s3c_interim.download_files_in_prefix_to_dir(project_prefix_data + '/interim/ml/annotations', - ANNOTATION_FOLDER) - + s3c_interim.download_files_in_prefix_to_dir(project_prefix_data + "/interim/ml/extraction", EXTRACTION_FOLDER) + # Download annotation files + s3c_interim.download_files_in_prefix_to_dir(project_prefix_data + "/interim/ml/annotations", ANNOTATION_FOLDER) + shutil.copyfile(os.path.join(kpi_folder, "kpi_mapping.csv"), "/app/code/kpi_mapping.csv") free_memory() @@ -285,93 +323,111 @@ def run_infer_relevance(): except Exception as e: msg = "Error during kpi infer stage\nException:" + str(repr(e) + traceback.format_exc()) return Response(msg, status=500) - + if s3_usage: - project_prefix_project_output = pathlib.Path(s3_settings['prefix'] + "/" + project_name + '/data/output') \ - / relevance_infer_config.experiment_type / relevance_infer_config.data_type - s3c_main.upload_files_in_dir_to_prefix(BASE_OUTPUT_FOLDER, project_prefix_project_output) + project_prefix_project_output = ( + pathlib.Path(s3_settings["prefix"] + "/" + project_name + "/data/output") + / relevance_infer_config.experiment_type + / relevance_infer_config.data_type + ) + s3c_main.upload_files_in_dir_to_prefix(BASE_OUTPUT_FOLDER, project_prefix_project_output) create_directory(kpi_folder) - create_directory(str(pathlib.Path(output_model_folder) / args["train_relevance"]['output_model_name'])) + create_directory(str(pathlib.Path(output_model_folder) / args["train_relevance"]["output_model_name"])) create_directory(BASE_OUTPUT_FOLDER) create_directory(ANNOTATION_FOLDER) create_directory(EXTRACTION_FOLDER) - - time_elapsed = str(timedelta(seconds=t2-t1)) + + time_elapsed = str(timedelta(seconds=t2 - t1)) msg = "Inference for the relevance stage finished successfully!\nTime elapsed:{}".format(time_elapsed) return Response(msg, status=200) -@app.route('/train_kpi/') +@app.route("/train_kpi/") def run_train_kpi(): - args = json.loads(request.args['payload']) + args = json.loads(request.args["payload"]) project_name = args["project_name"] kpi_inference_training_settings = args["train_kpi"] - file_config = QAFileConfig(project_name, kpi_inference_training_settings['output_model_name']) + file_config = QAFileConfig(project_name, kpi_inference_training_settings["output_model_name"]) config.TextKPIInferenceCurator_kwargs = { "annotation_folder": DATA_FOLDER / project_name / "interim" / "ml" / "annotations", "agg_annotation": DATA_FOLDER / project_name / "interim" / "ml" / "annotations" / "aggregated_annotation.csv", "extracted_text_json_folder": DATA_FOLDER / project_name / "interim" / "ml" / "extraction", - "output_squad_folder": DATA_FOLDER / project_name / "interim" / "ml" / "training", - "relevant_text_path": DATA_FOLDER / project_name / "interim" / "ml" / "text_3434.csv" + "output_squad_folder": DATA_FOLDER / project_name / "interim" / "ml" / "training", + "relevant_text_path": DATA_FOLDER / project_name / "interim" / "ml" / "text_3434.csv", } - config.seed = kpi_inference_training_settings['seed'] - + config.seed = kpi_inference_training_settings["seed"] + free_memory() try: t1 = time.time() - + tkpi = TextKPIInferenceCurator(**config.TextKPIInferenceCurator_kwargs) tkpi.annotation_folder = file_config.annotation_dir tkpi.agg_annotation = os.path.join(file_config.annotation_dir, "aggregated_annotation.csv") tkpi.output_squad_folder = file_config.training_dir tkpi.relevant_text_path = os.path.join(file_config.data_dir, project_name, "interim", "ml", "text_3434.csv") curation_input = kpi_inference_training_settings["curation"] - + create_directory(tkpi.annotation_folder) create_directory(str(pathlib.Path(file_config.data_dir) / project_name / "interim" / "ml")) create_directory(tkpi.output_squad_folder) - - if args['s3_usage']: + + if args["s3_usage"]: s3_settings = args["s3_settings"] s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), - ) + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), + ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), - ) - project_prefix_data = pathlib.Path(s3_settings['prefix']) / project_name / 'data' - project_prefix_agg_annotations = str(project_prefix_data / 'interim' / 'ml' / 'annotations') - + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), + ) + project_prefix_data = pathlib.Path(s3_settings["prefix"]) / project_name / "data" + project_prefix_agg_annotations = str(project_prefix_data / "interim" / "ml" / "annotations") + # Download aggregated annotations file - s3c_interim.download_file_from_s3(filepath=tkpi.agg_annotation, - s3_prefix=project_prefix_agg_annotations, - s3_key='aggregated_annotation.csv') - + s3c_interim.download_file_from_s3( + filepath=tkpi.agg_annotation, + s3_prefix=project_prefix_agg_annotations, + s3_key="aggregated_annotation.csv", + ) + # Download kpi file - s3c_main.download_file_from_s3('/app/code/kpi_mapping.csv', s3_prefix=str(project_prefix_data / 'input' / 'kpi_mapping'), s3_key='kpi_mapping.csv') - + s3c_main.download_file_from_s3( + "/app/code/kpi_mapping.csv", + s3_prefix=str(project_prefix_data / "input" / "kpi_mapping"), + s3_key="kpi_mapping.csv", + ) + # Download text_3434 file - s3c_interim.download_file_from_s3(tkpi.relevant_text_path, s3_prefix=str(project_prefix_data / 'interim' / 'ml'), s3_key='text_3434.csv') - + s3c_interim.download_file_from_s3( + tkpi.relevant_text_path, s3_prefix=str(project_prefix_data / "interim" / "ml"), s3_key="text_3434.csv" + ) + # Download extractions - s3c_interim.download_files_in_prefix_to_dir(str(project_prefix_data / 'interim' / 'ml' / 'extraction'), - str(config.TextKPIInferenceCurator_kwargs['extracted_text_json_folder'])) - - _, _ = tkpi.curate(curation_input["val_ratio"], curation_input["seed"], curation_input["find_new_answerable"], curation_input["create_unanswerable"]) + s3c_interim.download_files_in_prefix_to_dir( + str(project_prefix_data / "interim" / "ml" / "extraction"), + str(config.TextKPIInferenceCurator_kwargs["extracted_text_json_folder"]), + ) + + _, _ = tkpi.curate( + curation_input["val_ratio"], + curation_input["seed"], + curation_input["find_new_answerable"], + curation_input["create_unanswerable"], + ) + + print("Curation step in train_kpi done.") - print('Curation step in train_kpi done.') - free_memory() - train_config = QATrainingConfig(project_name, kpi_inference_training_settings['seed']) + train_config = QATrainingConfig(project_name, kpi_inference_training_settings["seed"]) model_config = QAModelConfig(project_name) mlflow_config = QAMLFlowConfig(project_name) processor_config = QAProcessorConfig(project_name) @@ -398,19 +454,23 @@ def run_train_kpi(): processor_config.max_seq_len = kpi_inference_training_settings["processor"]["max_seq_len"] processor_config.label_list = kpi_inference_training_settings["processor"]["label_list"] processor_config.metric = kpi_inference_training_settings["processor"]["metric"] - + if kpi_inference_training_settings["input_model_name"] is not None: base_model_dir = os.path.join(str(MODEL_FOLDER), project_name, "KPI_EXTRACTION", "Text") model_dir = os.path.join(base_model_dir, kpi_inference_training_settings["input_model_name"]) model_config.load_dir = model_dir tokenizer_config.pretrained_model_name_or_path = model_dir model_config.lang_model = model_dir - if args['s3_usage']: + if args["s3_usage"]: # Download model - model_inf_prefix = str(pathlib.Path(s3_settings['prefix']) / project_name / 'models' / 'KPI_EXTRACTION' / 'Text') + model_inf_prefix = str( + pathlib.Path(s3_settings["prefix"]) / project_name / "models" / "KPI_EXTRACTION" / "Text" + ) output_model_zip = model_dir + ".zip" - s3c_main.download_file_from_s3(output_model_zip, model_inf_prefix, kpi_inference_training_settings["input_model_name"] + ".zip") - with zipfile.ZipFile(output_model_zip, 'r') as zip_ref: + s3c_main.download_file_from_s3( + output_model_zip, model_inf_prefix, kpi_inference_training_settings["input_model_name"] + ".zip" + ) + with zipfile.ZipFile(output_model_zip, "r") as zip_ref: zip_ref.extractall(base_model_dir) os.remove(output_model_zip) else: @@ -420,105 +480,127 @@ def run_train_kpi(): farm_trainer_class = QAFARMTrainer farm_trainer = farm_trainer_class( - file_config =file_config, + file_config=file_config, tokenizer_config=tokenizer_config, model_config=model_config, processor_config=processor_config, training_config=train_config, - mlflow_config=mlflow_config + mlflow_config=mlflow_config, ) - + result = farm_trainer.run() - + # save results to json file - result.pop('preds', None) - result.pop('labels', None) + result.pop("preds", None) + result.pop("labels", None) for key in result: result[key] = str(result[key]) - name_out = os.path.join(str(MODEL_FOLDER), project_name, "result_kpi_" + kpi_inference_training_settings['output_model_name'] + ".json") - with open(name_out, 'w') as f: + name_out = os.path.join( + str(MODEL_FOLDER), + project_name, + "result_kpi_" + kpi_inference_training_settings["output_model_name"] + ".json", + ) + with open(name_out, "w") as f: json.dump(result, f) - - if args['s3_usage']: - print('Next we store the files to S3.') - train_inf_prefix = str(pathlib.Path(s3_settings['prefix']) / project_name / 'models' / 'KPI_EXTRACTION' / 'Text') - output_model_folder = str(pathlib.Path(MODEL_FOLDER) / project_name / "KPI_EXTRACTION" / "Text" / kpi_inference_training_settings['output_model_name']) - print('First we zip the model. This can take some time.') - with zipfile.ZipFile(output_model_folder + ".zip", 'w', zipfile.ZIP_DEFLATED) as zipf: - zipdir(output_model_folder, zipf) - print('Next we upload the model to S3. This can take some time.') - response = s3c_main.upload_file_to_s3(filepath=output_model_folder + ".zip", - s3_prefix=train_inf_prefix, - s3_key=kpi_inference_training_settings['output_model_name'] + ".zip") + + if args["s3_usage"]: + print("Next we store the files to S3.") + train_inf_prefix = str( + pathlib.Path(s3_settings["prefix"]) / project_name / "models" / "KPI_EXTRACTION" / "Text" + ) + output_model_folder = str( + pathlib.Path(MODEL_FOLDER) + / project_name + / "KPI_EXTRACTION" + / "Text" + / kpi_inference_training_settings["output_model_name"] + ) + print("First we zip the model. This can take some time.") + with zipfile.ZipFile(output_model_folder + ".zip", "w", zipfile.ZIP_DEFLATED) as zipf: + zipdir(output_model_folder, zipf) + print("Next we upload the model to S3. This can take some time.") + response = s3c_main.upload_file_to_s3( + filepath=output_model_folder + ".zip", + s3_prefix=train_inf_prefix, + s3_key=kpi_inference_training_settings["output_model_name"] + ".zip", + ) os.remove(output_model_folder + ".zip") - print('Next we upload the interim training files and the statistics to S3.') - response_2 = s3c_main.upload_file_to_s3(filepath=name_out, - s3_prefix=train_inf_prefix, - s3_key="result_kpi_" + kpi_inference_training_settings['output_model_name'] + ".json") - s3c_interim.upload_files_in_dir_to_prefix(source_dir=str(DATA_FOLDER / project_name / 'interim' / 'ml' / 'training'), - s3_prefix=str(pathlib.Path(s3_settings['prefix']) / project_name / 'data' / 'interim' / 'ml' / 'training')) + print("Next we upload the interim training files and the statistics to S3.") + response_2 = s3c_main.upload_file_to_s3( + filepath=name_out, + s3_prefix=train_inf_prefix, + s3_key="result_kpi_" + kpi_inference_training_settings["output_model_name"] + ".json", + ) + s3c_interim.upload_files_in_dir_to_prefix( + source_dir=str(DATA_FOLDER / project_name / "interim" / "ml" / "training"), + s3_prefix=str( + pathlib.Path(s3_settings["prefix"]) / project_name / "data" / "interim" / "ml" / "training" + ), + ) create_directory(output_model_folder) create_directory(output_model_folder) create_directory(os.path.join(str(MODEL_FOLDER), project_name)) create_directory(tkpi.annotation_folder) create_directory(str(pathlib.Path(file_config.data_dir) / project_name / "interim" / "ml")) - create_directory(str(project_prefix_data / 'input')) + create_directory(str(project_prefix_data / "input")) create_directory(tkpi.output_squad_folder) - create_directory(str(config.TextKPIInferenceCurator_kwargs['extracted_text_json_folder'])) + create_directory(str(config.TextKPIInferenceCurator_kwargs["extracted_text_json_folder"])) if kpi_inference_training_settings["input_model_name"] is not None: create_directory(model_dir) - + t2 = time.time() except Exception as e: msg = "Error during kpi infer stage\nException:" + str(repr(e) + traceback.format_exc()) return Response(msg, status=500) - time_elapsed = str(timedelta(seconds=t2-t1)) + time_elapsed = str(timedelta(seconds=t2 - t1)) msg = "Training for the kpi extraction stage finished successfully!\nTime elapsed:{}".format(time_elapsed) return Response(msg, status=200) -@app.route('/infer_kpi/') +@app.route("/infer_kpi/") def run_infer_kpi(): - args = json.loads(request.args['payload']) + args = json.loads(request.args["payload"]) project_name = args["project_name"] - relevance_infer_config = InferConfig(project_name, args["train_relevance"]['output_model_name']) - qa_infer_config = QAInferConfig(project_name, args["train_kpi"]['output_model_name']) + relevance_infer_config = InferConfig(project_name, args["train_relevance"]["output_model_name"]) + qa_infer_config = QAInferConfig(project_name, args["train_kpi"]["output_model_name"]) + + qa_infer_config.skip_processed_files = args["infer_kpi"]["skip_processed_files"] + qa_infer_config.top_k = args["infer_kpi"]["top_k"] + qa_infer_config.batch_size = args["infer_kpi"]["batch_size"] + qa_infer_config.num_processes = args["infer_kpi"]["num_processes"] + qa_infer_config.no_ans_boost = args["infer_kpi"]["no_ans_boost"] - qa_infer_config.skip_processed_files = args["infer_kpi"]['skip_processed_files'] - qa_infer_config.top_k = args["infer_kpi"]['top_k'] - qa_infer_config.batch_size = args["infer_kpi"]['batch_size'] - qa_infer_config.num_processes = args["infer_kpi"]['num_processes'] - qa_infer_config.no_ans_boost = args["infer_kpi"]['no_ans_boost'] - s3_usage = args["s3_usage"] if s3_usage: s3_settings = args["s3_settings"] # init s3 connector s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), ) # Download model - project_prefix_project_models = s3_settings['prefix'] + "/" + project_name + '/models' - train_inf_prefix = os.path.join(project_prefix_project_models, 'KPI_EXTRACTION', qa_infer_config.data_type) - output_model_folder = str(MODEL_FOLDER / project_name / 'KPI_EXTRACTION' / qa_infer_config.data_type) + project_prefix_project_models = s3_settings["prefix"] + "/" + project_name + "/models" + train_inf_prefix = os.path.join(project_prefix_project_models, "KPI_EXTRACTION", qa_infer_config.data_type) + output_model_folder = str(MODEL_FOLDER / project_name / "KPI_EXTRACTION" / qa_infer_config.data_type) create_directory(output_model_folder) - output_model_zip = os.path.join(output_model_folder, args["train_kpi"]['output_model_name'] + ".zip") - s3c_main.download_file_from_s3(output_model_zip, train_inf_prefix, args["train_kpi"]['output_model_name'] + ".zip") - with zipfile.ZipFile(output_model_zip, 'r') as zip_ref: + output_model_zip = os.path.join(output_model_folder, args["train_kpi"]["output_model_name"] + ".zip") + s3c_main.download_file_from_s3( + output_model_zip, train_inf_prefix, args["train_kpi"]["output_model_name"] + ".zip" + ) + with zipfile.ZipFile(output_model_zip, "r") as zip_ref: zip_ref.extractall(output_model_folder) os.remove(output_model_zip) - + free_memory() try: t1 = time.time() @@ -527,38 +609,39 @@ def run_infer_kpi(): create_directory(relevance_result_dir) if s3_usage: # Download relevance output - project_prefix_output = pathlib.Path(s3_settings['prefix']) / project_name / 'data' / 'output' - s3c_main.download_files_in_prefix_to_dir(str(project_prefix_output / 'RELEVANCE' / data_type), relevance_result_dir) + project_prefix_output = pathlib.Path(s3_settings["prefix"]) / project_name / "data" / "output" + s3c_main.download_files_in_prefix_to_dir( + str(project_prefix_output / "RELEVANCE" / data_type), relevance_result_dir + ) kpi_infer_component_class = CLASS_DATA_TYPE_KPI[data_type] kpi_infer_component_obj = kpi_infer_component_class(qa_infer_config) result_kpi = kpi_infer_component_obj.infer_on_relevance_results(relevance_result_dir) if s3_usage: # Upload kpi inference output - output_results_folder = str(DATA_FOLDER / project_name / 'output' / 'KPI_EXTRACTION' / 'ml' / data_type) - project_prefix_output = pathlib.Path(s3_settings['prefix']) / project_name / 'data' / 'output' - s3c_main.upload_files_in_dir_to_prefix(output_results_folder, str(project_prefix_output / 'KPI_EXTRACTION' / 'ml' / data_type)) + output_results_folder = str(DATA_FOLDER / project_name / "output" / "KPI_EXTRACTION" / "ml" / data_type) + project_prefix_output = pathlib.Path(s3_settings["prefix"]) / project_name / "data" / "output" + s3c_main.upload_files_in_dir_to_prefix( + output_results_folder, str(project_prefix_output / "KPI_EXTRACTION" / "ml" / data_type) + ) create_directory(relevance_result_dir) create_directory(output_results_folder) t2 = time.time() except Exception as e: msg = "Error during kpi infer stage\nException:" + str(repr(e) + traceback.format_exc()) return Response(msg, status=500) - + if s3_usage: - create_directory(output_model_folder + "/" + args["train_kpi"]['output_model_name']) - - time_elapsed = str(timedelta(seconds=t2-t1)) + create_directory(output_model_folder + "/" + args["train_kpi"]["output_model_name"]) + + time_elapsed = str(timedelta(seconds=t2 - t1)) msg = "Inference for the kpi extraction stage finished successfully!\nTime elapsed:{}".format(time_elapsed) return Response(msg, status=200) if __name__ == "__main__": - parser = argparse.ArgumentParser(description='inference server') + parser = argparse.ArgumentParser(description="inference server") # Add the arguments - parser.add_argument('--port', - type=int, - default=6000, - help='port to use for the infer server') + parser.add_argument("--port", type=int, default=6000, help="port to use for the infer server") args = parser.parse_args() port = args.port app.run(host="0.0.0.0", port=port) diff --git a/data_extractor/code/model_pipeline/model_pipeline/optuna_hyp.py b/data_extractor/code/model_pipeline/model_pipeline/optuna_hyp.py index da24baa..48187c5 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/optuna_hyp.py +++ b/data_extractor/code/model_pipeline/model_pipeline/optuna_hyp.py @@ -1,16 +1,25 @@ import optuna import model_pipeline -from model_pipeline import FARMTrainer, ModelConfig, FileConfig, TokenizerConfig, MLFlowConfig, ProcessorConfig, TrainingConfig +from model_pipeline import ( + FARMTrainer, + ModelConfig, + FileConfig, + TokenizerConfig, + MLFlowConfig, + ProcessorConfig, + TrainingConfig, +) + def objective(trial): # Uniform parameter - dropout_rate = trial.suggest_uniform('dropout_rate', 0.0, 1.0) + dropout_rate = trial.suggest_uniform("dropout_rate", 0.0, 1.0) - num_epochs = trial.suggest_int('num_epochs', 1, 5, 1) - batch_size = trial.suggest_int('batch_size', 4, 32, 4) + num_epochs = trial.suggest_int("num_epochs", 1, 5, 1) + batch_size = trial.suggest_int("batch_size", 4, 32, 4) # Loguniform parameter - learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-2) + learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-2) file_config = FileConfig() train_config = TrainingConfig() @@ -27,17 +36,18 @@ def objective(trial): tokenizer_config = TokenizerConfig() farm_trainer = FARMTrainer( - file_config =file_config, - tokenizer_config=tokenizer_config, - model_config=model_config, - processor_config=processor_config, - training_config=train_config, - mlflow_config=mlflow_config - ) + file_config=file_config, + tokenizer_config=tokenizer_config, + model_config=model_config, + processor_config=processor_config, + training_config=train_config, + mlflow_config=mlflow_config, + ) acc = farm_trainer.run(trial) return acc + if __name__ == "__main__": study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) study.optimize(objective, n_trials=100) diff --git a/data_extractor/code/model_pipeline/model_pipeline/qa_farm_trainer.py b/data_extractor/code/model_pipeline/model_pipeline/qa_farm_trainer.py index 1a213c7..8c6186b 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/qa_farm_trainer.py +++ b/data_extractor/code/model_pipeline/model_pipeline/qa_farm_trainer.py @@ -37,14 +37,19 @@ def __init__( :param mlflow_config: config object which sets MLflow parameters to monitor training :param model_config: config object which sets FARM AdaptiveModel parameters """ - super().__init__(file_config=file_config, tokenizer_config=tokenizer_config, processor_config=processor_config, - model_config=model_config, training_config=training_config, mlflow_config=mlflow_config) + super().__init__( + file_config=file_config, + tokenizer_config=tokenizer_config, + processor_config=processor_config, + model_config=model_config, + training_config=training_config, + mlflow_config=mlflow_config, + ) if self.file_config.data_type != "Text": raise ValueError("only `Text` is supported for QA.") def prepare_data(self): - if self.file_config.perform_splitting: _logger.info("Loading the {} data and splitting to train and val...".format(self.file_config.curated_data)) @@ -52,8 +57,9 @@ def prepare_data(self): dataset = json.load(read_file) # Splitting # create a list of all paragraphs, The splitting will happen at the paragraph level - paragraphs_list = [{'title': pdf['title'], 'single_paragraph': par} for pdf in dataset['data'] for par in - pdf["paragraphs"]] + paragraphs_list = [ + {"title": pdf["title"], "single_paragraph": par} for pdf in dataset["data"] for par in pdf["paragraphs"] + ] random.seed(self.training_config.seed) random.shuffle(paragraphs_list) train_paragraphs, dev_paragraphs = train_test_split(paragraphs_list, test_size=self.file_config.dev_split) @@ -63,27 +69,29 @@ def reformat_paragraphs(paragraphs): paragraphs_dict = defaultdict(list) for par in paragraphs: - paragraphs_dict[par['title']].append(par['single_paragraph']) + paragraphs_dict[par["title"]].append(par["single_paragraph"]) - squad_like_dataset = [{'title': key, 'paragraphs': value} for key, value in paragraphs_dict.items()] + squad_like_dataset = [{"title": key, "paragraphs": value} for key, value in paragraphs_dict.items()] return squad_like_dataset - train_data = {'version': 'v2.0', 'data': reformat_paragraphs(train_paragraphs)} - dev_data = {'version': 'v2.0', 'data': reformat_paragraphs(dev_paragraphs)} + train_data = {"version": "v2.0", "data": reformat_paragraphs(train_paragraphs)} + dev_data = {"version": "v2.0", "data": reformat_paragraphs(dev_paragraphs)} - with open(self.file_config.train_filename, 'w') as outfile: + with open(self.file_config.train_filename, "w") as outfile: json.dump(train_data, outfile) - with open(self.file_config.dev_filename, 'w') as outfile: + with open(self.file_config.dev_filename, "w") as outfile: json.dump(dev_data, outfile) else: - _logger.info("Loading the train from {} \n Loading validation data from {}". - format(self.file_config.train_filename, self.file_config.dev_filename)) + _logger.info( + "Loading the train from {} \n Loading validation data from {}".format( + self.file_config.train_filename, self.file_config.dev_filename + ) + ) for filename in [self.file_config.train_filename, self.file_config.dev_filename]: assert os.path.exists(filename), "File `{}` does not exist.".format(filename) - def create_head(self): if "squad" in self.model_config.lang_model: return QuestionAnsweringHead.load(self.model_config.lang_model) @@ -118,9 +126,7 @@ def run_cv(self, data_silo, xval_folds, device, n_gpu): all_em_answerable = [] all_f1_answerable = [] - silos = DataSiloForCrossVal.make( - data_silo, sets=["train", "dev"], n_splits=xval_folds - ) + silos = DataSiloForCrossVal.make(data_silo, sets=["train", "dev"], n_splits=xval_folds) for num_fold, silo in enumerate(silos): model = self._train_on_split(data_silo, silo, num_fold, device, n_gpu) @@ -141,15 +147,9 @@ def run_cv(self, data_silo, xval_folds, device, n_gpu): all_em_answerable.append(result["em_answerable"]) all_f1_answerable.append(result["f1_answerable"]) - _logger.info( - f"############ RESULT_CV -- {self.training_config.xval_folds} folds ############" - ) - _logger.info( - f"EM\nMean: {np.mean(all_em) * 100:.1f}, std: {np.std(all_em) * 100:.3f}" - ) - _logger.info( - f"F1\nMean: {np.mean(all_f1) * 100:.1f}, std F1: {np.std(all_f1) * 100:.3f}" - ) + _logger.info(f"############ RESULT_CV -- {self.training_config.xval_folds} folds ############") + _logger.info(f"EM\nMean: {np.mean(all_em) * 100:.1f}, std: {np.std(all_em) * 100:.3f}") + _logger.info(f"F1\nMean: {np.mean(all_f1) * 100:.1f}, std F1: {np.std(all_f1) * 100:.3f}") _logger.info( f"EM_Answerable\nMean: {np.mean(all_em_answerable) * 100:.1f}, std recall: {np.std(all_em_answerable) * 100:.3f}" ) diff --git a/data_extractor/code/model_pipeline/model_pipeline/relevance_infer.py b/data_extractor/code/model_pipeline/model_pipeline/relevance_infer.py index f557aa6..ff892ba 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/relevance_infer.py +++ b/data_extractor/code/model_pipeline/model_pipeline/relevance_infer.py @@ -16,7 +16,7 @@ class BaseRelevanceInfer(ABC): - """ An abstract base class for predicting relevant data for given question(s). + """An abstract base class for predicting relevant data for given question(s). The `run_folder` is the main method for this class and its children. @@ -35,16 +35,17 @@ def __init__(self, infer_config): # Filter KPIs based on section and whether they can be found in text or table. self.infer_config.sectors = kpi_mapping.KPI_SECTORS self.questions = [ - q_text + q_text for q_id, (q_text, sect) in kpi_mapping.KPI_MAPPING.items() - if len(set(sect).intersection(set(self.infer_config.sectors))) > 0 and self.data_type.upper() in kpi_mapping.KPI_CATEGORY[q_id] + if len(set(sect).intersection(set(self.infer_config.sectors))) > 0 + and self.data_type.upper() in kpi_mapping.KPI_CATEGORY[q_id] ] self.result_dir = self.infer_config.result_dir[self.data_type] if not os.path.exists(self.result_dir): os.makedirs(self.result_dir) - farm_logger = logging.getLogger('farm') + farm_logger = logging.getLogger("farm") farm_logger.setLevel(self.infer_config.farm_infer_logging_level) self.model = Inferencer.load( self.infer_config.load_dir[self.data_type], @@ -52,7 +53,7 @@ def __init__(self, infer_config): gpu=self.infer_config.gpu, num_processes=self.infer_config.num_processes, disable_tqdm=self.infer_config.disable_tqdm, - return_class_probs=self.infer_config.return_class_probs + return_class_probs=self.infer_config.return_class_probs, ) def run_folder(self): @@ -63,10 +64,13 @@ def run_folder(self): all_text_path_dict = self._gather_extracted_files() df_list = [] num_pdfs = len(all_text_path_dict) - _logger.info("{} Starting Relevence Inference for the following extracted pdf files found in {}:\n{} ". - format("#" * 20, self.result_dir, [pdf for pdf in all_text_path_dict.keys()])) + _logger.info( + "{} Starting Relevence Inference for the following extracted pdf files found in {}:\n{} ".format( + "#" * 20, self.result_dir, [pdf for pdf in all_text_path_dict.keys()] + ) + ) for i, (pdf_name, file_path) in enumerate(all_text_path_dict.items()): - _logger.info("{} {}/{} PDFs".format("#" * 20, i+1, num_pdfs)) + _logger.info("{} {}/{} PDFs".format("#" * 20, i + 1, num_pdfs)) predictions_file_name = "{}_{}".format(pdf_name, "predictions_relevant.csv") if self.infer_config.skip_processed_files and predictions_file_name in os.listdir(self.result_dir): _logger.info("The relevance infer results for {} already exists. Skipping.".format(pdf_name)) @@ -84,15 +88,13 @@ def run_folder(self): chunk_size = 1000 chunk_idx = 0 while chunk_idx * chunk_size < num_data_points: - data_chunk = data[chunk_idx * chunk_size: (chunk_idx + 1) * chunk_size] + data_chunk = data[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] predictions_chunk = self.model.inference_from_dicts(dicts=data_chunk) predictions.extend(predictions_chunk) chunk_idx += 1 - flat_predictions = [ - example for batch in predictions for example in batch["predictions"] - ] + flat_predictions = [example for batch in predictions for example in batch["predictions"]] positive_examples = [ - {**data[index], **{'paragraph_relevance_score': pred_example['probability']}} + {**data[index], **{"paragraph_relevance_score": pred_example["probability"]}} for index, pred_example in enumerate(flat_predictions) if pred_example["label"] == "1" ] @@ -100,10 +102,7 @@ def run_folder(self): df["source"] = self.data_type df_list.append(df) - predictions_file_path = os.path.join( - self.result_dir, predictions_file_name - - ) + predictions_file_path = os.path.join(self.result_dir, predictions_file_name) df.to_csv(predictions_file_path) _logger.info( "Saved {} relevant {} examples for {} in {}".format( @@ -140,8 +139,8 @@ def _gather_extracted_files(self): class TextRelevanceInfer(BaseRelevanceInfer): """This class is responsible for finding relevant texts to given questions. - Args: - infer_config (obj of model_pipeline.config.InferConfig) + Args: + infer_config (obj of model_pipeline.config.InferConfig) """ def __init__(self, infer_config): @@ -166,7 +165,7 @@ def _gather_extracted_files(self): } def _gather_data(self, pdf_name, pdf_path): - """ Gathers all the text data inside the given pdf and prepares it to be passed to text model + """Gathers all the text data inside the given pdf and prepares it to be passed to text model Args: pdf_name (str): Name of the pdf pdf_path (str): Path to the pdf @@ -208,7 +207,7 @@ def read_text_from_json(file): return text def run_text(self, input_text, input_question): - """ A method to make prediction on relevancy of a input_text and input_questions""" + """A method to make prediction on relevancy of a input_text and input_questions""" basic_texts = [ {"text": input_question, "text_b": input_text}, ] diff --git a/data_extractor/code/model_pipeline/model_pipeline/text_kpi_infer.py b/data_extractor/code/model_pipeline/model_pipeline/text_kpi_infer.py index 427f369..21109a2 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/text_kpi_infer.py +++ b/data_extractor/code/model_pipeline/model_pipeline/text_kpi_infer.py @@ -37,16 +37,19 @@ class TextKPIInfer: n_best_per_sample (int): num candidate answer spans to consider from each passage. Each passage also returns "no answer" info. This is the parameter for farm qa model. """ + def __init__(self, infer_config, n_best_per_sample=1): self.infer_config = infer_config - farm_logger = logging.getLogger('farm') + farm_logger = logging.getLogger("farm") farm_logger.setLevel(self.infer_config.farm_infer_logging_level) - self.model = QAInferencer.load(self.infer_config.load_dir["Text"], - batch_size=self.infer_config.batch_size, - gpu=self.infer_config.gpu, - num_processes=self.infer_config.num_processes) + self.model = QAInferencer.load( + self.infer_config.load_dir["Text"], + batch_size=self.infer_config.batch_size, + gpu=self.infer_config.gpu, + num_processes=self.infer_config.num_processes, + ) # num span-based candidate answer spans to consider from each passage self.model.model.prediction_heads[0].n_best_per_sample = n_best_per_sample # If positive, this will boost "No Answer" as prediction. @@ -84,12 +87,11 @@ def infer_on_file(self, squad_format_file, out_filename="predictions_of_file.jso write_squad_predictions( predictions=result_squad, predictions_filename=squad_format_file, - out_filename=os.path.join(self.result_dir, out_filename) + out_filename=os.path.join(self.result_dir, out_filename), ) self.model.close_multiprocessing_pool() return results - def infer_on_relevance_results(self, relevance_results_dir): """Make inference using the qa model on the relevant paragraphs. Args: @@ -110,9 +112,13 @@ def infer_on_relevance_results(self, relevance_results_dir): all_relevance_results_paths = glob.glob(os.path.join(relevance_results_dir, "*.csv")) all_span_dfs = [] num_csvs = len(all_relevance_results_paths) - _logger.info("{} Starting KPI Inference for the following relevance CSV files found in {}:\n{} ". - format("#" * 20, self.result_dir, [os.path.basename(relevance_results_path) for - relevance_results_path in all_relevance_results_paths])) + _logger.info( + "{} Starting KPI Inference for the following relevance CSV files found in {}:\n{} ".format( + "#" * 20, + self.result_dir, + [os.path.basename(relevance_results_path) for relevance_results_path in all_relevance_results_paths], + ) + ) for i, relevance_results_path in enumerate(all_relevance_results_paths): _logger.info("{} {}/{}".format("#" * 20, i + 1, num_csvs)) pdf_name = os.path.basename(relevance_results_path).split("_predictions_relevant")[0] @@ -126,23 +132,30 @@ def infer_on_relevance_results(self, relevance_results_dir): continue _logger.info("Starting KPI Extraction for {}".format(pdf_name)) input_df = pd.read_csv(relevance_results_path) - column_names = ['text_b', 'text', 'page', 'pdf_name', 'source', 'paragraph_relevance_score'] + column_names = ["text_b", "text", "page", "pdf_name", "source", "paragraph_relevance_score"] if len(input_df) == 0: _logger.info("The received relevance file is empty for {}".format(pdf_name)) df_empty = pd.DataFrame([]) df_empty.to_csv(os.path.join(self.result_dir, predictions_file_name)) continue - assert set(column_names).issubset(set(input_df.columns)), """The result of relevance detector has {} columns, - while expected {}""".format(input_df.columns, column_names) - - qa_dict = [{"qas": [question], "context": context} for question, context in zip(input_df["text"], input_df["text_b"])] + assert set(column_names).issubset( + set(input_df.columns) + ), """The result of relevance detector has {} columns, + while expected {}""".format( + input_df.columns, column_names + ) + + qa_dict = [ + {"qas": [question], "context": context} + for question, context in zip(input_df["text"], input_df["text_b"]) + ] num_data_points = len(qa_dict) result = [] chunk_size = 1000 chunk_idx = 0 while chunk_idx * chunk_size < num_data_points: - data_chunk = qa_dict[chunk_idx * chunk_size: (chunk_idx + 1) * chunk_size] + data_chunk = qa_dict[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size] predictions_chunk = self.model.inference_from_dicts(dicts=data_chunk) result.extend(predictions_chunk) chunk_idx += 1 @@ -153,11 +166,11 @@ def infer_on_relevance_results(self, relevance_results_dir): answers_dict = defaultdict(list) for exp in result: - preds = exp['predictions'][head_num]['answers'] + preds = exp["predictions"][head_num]["answers"] # Get the no_answer_score - no_answer_score = [p['score'] for p in preds if p['answer'] == "no_answer"] - if len(no_answer_score) == 0: # Happens if no answer is not among the n_best predictions. - no_answer_score = preds[0]['score'] - exp['predictions'][head_num]["no_ans_gap"] + no_answer_score = [p["score"] for p in preds if p["answer"] == "no_answer"] + if len(no_answer_score) == 0: # Happens if no answer is not among the n_best predictions. + no_answer_score = preds[0]["score"] - exp["predictions"][head_num]["no_ans_gap"] else: no_answer_score = no_answer_score[0] @@ -165,37 +178,48 @@ def infer_on_relevance_results(self, relevance_results_dir): # https://github.com/deepset-ai/FARM/blob/978da5d7600c48be458688996538770e9334e71b/farm/modeling/prediction_head.py#L1348 pure_no_ans_score = no_answer_score - self.infer_config.no_ans_boost - for i in range(num_answers): # This param is not exactly representative, n_best mostly defines num answers. - answers_dict[f"rank_{i+1}"].append((preds[i]['answer'], preds[i]['score'], - pure_no_ans_score, no_answer_score)) + for i in range( + num_answers + ): # This param is not exactly representative, n_best mostly defines num answers. + answers_dict[f"rank_{i+1}"].append( + (preds[i]["answer"], preds[i]["score"], pure_no_ans_score, no_answer_score) + ) for i in range(num_answers): input_df[f"rank_{i+1}"] = answers_dict[f"rank_{i+1}"] # Let's put different kpi predictions and their scores into one column so we can sort them. var_cols = [i for i in list(input_df.columns) if i.startswith("rank_")] id_vars = [i for i in list(input_df.columns) if not i.startswith("rank_")] - input_df = pd.melt(input_df, id_vars=id_vars, value_vars=var_cols, var_name='rank', value_name='answer_score') + input_df = pd.melt( + input_df, id_vars=id_vars, value_vars=var_cols, var_name="rank", value_name="answer_score" + ) # Separate a column with tuple value into two columns - input_df[['answer', 'score', 'no_ans_score', "no_answer_score_plus_boost"]] = pd.DataFrame( - input_df['answer_score'].tolist(), index=input_df.index) - input_df = input_df.drop(columns=['answer_score'], axis=1) - - no_answerables = input_df.groupby(['pdf_name', 'text']).apply(lambda grp: aggregate_result(grp)).dropna(how="all") - no_answerables = pd.DataFrame(no_answerables, columns=['score']).reset_index() + input_df[["answer", "score", "no_ans_score", "no_answer_score_plus_boost"]] = pd.DataFrame( + input_df["answer_score"].tolist(), index=input_df.index + ) + input_df = input_df.drop(columns=["answer_score"], axis=1) + + no_answerables = ( + input_df.groupby(["pdf_name", "text"]).apply(lambda grp: aggregate_result(grp)).dropna(how="all") + ) + no_answerables = pd.DataFrame(no_answerables, columns=["score"]).reset_index() no_answerables["answer"] = "no_answer" no_answerables["source"] = "Text" # Filter to span-based answers - span_df = input_df[input_df['answer'] != 'no_answer'] + span_df = input_df[input_df["answer"] != "no_answer"] # Concatenate the result of span answers with non answerable examples. span_df = pd.concat([span_df, no_answerables], ignore_index=True) # Get the predictions with n highest score for each pdf and question. # If the question is considered unanswerable, the best prediction is "no_answer", but the best span-based answer # is also returned. if the question is answerable, the best span-based answers are returned. - span_df = span_df.groupby(['pdf_name', 'text']).apply( - lambda grp: grp.nlargest(self.infer_config.top_k, 'score')).reset_index(drop=True) + span_df = ( + span_df.groupby(["pdf_name", "text"]) + .apply(lambda grp: grp.nlargest(self.infer_config.top_k, "score")) + .reset_index(drop=True) + ) # Final cleaning on the dataframe, removing unnecessary columns and renaming `text` and `text_b` columns. unnecessary_cols = ["rank"] + [i for i in list(span_df.columns) if i.startswith("Unnamed")] diff --git a/data_extractor/code/model_pipeline/model_pipeline/trainer_optuna.py b/data_extractor/code/model_pipeline/model_pipeline/trainer_optuna.py index b0b8d4e..f23b988 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/trainer_optuna.py +++ b/data_extractor/code/model_pipeline/model_pipeline/trainer_optuna.py @@ -8,6 +8,7 @@ logger = logging.getLogger(__name__) + class TrainerOptuna(Trainer): def train(self, trial): """ @@ -98,5 +99,3 @@ def train(self, trial): result = evaluator_test.eval(self.model) evaluator_test.log_results(result, "Test", self.global_step) return self.model - - diff --git a/data_extractor/code/model_pipeline/model_pipeline/utils/kpi_mapping.py b/data_extractor/code/model_pipeline/model_pipeline/utils/kpi_mapping.py index 928f0ef..7762243 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/utils/kpi_mapping.py +++ b/data_extractor/code/model_pipeline/model_pipeline/utils/kpi_mapping.py @@ -3,23 +3,20 @@ import os from pathlib import Path -#config = Config() +# config = Config() -#df = pd.read_csv(Path(config.root).parent.parent / "kpi_mapping.csv", header=0) +# df = pd.read_csv(Path(config.root).parent.parent / "kpi_mapping.csv", header=0) try: df = pd.read_csv("/app/code/kpi_mapping.csv", header=0) _KPI_MAPPING = { - str(i[0]): (i[1], [j.strip() for j in i[2].split(',')]) \ - for i in df[['kpi_id', 'question', 'sectors']].values + str(i[0]): (i[1], [j.strip() for j in i[2].split(",")]) for i in df[["kpi_id", "question", "sectors"]].values } KPI_MAPPING = {(float(key)): value for key, value in _KPI_MAPPING.items()} # Category where the answer to the question should originate from - KPI_CATEGORY = { - i[0]: [j.strip() for j in i[1].split(', ')] for i in df[['kpi_id', 'kpi_category']].values - } + KPI_CATEGORY = {i[0]: [j.strip() for j in i[1].split(", ")] for i in df[["kpi_id", "kpi_category"]].values} - KPI_SECTORS = list(set(df['sectors'].values)) + KPI_SECTORS = list(set(df["sectors"].values)) except: KPI_MAPPING = {} KPI_CATEGORY = {} diff --git a/data_extractor/code/model_pipeline/model_pipeline/utils/qa_metrics.py b/data_extractor/code/model_pipeline/model_pipeline/utils/qa_metrics.py index edd9f2c..98da6c8 100644 --- a/data_extractor/code/model_pipeline/model_pipeline/utils/qa_metrics.py +++ b/data_extractor/code/model_pipeline/model_pipeline/utils/qa_metrics.py @@ -38,9 +38,9 @@ def relaxed_squad_f1_single(pred, label, pred_idx=0): def compute_extra_metrics(eval_results): metric_dict = {} head_num = 0 - preds = eval_results[head_num]['preds'] - labels = eval_results[head_num]['labels'] - is_preds_answerable = [pred_doc[0][0].answer_type == 'span' for pred_doc in preds] + preds = eval_results[head_num]["preds"] + labels = eval_results[head_num]["labels"] + is_preds_answerable = [pred_doc[0][0].answer_type == "span" for pred_doc in preds] is_labels_answerable = [label_doc != [(-1, -1)] for label_doc in labels] # tn: label : unanswerable, predcited: unanswerable # fp: label : unanswerable, predcited: answerable @@ -48,7 +48,7 @@ def compute_extra_metrics(eval_results): # fp: label : answerable, predcited: answerable tn, fp, fn, tp = confusion_matrix(is_labels_answerable, is_preds_answerable, labels=[True, False]).ravel() - metric_dict.update({'TN': tn, 'FP': fp, 'FN': fn, 'TP': tp}) + metric_dict.update({"TN": tn, "FP": fp, "FN": fn, "TP": tp}) prediction_answerable_examples = [p for doc_idx, p in enumerate(preds) if is_labels_answerable[doc_idx]] label_answerable_examples = [l for doc_idx, l in enumerate(labels) if is_labels_answerable[doc_idx]] @@ -57,8 +57,8 @@ def compute_extra_metrics(eval_results): em_answerable = squad_EM(prediction_answerable_examples, label_answerable_examples) f1_answerable = squad_f1(prediction_answerable_examples, label_answerable_examples) - metric_dict.update({'relaxed_f1_answerable': relaxed_f1_answerable, - 'em_answerable': em_answerable, - 'f1_answerable': f1_answerable}) + metric_dict.update( + {"relaxed_f1_answerable": relaxed_f1_answerable, "em_answerable": em_answerable, "f1_answerable": f1_answerable} + ) return metric_dict diff --git a/data_extractor/code/model_pipeline/setup.py b/data_extractor/code/model_pipeline/setup.py index 4ffbeaf..8d3fdcc 100644 --- a/data_extractor/code/model_pipeline/setup.py +++ b/data_extractor/code/model_pipeline/setup.py @@ -4,34 +4,36 @@ from setuptools import find_packages, setup -NAME = 'model_pipeline' -DESCRIPTION = 'Train and infer for ESG project' -AUTHOR = '1QBit NLP' -REQUIRES_PYTHON = '>=3.6.0' +NAME = "model_pipeline" +DESCRIPTION = "Train and infer for ESG project" +AUTHOR = "1QBit NLP" +REQUIRES_PYTHON = ">=3.6.0" -def list_reqs(fname='requirements.txt'): + +def list_reqs(fname="requirements.txt"): with open(fname) as fd: return fd.read().splitlines() + here = os.path.abspath(os.path.dirname(__file__)) # Load the package's __version__.py module as a dictionary. ROOT_DIR = Path(__file__).resolve().parent PACKAGE_DIR = ROOT_DIR / NAME about = {} -with open(PACKAGE_DIR / 'VERSION') as f: +with open(PACKAGE_DIR / "VERSION") as f: _version = f.read().strip() - about['__version__'] = _version + about["__version__"] = _version setup( name=NAME, - version=about['__version__'], + version=about["__version__"], description=DESCRIPTION, author=AUTHOR, python_requires=REQUIRES_PYTHON, - packages=find_packages(exclude=('tests', 'notebooks')), - package_data={'model_pipeline': ['VERSION']}, + packages=find_packages(exclude=("tests", "notebooks")), + package_data={"model_pipeline": ["VERSION"]}, install_requires=list_reqs(), extras_require={}, - include_package_data=True + include_package_data=True, ) diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerCluster.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerCluster.py index bf7a274..76a9da8 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerCluster.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerCluster.py @@ -15,261 +15,240 @@ from Format_Analyzer import * - class AnalyzerCluster: - - - htmlcluster = None - htmlpage = None - items = None - default_year = None - bad_page = None - - - - - def find_kpis_single_node(self, kpispecs, cluster): - - def get_txt_by_idx_list(idx_list): - res = '' - for idx in idx_list: - if(res != ''): - res += ', ' - res += self.items[idx].txt - return res - - - def get_rect_by_idx_list(idx_list): - rect = Rect(9999999, 9999999, -1, -1) - for idx in idx_list: - rect.grow(self.items[idx].get_rect()) - return rect - - - - def refine_txt_items(idx_list, base_score): - needed = [] - for i in range(len(idx_list)): - sub_idx_list = idx_list[0:i] + idx_list[i+1:len(idx_list)] - sub_txt = get_txt_by_idx_list(sub_idx_list) - txt_match, score = kpispecs.match_nodes([sub_txt]) - if(score >= base_score and txt_match): - needed.append(False) - else: - needed.append(True) - res = [] - for i in range(len(idx_list)): - if(needed[i]): - res.append(idx_list[i]) - return res - - - def find_nearest_matching_str(idx_list, ref_point_x, ref_point_y, matching_fun, exclude_years): - best = -1 - best_word = -1 - best_dist = 9999999 - best_rect = None - - for i in range(len(idx_list)): - it = self.items[idx_list[i]] - txt = it.txt if not exclude_years else Format_Analyzer.exclude_all_years(it.txt) - print_verbose(7, '-------->Looking for '+str(matching_fun)+' in: "'+txt+'"') - if(matching_fun(txt)): - #whole string - print_verbose(7, '----------> FOUND!') - cur_x, cur_y = it.get_rect().get_center() - cur_dist = dist(ref_point_x, ref_point_y, cur_x, cur_y) - if(cur_dist < best_dist): - best_dist = cur_dist - best = i - best_word = -1 - best_rect = it.get_rect() - else: - #each word - for j in range(len(it.words)): - wtxt =it.words[j].txt if not exclude_years else Format_Analyzer.exclude_all_years(it.words[j].txt) - print_verbose(7, '-------->Looking for '+str(matching_fun)+' in: "'+wtxt+'"') - if(matching_fun(wtxt)): - print_verbose(7, '----------> FOUND!') - cur_x, cur_y = it.words[j].rect.get_center() - cur_dist = dist(ref_point_x, ref_point_y, cur_x, cur_y) - if(cur_dist < best_dist): - best_dist = cur_dist - best = i - best_word = j - best_rect = it.words[j].rect - - if(best == -1): - return None, None, best_rect - - if(best_word == -1): - return idx_list[best], self.items[idx_list[best]].txt, best_rect - - return idx_list[best], self.items[idx_list[best]].words[best_word].txt, best_rect - - - - print_verbose(5, 'ANALYZING CLUSTER NODE ===>>> ' + cluster.flat_text) - - txt = cluster.flat_text - - idx_list = cluster.get_idx_list() - - - # get text - - txt_match, score = kpispecs.match_nodes([txt]) - print_verbose(5, '---> txt base_score='+str(score)) - if(not txt_match): - print_verbose(5, '---> No match') - return None - - idx_list_refined_txt = refine_txt_items(idx_list, score) - txt_refined = get_txt_by_idx_list(idx_list_refined_txt) - - txt_match, score = kpispecs.match_nodes([txt_refined]) - print_verbose(5, '------> After refinement: ' + txt_refined) - print_verbose(5, '------> txt score='+str(score)) - - - base_point_x, base_point_y = get_rect_by_idx_list(idx_list_refined_txt).get_center() - - # get value - - raw_value_idx, raw_value, value_rect = find_nearest_matching_str(idx_list, base_point_x, base_point_y, kpispecs.match_value, not kpispecs.value_must_be_year) # TODO: Maybe not always exclude years ? - if(raw_value is None): - print_verbose(5, '---> Value missmatch') - return None # value missmatch - - print_verbose(5, '------> raw_value: '+str(raw_value)) - - # get unit - txt_unit_matched = kpispecs.match_unit(txt) - if(not txt_unit_matched): - print_verbose(5, '---> Unit not matched') - return None #unit not matched - - unit_str = '' - unit_idx = None - if(kpispecs.has_unit()): - unit_idx, unit_str, foo = find_nearest_matching_str(idx_list, base_point_x, base_point_y, kpispecs.match_unit, True) # TODO: Maybe not always exclude years ? - if(unit_idx is None): - print_verbose(5, '---> Unit not matched in individual item') - return None - - print_verbose(5, '------> unit_str: '+str(unit_str)) - - - # get year - year = -1 - year_idx, year_str, foo = find_nearest_matching_str(idx_list, base_point_x, base_point_y, Format_Analyzer.looks_year, False) - if(year_str is not None): - year = Format_Analyzer.to_year(year_str) - - - print_verbose(5, '------> year_str: '+str(year_str)) - - - # get new idx list - base_new_idx_list = idx_list_refined_txt - base_new_idx_list.append(raw_value_idx) - if(unit_idx is not None): - base_new_idx_list.append(unit_idx) - if(year_idx is not None): - base_new_idx_list.append(year_idx) - - base_rect = Rect(9999999, 9999999, -1, -1) - for idx in base_new_idx_list: - print_verbose(7,'................----> base_item='+str(self.items[idx])) - base_rect.grow(self.items[idx].get_rect()) - - print_verbose(5, '----> base_rect='+str(base_rect)) - - - new_idx_list_in_rect = self.htmlpage.find_items_within_rect_all_categories(base_rect) - new_idx_list = list(set.intersection(set(new_idx_list_in_rect),set(idx_list))) - - - final_txt = get_txt_by_idx_list(new_idx_list) - print_verbose(5, '------> Final text: "'+str(final_txt)+'"') - - txt_match, final_txt_score = kpispecs.match_nodes([final_txt]) - print_verbose(5, '---> txt final_score='+str(final_txt_score)) - if(not txt_match): - print_verbose(5, '---> No match') - return None - - - rect = Rect(9999999, 9999999, -1, -1) - anywhere_match_score = 9999999 - for idx in new_idx_list: - rect.grow(self.items[idx].get_rect()) - anywhere_match, anywhere_match_score_cur = kpispecs.match_anywhere_on_page(self.htmlpage, idx) - anywhere_match_score = min(anywhere_match_score, anywhere_match_score_cur) - if(not anywhere_match): - print_verbose(5, '---> anywhere-match was not matched on this page. No other match possible.') - self.bad_page = True - return None - - #anywhere_match_score /= len(new_idx_list) , only do this when avg is used - - - kpi_measure = KPIMeasure() - kpi_measure.kpi_id = kpispecs.kpi_id - kpi_measure.kpi_name = kpispecs.kpi_name - kpi_measure.src_file = 'TODO' - kpi_measure.page_num = self.htmlpage.items[raw_value_idx].page_num - kpi_measure.item_ids = idx_list - kpi_measure.pos_x = value_rect.x0 # (rect.x0+rect.x1)*0.5 - kpi_measure.pos_y = value_rect.y0 # (rect.y0+rect.y1)*0.5 - kpi_measure.raw_txt = raw_value - kpi_measure.year = year - kpi_measure.value = kpispecs.extract_value(raw_value) - kpi_measure.score = final_txt_score + anywhere_match_score - kpi_measure.unit = unit_str - kpi_measure.match_type= 'AC.default' - print_verbose(5, '---> Match: ' + str(kpi_measure) + ': final_txt_score='+str(final_txt_score)+',anywhere_match_score='+str(anywhere_match_score)) - - return kpi_measure - - - - - - - def find_kpis_rec(self, kpispecs, cluster): - - - cur_kpi = self.find_kpis_single_node(kpispecs, cluster) - - - res=[] - if(cur_kpi is not None): - res = [cur_kpi] - - for c in cluster.children: - res.extend(self.find_kpis_rec(kpispecs, c)) - - return res - - - - - - def find_kpis(self, kpispecs): - - if(self.htmlcluster is None): - return [] - - res = self.find_kpis_rec(kpispecs, self.htmlcluster) - - return res - - - def __init__(self, htmlcluster, htmlpage, default_year): - self.htmlcluster = htmlcluster - self.htmlpage = htmlpage - self.items = htmlpage.items - self.default_year = default_year - self.bad_page = False - + htmlcluster = None + htmlpage = None + items = None + default_year = None + bad_page = None + + def find_kpis_single_node(self, kpispecs, cluster): + def get_txt_by_idx_list(idx_list): + res = "" + for idx in idx_list: + if res != "": + res += ", " + res += self.items[idx].txt + return res + + def get_rect_by_idx_list(idx_list): + rect = Rect(9999999, 9999999, -1, -1) + for idx in idx_list: + rect.grow(self.items[idx].get_rect()) + return rect + + def refine_txt_items(idx_list, base_score): + needed = [] + for i in range(len(idx_list)): + sub_idx_list = idx_list[0:i] + idx_list[i + 1 : len(idx_list)] + sub_txt = get_txt_by_idx_list(sub_idx_list) + txt_match, score = kpispecs.match_nodes([sub_txt]) + if score >= base_score and txt_match: + needed.append(False) + else: + needed.append(True) + res = [] + for i in range(len(idx_list)): + if needed[i]: + res.append(idx_list[i]) + return res + + def find_nearest_matching_str(idx_list, ref_point_x, ref_point_y, matching_fun, exclude_years): + best = -1 + best_word = -1 + best_dist = 9999999 + best_rect = None + + for i in range(len(idx_list)): + it = self.items[idx_list[i]] + txt = it.txt if not exclude_years else Format_Analyzer.exclude_all_years(it.txt) + print_verbose(7, "-------->Looking for " + str(matching_fun) + ' in: "' + txt + '"') + if matching_fun(txt): + # whole string + print_verbose(7, "----------> FOUND!") + cur_x, cur_y = it.get_rect().get_center() + cur_dist = dist(ref_point_x, ref_point_y, cur_x, cur_y) + if cur_dist < best_dist: + best_dist = cur_dist + best = i + best_word = -1 + best_rect = it.get_rect() + else: + # each word + for j in range(len(it.words)): + wtxt = ( + it.words[j].txt if not exclude_years else Format_Analyzer.exclude_all_years(it.words[j].txt) + ) + print_verbose(7, "-------->Looking for " + str(matching_fun) + ' in: "' + wtxt + '"') + if matching_fun(wtxt): + print_verbose(7, "----------> FOUND!") + cur_x, cur_y = it.words[j].rect.get_center() + cur_dist = dist(ref_point_x, ref_point_y, cur_x, cur_y) + if cur_dist < best_dist: + best_dist = cur_dist + best = i + best_word = j + best_rect = it.words[j].rect + + if best == -1: + return None, None, best_rect + + if best_word == -1: + return idx_list[best], self.items[idx_list[best]].txt, best_rect + + return idx_list[best], self.items[idx_list[best]].words[best_word].txt, best_rect + + print_verbose(5, "ANALYZING CLUSTER NODE ===>>> " + cluster.flat_text) + + txt = cluster.flat_text + + idx_list = cluster.get_idx_list() + + # get text + + txt_match, score = kpispecs.match_nodes([txt]) + print_verbose(5, "---> txt base_score=" + str(score)) + if not txt_match: + print_verbose(5, "---> No match") + return None + + idx_list_refined_txt = refine_txt_items(idx_list, score) + txt_refined = get_txt_by_idx_list(idx_list_refined_txt) + + txt_match, score = kpispecs.match_nodes([txt_refined]) + print_verbose(5, "------> After refinement: " + txt_refined) + print_verbose(5, "------> txt score=" + str(score)) + + base_point_x, base_point_y = get_rect_by_idx_list(idx_list_refined_txt).get_center() + + # get value + + raw_value_idx, raw_value, value_rect = find_nearest_matching_str( + idx_list, base_point_x, base_point_y, kpispecs.match_value, not kpispecs.value_must_be_year + ) # TODO: Maybe not always exclude years ? + if raw_value is None: + print_verbose(5, "---> Value missmatch") + return None # value missmatch + + print_verbose(5, "------> raw_value: " + str(raw_value)) + + # get unit + txt_unit_matched = kpispecs.match_unit(txt) + if not txt_unit_matched: + print_verbose(5, "---> Unit not matched") + return None # unit not matched + + unit_str = "" + unit_idx = None + if kpispecs.has_unit(): + unit_idx, unit_str, foo = find_nearest_matching_str( + idx_list, base_point_x, base_point_y, kpispecs.match_unit, True + ) # TODO: Maybe not always exclude years ? + if unit_idx is None: + print_verbose(5, "---> Unit not matched in individual item") + return None + + print_verbose(5, "------> unit_str: " + str(unit_str)) + + # get year + year = -1 + year_idx, year_str, foo = find_nearest_matching_str( + idx_list, base_point_x, base_point_y, Format_Analyzer.looks_year, False + ) + if year_str is not None: + year = Format_Analyzer.to_year(year_str) + + print_verbose(5, "------> year_str: " + str(year_str)) + + # get new idx list + base_new_idx_list = idx_list_refined_txt + base_new_idx_list.append(raw_value_idx) + if unit_idx is not None: + base_new_idx_list.append(unit_idx) + if year_idx is not None: + base_new_idx_list.append(year_idx) + + base_rect = Rect(9999999, 9999999, -1, -1) + for idx in base_new_idx_list: + print_verbose(7, "................----> base_item=" + str(self.items[idx])) + base_rect.grow(self.items[idx].get_rect()) + + print_verbose(5, "----> base_rect=" + str(base_rect)) + + new_idx_list_in_rect = self.htmlpage.find_items_within_rect_all_categories(base_rect) + new_idx_list = list(set.intersection(set(new_idx_list_in_rect), set(idx_list))) + + final_txt = get_txt_by_idx_list(new_idx_list) + print_verbose(5, '------> Final text: "' + str(final_txt) + '"') + + txt_match, final_txt_score = kpispecs.match_nodes([final_txt]) + print_verbose(5, "---> txt final_score=" + str(final_txt_score)) + if not txt_match: + print_verbose(5, "---> No match") + return None + + rect = Rect(9999999, 9999999, -1, -1) + anywhere_match_score = 9999999 + for idx in new_idx_list: + rect.grow(self.items[idx].get_rect()) + anywhere_match, anywhere_match_score_cur = kpispecs.match_anywhere_on_page(self.htmlpage, idx) + anywhere_match_score = min(anywhere_match_score, anywhere_match_score_cur) + if not anywhere_match: + print_verbose(5, "---> anywhere-match was not matched on this page. No other match possible.") + self.bad_page = True + return None + + # anywhere_match_score /= len(new_idx_list) , only do this when avg is used + + kpi_measure = KPIMeasure() + kpi_measure.kpi_id = kpispecs.kpi_id + kpi_measure.kpi_name = kpispecs.kpi_name + kpi_measure.src_file = "TODO" + kpi_measure.page_num = self.htmlpage.items[raw_value_idx].page_num + kpi_measure.item_ids = idx_list + kpi_measure.pos_x = value_rect.x0 # (rect.x0+rect.x1)*0.5 + kpi_measure.pos_y = value_rect.y0 # (rect.y0+rect.y1)*0.5 + kpi_measure.raw_txt = raw_value + kpi_measure.year = year + kpi_measure.value = kpispecs.extract_value(raw_value) + kpi_measure.score = final_txt_score + anywhere_match_score + kpi_measure.unit = unit_str + kpi_measure.match_type = "AC.default" + print_verbose( + 5, + "---> Match: " + + str(kpi_measure) + + ": final_txt_score=" + + str(final_txt_score) + + ",anywhere_match_score=" + + str(anywhere_match_score), + ) + + return kpi_measure + + def find_kpis_rec(self, kpispecs, cluster): + cur_kpi = self.find_kpis_single_node(kpispecs, cluster) + + res = [] + if cur_kpi is not None: + res = [cur_kpi] + + for c in cluster.children: + res.extend(self.find_kpis_rec(kpispecs, c)) + + return res + + def find_kpis(self, kpispecs): + if self.htmlcluster is None: + return [] + + res = self.find_kpis_rec(kpispecs, self.htmlcluster) + + return res + + def __init__(self, htmlcluster, htmlpage, default_year): + self.htmlcluster = htmlcluster + self.htmlpage = htmlpage + self.items = htmlpage.items + self.default_year = default_year + self.bad_page = False diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerDirectory.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerDirectory.py index ac7529e..97bb2ba 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerDirectory.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerDirectory.py @@ -13,24 +13,22 @@ class AnalyzerDirectory: - htmldirectory = None - analyzer_page = None - default_year = None - - - def __init__(self, htmldirectory, default_year): - self.htmldirectory = htmldirectory - self.analyzer_page = [] - for i in range(len(self.htmldirectory.htmlpages)): - p = htmldirectory.htmlpages[i] - self.analyzer_page.append(AnalyzerPage(p, default_year)) - if(config.global_analyze_multiple_pages_at_one and i < len(self.htmldirectory.htmlpages) - 1): - p_mult = HTMLPage.merge(p, htmldirectory.htmlpages[i+1]) - self.analyzer_page.append(AnalyzerPage(p_mult, default_year)) - self.default_year = default_year - + htmldirectory = None + analyzer_page = None + default_year = None - """ + def __init__(self, htmldirectory, default_year): + self.htmldirectory = htmldirectory + self.analyzer_page = [] + for i in range(len(self.htmldirectory.htmlpages)): + p = htmldirectory.htmlpages[i] + self.analyzer_page.append(AnalyzerPage(p, default_year)) + if config.global_analyze_multiple_pages_at_one and i < len(self.htmldirectory.htmlpages) - 1: + p_mult = HTMLPage.merge(p, htmldirectory.htmlpages[i + 1]) + self.analyzer_page.append(AnalyzerPage(p_mult, default_year)) + self.default_year = default_year + + """ TODO : Probably implementation not neccessary def adjust_scores_by_value_preference(self, lst, kpispecs): if(kpispecs.value_preference == 1.0 or len(lst) < 2): @@ -40,58 +38,46 @@ def adjust_scores_by_value_preference(self, lst, kpispecs): for k in lst: ... """ - - - - def fix_src_name(self, kpi_measures): - print_verbose(3, "self.htmldirectory.src_pdf_filename="+self.htmldirectory.src_pdf_filename) - res = [] - for k in kpi_measures: - k.set_file_path(self.htmldirectory.src_pdf_filename) - res.append(k) - return res - - - - def find_kpis(self, kpispecs): - # find all possible occurenes of kpi on all pages - - - res = [] - - for a in self.analyzer_page: - res.extend(a.find_kpis(kpispecs)) - - if(config.global_ignore_all_years): - res = KPIMeasure.remove_all_years(res) - #print("\n\n\n1:"+str(res)) - - res = KPIMeasure.remove_duplicates(res) + def fix_src_name(self, kpi_measures): + print_verbose(3, "self.htmldirectory.src_pdf_filename=" + self.htmldirectory.src_pdf_filename) + res = [] + for k in kpi_measures: + k.set_file_path(self.htmldirectory.src_pdf_filename) + res.append(k) + return res - #print("\n\n\n2:"+str(res)) + def find_kpis(self, kpispecs): + # find all possible occurenes of kpi on all pages - res = KPIMeasure.remove_bad_scores(res, kpispecs.minimum_score) - + res = [] + for a in self.analyzer_page: + res.extend(a.find_kpis(kpispecs)) - return res - - - def find_multiple_kpis(self, kpispecs_lst): - res = [] - - for k in kpispecs_lst: - res.extend(self.find_kpis(k)) - + if config.global_ignore_all_years: + res = KPIMeasure.remove_all_years(res) + + # print("\n\n\n1:"+str(res)) + + res = KPIMeasure.remove_duplicates(res) + + # print("\n\n\n2:"+str(res)) + + res = KPIMeasure.remove_bad_scores(res, kpispecs.minimum_score) + + return res + + def find_multiple_kpis(self, kpispecs_lst): + res = [] + + for k in kpispecs_lst: + res.extend(self.find_kpis(k)) + + res = KPIMeasure.remove_bad_years(res, self.default_year) + + res = KPIMeasure.remove_duplicates(res) + + res = self.fix_src_name(res) - - res = KPIMeasure.remove_bad_years(res, self.default_year) - - res = KPIMeasure.remove_duplicates(res) - - res = self.fix_src_name(res) - - - - return res \ No newline at end of file + return res diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerPage.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerPage.py index daedaa6..c39165f 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerPage.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerPage.py @@ -13,57 +13,55 @@ from AnalyzerTable import * from AnalyzerCluster import * + class AnalyzerPage: - htmlpage = None - analyzer_table = None - analyzer_cluster = None - default_year = None + htmlpage = None + analyzer_table = None + analyzer_cluster = None + default_year = None + + def __init__(self, htmlpage, default_year): + self.htmlpage = htmlpage + self.analyzer_table = [] + for t in self.htmlpage.tables: + self.analyzer_table.append(AnalyzerTable(t, self.htmlpage, default_year)) + sub_tabs = t.generate_sub_tables() + for s in sub_tabs: + self.analyzer_table.append(AnalyzerTable(s, self.htmlpage, default_year)) + + self.analyzer_cluster = [] + # self.analyzer_cluster.append(AnalyzerCluster(htmlpage.clusters, htmlpage, default_year)) + self.analyzer_cluster.append(AnalyzerCluster(htmlpage.clusters_text, htmlpage, default_year)) + + self.default_year = default_year + + def find_kpis(self, kpispecs): + # find all possible occurenes of kpi on that page + + print_verbose( + 1, " ==>>>> FIND KPIS '" + kpispecs.kpi_name + "' ON PAGE: " + str(self.htmlpage.page_num) + " <<<<<=====" + ) + print_verbose(9, self.htmlpage) + + res = [] + # 1. Tables + for a in self.analyzer_table: + res.extend(a.find_kpis(kpispecs)) + + # 2. Figures and Text (used for CDP reports) + # for a in self.analyzer_cluster: + # res.extend(a.find_kpis(kpispecs)) + + # 3. Regular text + # TODO + # 4. Remove dups + res = KPIMeasure.remove_duplicates(res) - def __init__(self, htmlpage, default_year): - self.htmlpage = htmlpage - self.analyzer_table = [] - for t in self.htmlpage.tables: - self.analyzer_table.append(AnalyzerTable(t, self.htmlpage, default_year)) - sub_tabs = t.generate_sub_tables() - for s in sub_tabs: - self.analyzer_table.append(AnalyzerTable(s, self.htmlpage, default_year)) - - self.analyzer_cluster = [] - #self.analyzer_cluster.append(AnalyzerCluster(htmlpage.clusters, htmlpage, default_year)) - self.analyzer_cluster.append(AnalyzerCluster(htmlpage.clusters_text, htmlpage, default_year)) - - - self.default_year = default_year - - def find_kpis(self, kpispecs): - # find all possible occurenes of kpi on that page - - print_verbose(1, " ==>>>> FIND KPIS '" + kpispecs.kpi_name + "' ON PAGE: "+str(self.htmlpage.page_num) + " <<<<<=====") - print_verbose(9, self.htmlpage) - - res = [] - # 1. Tables - for a in self.analyzer_table: - res.extend(a.find_kpis(kpispecs)) - - # 2. Figures and Text (used for CDP reports) - #for a in self.analyzer_cluster: - # res.extend(a.find_kpis(kpispecs)) - - # 3. Regular text - # TODO - - # 4. Remove dups - res = KPIMeasure.remove_duplicates(res) - - #5. Adjust coords - for k in res: - px, py = self.htmlpage.transform_coords(k.pos_x, k.pos_y) - k.pos_x = px - k.pos_y = py + # 5. Adjust coords + for k in res: + px, py = self.htmlpage.transform_coords(k.pos_x, k.pos_y) + k.pos_x = px + k.pos_y = py - - - return res - \ No newline at end of file + return res diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerTable.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerTable.py index f0f2194..231501b 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerTable.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/AnalyzerTable.py @@ -21,287 +21,265 @@ class AnalyzerTable: - - class YearRow: # a row that is assumed to be a headline with years - row_num = None - years = None # a mapping: year -> row, col - - def __init__(self, row_num, years): - self.row_num = row_num - self.years = years - - def __repr__(self): - return '' - - - - htmltable = None - htmlpage = None - items = None - default_year = None - - table_hierarchy = None # for each ix, a refernece to the parent ix (or -1, if root) - year_rows = None # all rows containing years, each will be a YearRow - - - def get_num_cols(self): - return self.htmltable.num_cols - - def get_num_rows(self): - return self.htmltable.num_rows - - def get_ix(self, i, j): - return self.htmltable.get_ix(i, j) - - def get_row_and_col_by_ix(self, ix): - return self.htmltable.get_row_and_col_by_ix(ix) - - def has_item_at_ix(self, i): # i=ix - return self.htmltable.has_item_at_ix(i) - - def has_item_at(self, i, j): # i=row, j=col - return self.htmltable.has_item_at(i, j) - - def get_item(self, i, j): - return self.htmltable.get_item(i, j) - - def get_item_by_ix(self, ix): - return self.htmltable.get_item_by_ix(ix) - - def find_next_non_empty_cell_return_row_only(self, i, j, dir): - while(i>0 and i> find_next_parent_cell: "+ str(r0) + ',' + str(c0) ) - d0 = self.get_depth(r0, c0, dir) - if(d0 == 999999999): - return -1 # empty cell - - if(dir==HIERARCHY_DIR_UP): - for i in range(r0 - 1, -1, -1): - d = self.get_depth(i, c0, dir) - if(d < d0): - #print("---------------->> " + str(i) + "," + str(c0)) - return self.get_ix(i, c0) - elif(dir==HIERARCHY_DIR_LEFT): - for j in range(c0 - 1, -1, -1): - d = self.get_depth(r0, j, dir) - if(d < d0): - return self.get_ix(r0, j) - - #print("---------------->> root") - return -1 # not found / root - - - - def calculate_hierarchy(self, dir): - for i in range(self.get_num_rows()): - for j in range(self.get_num_cols()): - self.table_hierarchy[dir][self.get_ix(i,j)] = self.find_next_parent_cell(i, j, dir) - - - def get_aligned_multirow_txt_with_rect(self, r0, c0): - def go(dir, init_depth): - res = [] - rect = Rect(9999999, 9999999, -1, -1) - r = r0 + dir - while(r > 0 and r < self.get_num_rows()): - if(not self.has_item_at(r, c0)): - break # nothing here (TODO: maybe just skip these?) - if(self.get_depth(r, c0, HIERARCHY_DIR_UP) != init_depth): - break # other depth => unrelated - # do we have other numeric values in that row? - has_num_values = False - for j in range(self.get_num_cols()): - if(j==c0): - continue - if(self.has_item_at(r, j) and Format_Analyzer.looks_weak_numeric(self.get_item(r, j).txt)): - has_num_values = True - break - if(has_num_values): - break - res.append(self.get_item(r, c0).txt) - rect.grow(self.get_item(r, c0).get_rect()) - r += dir - return res, rect - - - - - #sometimes a cell contains of multiple rows. we want to match the whole cell, - #but to avoid overmatching, we make sure that in such cases in no other connected row - #any values can be stored - - res = [] - rect = Rect(9999999, 9999999, -1, -1) - - if(not self.has_item_at(r0, c0)): - return res, rect - - res.append(self.get_item(r0, c0).txt ) - rect.grow(self.get_item(r0, c0).get_rect()) - - init_depth = self.get_depth(r0, c0, HIERARCHY_DIR_UP) - - res_up, rect_up = go(DIR_UPWARDS, init_depth) - res_down, rect_down = go(DIR_DOWNWARDS, init_depth) - - for i in range(max(len(res_up),len(res_down))): - if(i 0): - r -= 1 - if(not self.has_item_at(r, c0)): - continue # empty - txt = self.get_item(r, c0).txt - if(Format_Analyzer.looks_weak_non_numeric(txt)): - res.append(txt) - else: - if(break_at_number): - break #a number or soemthing that is no text - - - if(include_special_items and self.has_item_at(r0, c0)): - rect = self.get_item(r0, c0).get_rect() - rect.y0 = 0 - sp_idx = self.htmltable.find_special_item_idx_in_rect(rect) - for i in sp_idx: - res.append(self.items[i].txt) - return res - - - def get_txt_headline(self): - res = [] - for h_idx in self.htmltable.headline_idx: - res.append(self.items[h_idx].txt) - return res - - def get_first_non_empty_row(self, r0, c0): #typically, set r0 = 0 to search from top. If no such row found, then "num_rows" is returned - r = r0 - while(r < self.get_num_rows()): - if(self.has_item_at(r, c0)): - break - r +=1 - return r - - - def get_multi_row_headline(self, r0, c0, include_special_items): - # get first non-empty row - r = self.get_first_non_empty_row(r0, c0) - if(r == self.get_num_rows() or not Format_Analyzer.looks_weak_words(self.get_item(r, c0))): - return '' - - - res = '' - # special items? - if(include_special_items): - rect = self.get_item(r, c0).get_rect() - rect.y0 = 0 - sp_idx = self.htmltable.find_special_item_idx_in_rect(rect) - res += HTMLItem.concat_txt(htmlpage.unfold_idx_to_items(sp_idx)) - - # start at r,c0 - font = self.get_item(r, c0).get_font_characteristics() - max_row_spacing = self.get_item(r, c0).font_size * 1.5 - res += self.get_item(r, c0).txt - last_y1 = self.get_item(r, c0).get_rect().y1 - - while(r < self.get_num_rows() - 1): - r += 1 - if(not self.has_item_at(r, c0)): - continue - cur_font = self.get_item(r, c0).get_font_characteristics() - if(cur_font != font): - break # different font => break here - cur_rect = self.get_item(r, c0).get_rect() - if(cur_rect.y0 > last_y1 + max_row_spacing): - break # rows are too far apart => break here - txt = self.get_item(r, c0).txt - if(not Format_Analyzer.looks_weak_words(txt)): - break # not a text anymore => break here - res += ' ' + txt - - - - return res - - - def row_looks_like_year_line(self, r0): #at least one column is a year? if yes, an dict with year -> row_num, col_num is returned. If not, None - num_years = 0 - res = {} - for j in range(self.get_num_cols()): - if(self.has_item_at(r0, j)): - txt = self.get_item(r0, j).txt - if(Format_Analyzer.looks_year(txt)): - num_years +=1 - y = Format_Analyzer.to_year(self.get_item(r0, j).txt) - if(y not in res): - res[y] = (r0, j) - elif(Format_Analyzer.looks_numeric(txt)): - # some other number occured => this is probably not a line of years - return None - - return res if num_years >= 1 else None - - - """ + class YearRow: # a row that is assumed to be a headline with years + row_num = None + years = None # a mapping: year -> row, col + + def __init__(self, row_num, years): + self.row_num = row_num + self.years = years + + def __repr__(self): + return "" + + htmltable = None + htmlpage = None + items = None + default_year = None + + table_hierarchy = None # for each ix, a refernece to the parent ix (or -1, if root) + year_rows = None # all rows containing years, each will be a YearRow + + def get_num_cols(self): + return self.htmltable.num_cols + + def get_num_rows(self): + return self.htmltable.num_rows + + def get_ix(self, i, j): + return self.htmltable.get_ix(i, j) + + def get_row_and_col_by_ix(self, ix): + return self.htmltable.get_row_and_col_by_ix(ix) + + def has_item_at_ix(self, i): # i=ix + return self.htmltable.has_item_at_ix(i) + + def has_item_at(self, i, j): # i=row, j=col + return self.htmltable.has_item_at(i, j) + + def get_item(self, i, j): + return self.htmltable.get_item(i, j) + + def get_item_by_ix(self, ix): + return self.htmltable.get_item_by_ix(ix) + + def find_next_non_empty_cell_return_row_only(self, i, j, dir): + while i > 0 and i < self.get_num_rows() - 1: + i += dir + if self.has_item_at(i, j): + return i + return -1 # not found + + def get_depth(self, i, j, dir): + ident_threshold = ( + (3.0 / 609.0) * self.htmlpage.page_width + if dir == HIERARCHY_DIR_UP + else (3.0 / 609.0) * self.htmlpage.page_height + ) + + if not self.has_item_at(i, j): + return 999999999 + it = self.get_item(i, j) + if dir == HIERARCHY_DIR_UP: + return it.get_depth() + ((int)(it.pos_x / ident_threshold) * 10000 if it.alignment == ALIGN_LEFT else 0) + return it.get_depth() + ((int)(it.pos_y / ident_threshold) * 10000 if it.alignment == ALIGN_LEFT else 0) + + def find_next_parent_cell(self, r0, c0, dir): # row=r0, col=c0 + # print("------------>> find_next_parent_cell: "+ str(r0) + ',' + str(c0) ) + d0 = self.get_depth(r0, c0, dir) + if d0 == 999999999: + return -1 # empty cell + + if dir == HIERARCHY_DIR_UP: + for i in range(r0 - 1, -1, -1): + d = self.get_depth(i, c0, dir) + if d < d0: + # print("---------------->> " + str(i) + "," + str(c0)) + return self.get_ix(i, c0) + elif dir == HIERARCHY_DIR_LEFT: + for j in range(c0 - 1, -1, -1): + d = self.get_depth(r0, j, dir) + if d < d0: + return self.get_ix(r0, j) + + # print("---------------->> root") + return -1 # not found / root + + def calculate_hierarchy(self, dir): + for i in range(self.get_num_rows()): + for j in range(self.get_num_cols()): + self.table_hierarchy[dir][self.get_ix(i, j)] = self.find_next_parent_cell(i, j, dir) + + def get_aligned_multirow_txt_with_rect(self, r0, c0): + def go(dir, init_depth): + res = [] + rect = Rect(9999999, 9999999, -1, -1) + r = r0 + dir + while r > 0 and r < self.get_num_rows(): + if not self.has_item_at(r, c0): + break # nothing here (TODO: maybe just skip these?) + if self.get_depth(r, c0, HIERARCHY_DIR_UP) != init_depth: + break # other depth => unrelated + # do we have other numeric values in that row? + has_num_values = False + for j in range(self.get_num_cols()): + if j == c0: + continue + if self.has_item_at(r, j) and Format_Analyzer.looks_weak_numeric(self.get_item(r, j).txt): + has_num_values = True + break + if has_num_values: + break + res.append(self.get_item(r, c0).txt) + rect.grow(self.get_item(r, c0).get_rect()) + r += dir + return res, rect + + # sometimes a cell contains of multiple rows. we want to match the whole cell, + # but to avoid overmatching, we make sure that in such cases in no other connected row + # any values can be stored + + res = [] + rect = Rect(9999999, 9999999, -1, -1) + + if not self.has_item_at(r0, c0): + return res, rect + + res.append(self.get_item(r0, c0).txt) + rect.grow(self.get_item(r0, c0).get_rect()) + + init_depth = self.get_depth(r0, c0, HIERARCHY_DIR_UP) + + res_up, rect_up = go(DIR_UPWARDS, init_depth) + res_down, rect_down = go(DIR_DOWNWARDS, init_depth) + + for i in range(max(len(res_up), len(res_down))): + if i < len(res_up): + res.append(res_up[i]) + if i < len(res_down): + res.append(res_down[i]) + rect.grow(rect_up) + rect.grow(rect_down) + + return res, rect + + def get_txt_nodes(self, r0, c0, dir, include_special_items): + res = [] + ix = self.get_ix(r0, c0) + rect = Rect(9999999, 9999999, -1, -1) + while self.has_item_at_ix(ix): + r, c = self.get_row_and_col_by_ix(ix) + cur_res, cur_rect = self.get_aligned_multirow_txt_with_rect(r, c) + rect.grow(cur_rect) + res.extend(cur_res) + + # rect.grow(self.get_item_by_ix(ix).get_rect()) + # res.append(self.get_item_by_ix(ix).txt) + ix = self.table_hierarchy[dir][ix] + + if include_special_items: + sp_idx = self.htmltable.find_special_item_idx_in_rect(rect) + for i in sp_idx: + res.append(self.items[i].txt) + return res + + def get_txt_nodes_above(self, r0, c0, include_special_items, break_at_number): + # search for text items that are above the current cell + res = [] + r = r0 + while r > 0: + r -= 1 + if not self.has_item_at(r, c0): + continue # empty + txt = self.get_item(r, c0).txt + if Format_Analyzer.looks_weak_non_numeric(txt): + res.append(txt) + else: + if break_at_number: + break # a number or soemthing that is no text + + if include_special_items and self.has_item_at(r0, c0): + rect = self.get_item(r0, c0).get_rect() + rect.y0 = 0 + sp_idx = self.htmltable.find_special_item_idx_in_rect(rect) + for i in sp_idx: + res.append(self.items[i].txt) + return res + + def get_txt_headline(self): + res = [] + for h_idx in self.htmltable.headline_idx: + res.append(self.items[h_idx].txt) + return res + + def get_first_non_empty_row( + self, r0, c0 + ): # typically, set r0 = 0 to search from top. If no such row found, then "num_rows" is returned + r = r0 + while r < self.get_num_rows(): + if self.has_item_at(r, c0): + break + r += 1 + return r + + def get_multi_row_headline(self, r0, c0, include_special_items): + # get first non-empty row + r = self.get_first_non_empty_row(r0, c0) + if r == self.get_num_rows() or not Format_Analyzer.looks_weak_words(self.get_item(r, c0)): + return "" + + res = "" + # special items? + if include_special_items: + rect = self.get_item(r, c0).get_rect() + rect.y0 = 0 + sp_idx = self.htmltable.find_special_item_idx_in_rect(rect) + res += HTMLItem.concat_txt(htmlpage.unfold_idx_to_items(sp_idx)) + + # start at r,c0 + font = self.get_item(r, c0).get_font_characteristics() + max_row_spacing = self.get_item(r, c0).font_size * 1.5 + res += self.get_item(r, c0).txt + last_y1 = self.get_item(r, c0).get_rect().y1 + + while r < self.get_num_rows() - 1: + r += 1 + if not self.has_item_at(r, c0): + continue + cur_font = self.get_item(r, c0).get_font_characteristics() + if cur_font != font: + break # different font => break here + cur_rect = self.get_item(r, c0).get_rect() + if cur_rect.y0 > last_y1 + max_row_spacing: + break # rows are too far apart => break here + txt = self.get_item(r, c0).txt + if not Format_Analyzer.looks_weak_words(txt): + break # not a text anymore => break here + res += " " + txt + + return res + + def row_looks_like_year_line( + self, r0 + ): # at least one column is a year? if yes, an dict with year -> row_num, col_num is returned. If not, None + num_years = 0 + res = {} + for j in range(self.get_num_cols()): + if self.has_item_at(r0, j): + txt = self.get_item(r0, j).txt + if Format_Analyzer.looks_year(txt): + num_years += 1 + y = Format_Analyzer.to_year(self.get_item(r0, j).txt) + if y not in res: + res[y] = (r0, j) + elif Format_Analyzer.looks_numeric(txt): + # some other number occured => this is probably not a line of years + return None + + return res if num_years >= 1 else None + + """ old / not used anymore def is_table_with_years_at_top(self): #do we have a top row with years? if yes, a dict with year -> row_num, col_num is returned. If not, None for j in range(self.get_num_cols()): @@ -315,520 +293,583 @@ def is_table_with_years_at_top(self): #do we have a top row with years? if yes, return None """ - - def find_all_year_rows(self): - self.year_rows = [] - for i in range(self.get_num_rows()): - years = self.row_looks_like_year_line(i) - if(years is not None): - self.year_rows.append(AnalyzerTable.YearRow(i, years)) - - - def find_applicable_year_line(self, r0): - first_match = None - for i in range(len(self.year_rows)-1, -1, -1): - if(self.year_rows[i].row_num <= r0): - first_match = self.year_rows[i]#.years - break - - # now we look for a better match (that contain more year values) - if(first_match is not None): - - best_match = first_match - - for i in range(len(self.year_rows)-1, -1, -1): - if(self.year_rows[i].row_num < first_match.row_num and self.year_rows[i].row_num >= first_match.row_num - 7): #max 7 lines above - if(len(self.year_rows[i].years) > len(best_match.years)): - best_match = self.year_rows[i] - - return best_match.years - - - return None - - - def find_applicable_items_for_table_with_years(self, r0): - def return_items_for_row(r, years): - if(years is None or len(years) == 0): - return None - - res = {} - for y, cell in years.items(): - if(cell[0] != r): # never match items that are in same row as the year numbers - res[y] = self.get_item(r, cell[1]) - return res - - def contains_items(d): # check if dict d contains some items at all - for y, it in d.items(): - if(it is not None): - return True - return False - - def advance_row(init_depth, r1): - r = r1 + 1 - while(r < self.get_num_rows()): - if(not self.has_item_at(r, 0)): - return r - if(self.get_depth(r, 0, HIERARCHY_DIR_UP) <= init_depth): - break# another left item was found at same depth => stop process - r += 1 - return self.get_num_rows() #nothing found - - r = r0 - init_depth = self.get_depth(r0, 0, HIERARCHY_DIR_UP) - while(r < self.get_num_rows()): - cur_years = self.find_applicable_year_line(r) - print_verbose(8, '.........-> r='+str(r)+', cur_years='+str(cur_years)) - cur_items = return_items_for_row(r, cur_years) - if(cur_items is not None and contains_items(cur_items)): - #we found the applicable items - return r, cur_items - #advace to next row - r = advance_row(init_depth, r) - return self.get_num_rows(), None - - - def find_applicable_row_with_items_for_any_left_oriented_table(self, r0): - - def contains_items(r): # check if dict d contains some items at all - for j in range(self.get_num_cols()): - if(self.has_item_at(r, j) and Format_Analyzer.looks_numeric(self.get_item(r, j).txt)): - return True - return False - - def advance_row(init_depth, r1): - r = r1 + 1 - while(r < self.get_num_rows()): - if(not self.has_item_at(r, 0)): - return r - if(self.get_depth(r, 0, HIERARCHY_DIR_UP) <= init_depth): - break# another left item was found at same depth => stop process - r += 1 - return self.get_num_rows() #nothing found - - r = r0 - init_depth = self.get_depth(r0, 0, HIERARCHY_DIR_UP) - while(r < self.get_num_rows()): - if(contains_items(r)): - #we found the applicable items - return r - #advance to next row - r = advance_row(init_depth, r) - return None - - def find_applicable_unit_item(self, kpispecs, r0): - # returns the applicable item that contains the corresponding unit - sp_item = self.htmltable.find_applying_special_item(r0) - print_verbose(7, '....unit_item->sp_item='+str(sp_item)) - if(sp_item is not None and kpispecs.match_unit(sp_item.txt)): - return sp_item.txt - - # look for other unit items - search_rect = self.htmltable.rows[r0] - search_rect.y0 = 0 #self.htmltable.table_rect.y0 - self.htmlpage.page_height * 0.125 - items_idx = self.htmlpage.find_items_within_rect(search_rect, [CAT_HEADLINE, CAT_OTHER_TEXT, CAT_TABLE_DATA, CAT_TABLE_HEADLINE, CAT_TABLE_SPECIAL, CAT_MISC, CAT_FOOTER]) - match_idx = -1 - for i in items_idx: - txt = self.htmlpage.explode_item(i) - print_verbose(10,'.......trying instead: ' + txt) - if(kpispecs.match_unit(txt)): - print_verbose(10,'...........===> match!') - if(match_idx == -1 or self.items[i].pos_y > self.items[match_idx].pos_y): - print_verbose(10,'...........===> better then previous match. new match_idx='+str(i)) - match_idx = i - if(match_idx != -1): - return self.htmlpage.explode_item(match_idx) - - return None - - - def search_year_agressive(self, search_rect, min_year, max_year, base_pos_x, base_pos_y, aggressive_year_pattern): - items_idx = self.htmlpage.find_items_within_rect(search_rect, [CAT_HEADLINE, CAT_OTHER_TEXT, CAT_TABLE_DATA, CAT_TABLE_HEADLINE, CAT_TABLE_SPECIAL, CAT_MISC, CAT_FOOTER]) - best_year = -1 - best_dist = 9999999 - for i in items_idx: - it = self.items[i] - cur_x, cur_y = it.get_rect().get_center() - dist_x = abs(cur_x - base_pos_x) - dist_y = abs(cur_y - base_pos_y) - if(cur_x > base_pos_x): - dist_x *= 3.0 #prefer items to the left - dist_1 = min(dist_x, dist_y) - dist_2 = max(dist_x, dist_y) / 3.0 # prefer ortoghonally aligned items - cur_dist = (dist_1*dist_1 + dist_2*dist_2)**0.5 - if(cur_dist < best_dist): - for w in it.words: - cur_year = None - if(not aggressive_year_pattern): - if(Format_Analyzer.looks_year(w.txt)): - cur_year = Format_Analyzer.to_year(w.txt) #by Lei - #cur_year = int(w.txt) - else: - cur_year = Format_Analyzer.looks_year_extended(w.txt) - print_verbose(11, '..................... Analyzing possible year string: "'+w.txt+'" => ' +str(cur_year)) - - if(cur_year is not None and cur_year >= min_year and cur_year <= max_year): - best_year = cur_year - best_dist = cur_dist - - return best_year - - - - - def find_kpi_with_direct_years(self, kpispecs, bonus): - # find KPIs that are directly aligned with year headline - # Example: - # 2019 2018 2017 - # Hydrocarbons 123 456 789 - # - # ==> Result: 2019: 123 ; 2018: 456; 2017: 789 - - - if(self.year_rows is None or len(self.year_rows) == 0): - return [] # invalid table for this search - - - # base score from headlines: - print_verbose(5, '<<< ================================= >>>') - print_verbose(5, '<<== FIND_KPI_WITH_DIRECT_YEARS ==>>') - print_verbose(5, '<<< ================================= >>>') - print_verbose(5, ' ') - - print_verbose(5, 'year_rows = ' + str(self.year_rows)) - - print_verbose(5, 'Looking at headlines') - h_txt_nodes = self.get_txt_headline() - h_match_dummy, h_score = kpispecs.match_nodes(h_txt_nodes) - h_score *= 0.5 #decay factor for headline - print_verbose(5, 'Headline: ' + str(h_txt_nodes)+ ', score=' + str(h_score)) - - if(h_score < 0): - return [] # headline contains something that must be excluded - - - - # normal score from acutal items: - res = [] - - previous_txt_node_with_no_values = '' - - - for i in range(self.get_num_rows()): - txt_nodes = self.get_txt_nodes(i, 0, HIERARCHY_DIR_UP, True) - txt_nodes = txt_nodes + ([previous_txt_node_with_no_values] if previous_txt_node_with_no_values != '' and previous_txt_node_with_no_values not in txt_nodes else []) - print_verbose(5, 'Looking at row i='+str(i)+', txt_nodes='+str(txt_nodes)) - txt_match, score = kpispecs.match_nodes(txt_nodes) - print_verbose(5, '---> score='+str(score)) - if(not txt_match): - print_verbose(5, '---> No match') - continue #no match - value_row, value_items = self.find_applicable_items_for_table_with_years(i) - if(value_items is None): - print_verbose(5, '---> No values found') - if(self.has_item_at(i, 0)): - previous_txt_node_with_no_values = self.get_item(i, 0).txt - else: - previous_txt_node_with_no_values = '' - continue #no values found - missmatch_value = False - print_verbose(6, '-------> value_row / value_items= '+str(value_row)+' / '+str(value_items)) - for y, it in value_items.items(): - if(it is None): - continue - if(not kpispecs.match_value(it.txt)): - missmatch_value = True - break - if(missmatch_value): - print_verbose(5, '---> Value missmatch') - continue # value missmatch - txt_unit = self.find_applicable_unit_item(kpispecs, value_row) - if(txt_unit is None): - print_verbose(5, '---> Unit not matched') - continue #unit not matched - - # determine score multiplier - multiplier = 1.0 - - for y, it in value_items.items(): - if(it is None): - continue - multiplier *= 1.2 - - if(multiplier > 1.0): - multiplier /= 1.2 # we counted one item too much, so divide once - - for y, it in value_items.items(): - if(it is None): - continue - - anywhere_match, anywhere_match_score = kpispecs.match_anywhere_on_page(self.htmlpage, it.this_id) - if(not anywhere_match): - print_verbose(5, '---> anywhere-match was not matched on this page. No other match possible.') - return [] - - total_score = score + h_score + anywhere_match_score + bonus - if(total_score < kpispecs.minimum_score): - print_verbose(5, '---> Total score '+str(total_score)+' is less than min. score '+str(kpispecs.minimum_score)) - continue - - - kpi_measure = KPIMeasure() - kpi_measure.kpi_id = kpispecs.kpi_id - kpi_measure.kpi_name = kpispecs.kpi_name - kpi_measure.src_file = 'TODO' - kpi_measure.page_num = self.htmlpage.page_num - kpi_measure.item_ids = [it.this_id] - kpi_measure.pos_x = it.pos_x - kpi_measure.pos_y = it.pos_y - kpi_measure.raw_txt = it.txt - kpi_measure.year = y - kpi_measure.value = kpispecs.extract_value(it.txt) - kpi_measure.score = total_score * multiplier - kpi_measure.unit = txt_unit - kpi_measure.match_type= 'AT.direct' - res.append(kpi_measure) - print_verbose(4, '---> Match: ' + str(kpi_measure) + ': score='+str(score)+',h_score='+str(h_score)+',anywhere_match_score='+str(anywhere_match_score)+',bonus='+str(bonus)+', multiplier='+str(multiplier)) - - res = KPIMeasure.remove_duplicates(res) - return res - - - - - def find_kpi_with_indirect_years(self, kpispecs, bonus): - # find KPIs that are only indirectly connected with year headline, or not at all - # Example: - # Year 2019 - # Upstream Downstream Total - # Sales: 1 2 3 - # - # ==> Result (for KPI "Downstream Sales"): 2019: 2 - - - - # base score from headlines: - print_verbose(5, '<<< ================================= >>>') - print_verbose(5, '<<== FIND_KPI_WITH_INDIRECT_YEARS ==>>') - print_verbose(5, '<<< ================================= >>>') - print_verbose(5, ' ') - - print_verbose(5, 'Looking at headlines') - h_txt_nodes = self.get_txt_headline() - h_match_dummy, h_score = kpispecs.match_nodes(h_txt_nodes) - if(h_score < 0): - return [] # headline contains something that must be excluded - - - h_score *= 0.1 #decay factor for headline - print_verbose(5, 'Headline: ' + str(h_txt_nodes)+ ', score=' + str(h_score)) - - - - # normal score from acutal items: - res = [] - - - # find possible fixed left columns - fixed_left_cols = [0] - for j in range(1, self.get_num_cols()): - if(self.htmltable.col_looks_like_text_col(j)): - fixed_left_cols.append(j) - - print_verbose(6, 'fixed_left_cols='+str(fixed_left_cols)) - - for fixed_left_column in fixed_left_cols: - #fixed_left_column = 6 - for i in range(self.get_num_rows()): - txt_nodes_row = self.get_txt_nodes(i, fixed_left_column, HIERARCHY_DIR_UP, True) - font_size_row_node = None - if(self.has_item_at(i, fixed_left_column)): - font_size_row_node = self.get_item(i, fixed_left_column).font_size - - print_verbose(5, 'Looking at row i='+str(i)+', txt_nodes_row='+str(txt_nodes_row)+',fonz_size='+str(font_size_row_node)) - - value_row = self.find_applicable_row_with_items_for_any_left_oriented_table(i) - if(value_row is None): - print_verbose(5, '---> No values found') - continue #no values found - - txt_unit = self.find_applicable_unit_item(kpispecs, value_row) - - print_verbose(6, '-------> txt_unit='+str(txt_unit)) - - if(txt_unit is None): - print_verbose(5, '---> Unit not matched') - continue #unit not matched - - years = self.find_applicable_year_line(i) - - print_verbose(5, '--> years= '+str(years)) - for j in range(fixed_left_column + 1, self.get_num_cols()): - if(not self.has_item_at(value_row, j)): - continue # empty cell - - it = self.get_item(value_row, j) - value_txt = it.txt - font_size_cell = it.font_size - print_verbose(5, '\n-> Looking at cell: row,col=' + str(value_row) +',' + str(j)+':' + str(value_txt) + ', font_size=' +str(font_size_cell)) - if(font_size_row_node is not None and (font_size_cell < font_size_row_node / 1.75 or font_size_cell > font_size_row_node * 1.75)): - print_verbose(5, '---> Fontsize missmatch') - continue # value missmatch - if(not kpispecs.match_value(value_txt)): - print_verbose(5, '---> Value missmatch') - continue # value missmatch - - txt_nodes_col = self.get_txt_nodes_above(value_row, j, True, False) # TODO: Really use False here? - print_verbose(5, '---> txt_nodes_col='+str(txt_nodes_col)) - print_verbose(6, '......... matching against: ' + str(txt_nodes_row + txt_nodes_col)) - txt_match, score = kpispecs.match_nodes(txt_nodes_row + txt_nodes_col) - print_verbose(5, '---> score='+str(score)) - if(not txt_match): - print_verbose(5, '---> No match') - continue #no match - - # find best year match - kpi_year = -1 - bad_year_match = False - if(years is not None): - min_diff = 9999999 - for y, cell in years.items(): - cur_diff = abs(cell[1] - j) - if(cur_diff < min_diff): - min_diff = cur_diff - kpi_year = Format_Analyzer.to_year(self.get_item(cell[0], cell[1]).txt) #by Lei - if(cell[0]==i and cell[1]==j): - # we matched the year as value itself => must be wrong - bad_year_match = True - - if(bad_year_match): - print_verbose(5, '-----> Bad year match: year is same as value') - continue - - # if still no year found, then search more - if(kpi_year == -1): - print_verbose(7, '......---> no year found. searching more agressively') - search_rect = self.htmltable.table_rect - search_rect.y1 = it.pos_y - next_non_empty_row = self.find_next_non_empty_cell_return_row_only(value_row, j, DIR_DOWNWARDS) - max_add = 999999 # we also want to to look a LITTLE bit downwards, in case of two-line-description cells that refer to this cell - if(next_non_empty_row != -1): - max_add = (self.get_item(next_non_empty_row, j).pos_y - (it.pos_y + it.height)) * 0.8 - - search_rect.y1 += min(it.height * 1.0, max_add) - print_verbose(8, '..............-> max_add='+str(max_add)+ ', y1(old)=' +str(it.pos_y) + ', y1(new)=' + str(search_rect.y1) ) - base_pos_x, base_pos_y = it.get_rect().get_center() - kpi_year = self.search_year_agressive(search_rect, self.default_year - 10, self.default_year, base_pos_x, base_pos_y, aggressive_year_pattern = False) - if(kpi_year == -1): - print_verbose(7, '........---> still no year found. searching even more agressively') - kpi_year = self.search_year_agressive(search_rect, self.default_year - 10, self.default_year, base_pos_x, base_pos_y, aggressive_year_pattern = True) - if(kpi_year == -1): - print_verbose(7, '........---> still no year found. searching even MORE agressively') - search_rect.y0 = 0 - search_rect.x0 = 0 - search_rect.x1 = 9999999 - kpi_year = self.search_year_agressive(search_rect, self.default_year - 10, self.default_year, base_pos_x, base_pos_y, aggressive_year_pattern = True) - print_verbose(7, '.........-> year found='+str(kpi_year) if kpi_year != -1 else '..........-> still nothing found. give up.') - - - anywhere_match, anywhere_match_score = kpispecs.match_anywhere_on_page(self.htmlpage, it.this_id) - if(not anywhere_match): - print_verbose(5, '---> anywhere-match was not matched on this page. No other match possible.') - return [] - - total_score = score + h_score + anywhere_match_score + bonus - - if(total_score < kpispecs.minimum_score): - print_verbose(5, '---> Total score '+str(total_score)+' is less than min. score '+str(kpispecs.minimum_score)) - continue - - kpi_measure = KPIMeasure() - kpi_measure.kpi_id = kpispecs.kpi_id - kpi_measure.kpi_name = kpispecs.kpi_name - kpi_measure.src_file = 'TODO' - kpi_measure.page_num = self.htmlpage.page_num - kpi_measure.item_ids = [it.this_id] - kpi_measure.pos_x = it.pos_x - kpi_measure.pos_y = it.pos_y - kpi_measure.raw_txt = it.txt - kpi_measure.year = kpi_year if kpi_year != -1 else self.default_year - kpi_measure.value = kpispecs.extract_value(it.txt) - kpi_measure.score = total_score - kpi_measure.unit = txt_unit - kpi_measure.match_type= 'AT.indirect' - kpi_measure.tmp = i # we use this to determine score multiplier - res.append(kpi_measure) - print_verbose(4, '---> Match: ' + str(kpi_measure)) - - res = KPIMeasure.remove_duplicates(res) - - - row_multiplier = {} - row_years_taken = {} - for kpi in res: - row = kpi.tmp - if(row in row_multiplier): - if(kpi.year not in row_years_taken[row]): - row_multiplier[row] *= 1.2 - row_years_taken[row].append(kpi.year) - else: - row_multiplier[row] = 1.0 - row_years_taken[row] = [kpi.year] - - for kpi in res: - kpi.score *= row_multiplier[kpi.tmp] - - - print_verbose(5, "===> found AT.indirect KPIs on Page " + str(self.htmlpage.page_num) + ": " + str(res) + "\n================================") - - return res - - - - - - - def find_kpis(self, kpispecs): - # Find all possible occurences of KPIs in that table - print_verbose(6, "Analyzing Table :\n" +str(self.htmltable)) - res = [] - res.extend(self.find_kpi_with_direct_years(kpispecs, 100)) # with years - res.extend(self.find_kpi_with_indirect_years(kpispecs, 0)) # without years - - - # ... add others here - - res = KPIMeasure.remove_duplicates(res) - - if(len(res) > 0): - print_verbose(2, "Found KPIs on Page " + str(self.htmlpage.page_num) +", Table : \n" +str(self.htmltable.get_printed_repr()) + "\n" + str(res) + "\n================================") - - return res - - - - - - - - - def __init__(self, htmltable, htmlpage, default_year): - self.htmltable = htmltable - self.htmlpage = htmlpage - self.items = htmlpage.items - self.default_year = default_year - self.table_hierarchy = [] - for i in range(2): - self.table_hierarchy.append([-2] * len(self.htmltable.idx)) - self.calculate_hierarchy(HIERARCHY_DIR_UP) - self.calculate_hierarchy(HIERARCHY_DIR_LEFT) - self.years = [] - self.find_all_year_rows() - - - - - - - \ No newline at end of file + def find_all_year_rows(self): + self.year_rows = [] + for i in range(self.get_num_rows()): + years = self.row_looks_like_year_line(i) + if years is not None: + self.year_rows.append(AnalyzerTable.YearRow(i, years)) + + def find_applicable_year_line(self, r0): + first_match = None + for i in range(len(self.year_rows) - 1, -1, -1): + if self.year_rows[i].row_num <= r0: + first_match = self.year_rows[i] # .years + break + + # now we look for a better match (that contain more year values) + if first_match is not None: + best_match = first_match + + for i in range(len(self.year_rows) - 1, -1, -1): + if ( + self.year_rows[i].row_num < first_match.row_num + and self.year_rows[i].row_num >= first_match.row_num - 7 + ): # max 7 lines above + if len(self.year_rows[i].years) > len(best_match.years): + best_match = self.year_rows[i] + + return best_match.years + + return None + + def find_applicable_items_for_table_with_years(self, r0): + def return_items_for_row(r, years): + if years is None or len(years) == 0: + return None + + res = {} + for y, cell in years.items(): + if cell[0] != r: # never match items that are in same row as the year numbers + res[y] = self.get_item(r, cell[1]) + return res + + def contains_items(d): # check if dict d contains some items at all + for y, it in d.items(): + if it is not None: + return True + return False + + def advance_row(init_depth, r1): + r = r1 + 1 + while r < self.get_num_rows(): + if not self.has_item_at(r, 0): + return r + if self.get_depth(r, 0, HIERARCHY_DIR_UP) <= init_depth: + break # another left item was found at same depth => stop process + r += 1 + return self.get_num_rows() # nothing found + + r = r0 + init_depth = self.get_depth(r0, 0, HIERARCHY_DIR_UP) + while r < self.get_num_rows(): + cur_years = self.find_applicable_year_line(r) + print_verbose(8, ".........-> r=" + str(r) + ", cur_years=" + str(cur_years)) + cur_items = return_items_for_row(r, cur_years) + if cur_items is not None and contains_items(cur_items): + # we found the applicable items + return r, cur_items + # advace to next row + r = advance_row(init_depth, r) + return self.get_num_rows(), None + + def find_applicable_row_with_items_for_any_left_oriented_table(self, r0): + def contains_items(r): # check if dict d contains some items at all + for j in range(self.get_num_cols()): + if self.has_item_at(r, j) and Format_Analyzer.looks_numeric(self.get_item(r, j).txt): + return True + return False + + def advance_row(init_depth, r1): + r = r1 + 1 + while r < self.get_num_rows(): + if not self.has_item_at(r, 0): + return r + if self.get_depth(r, 0, HIERARCHY_DIR_UP) <= init_depth: + break # another left item was found at same depth => stop process + r += 1 + return self.get_num_rows() # nothing found + + r = r0 + init_depth = self.get_depth(r0, 0, HIERARCHY_DIR_UP) + while r < self.get_num_rows(): + if contains_items(r): + # we found the applicable items + return r + # advance to next row + r = advance_row(init_depth, r) + return None + + def find_applicable_unit_item(self, kpispecs, r0): + # returns the applicable item that contains the corresponding unit + sp_item = self.htmltable.find_applying_special_item(r0) + print_verbose(7, "....unit_item->sp_item=" + str(sp_item)) + if sp_item is not None and kpispecs.match_unit(sp_item.txt): + return sp_item.txt + + # look for other unit items + search_rect = self.htmltable.rows[r0] + search_rect.y0 = 0 # self.htmltable.table_rect.y0 - self.htmlpage.page_height * 0.125 + items_idx = self.htmlpage.find_items_within_rect( + search_rect, + [CAT_HEADLINE, CAT_OTHER_TEXT, CAT_TABLE_DATA, CAT_TABLE_HEADLINE, CAT_TABLE_SPECIAL, CAT_MISC, CAT_FOOTER], + ) + match_idx = -1 + for i in items_idx: + txt = self.htmlpage.explode_item(i) + print_verbose(10, ".......trying instead: " + txt) + if kpispecs.match_unit(txt): + print_verbose(10, "...........===> match!") + if match_idx == -1 or self.items[i].pos_y > self.items[match_idx].pos_y: + print_verbose(10, "...........===> better then previous match. new match_idx=" + str(i)) + match_idx = i + if match_idx != -1: + return self.htmlpage.explode_item(match_idx) + + return None + + def search_year_agressive(self, search_rect, min_year, max_year, base_pos_x, base_pos_y, aggressive_year_pattern): + items_idx = self.htmlpage.find_items_within_rect( + search_rect, + [CAT_HEADLINE, CAT_OTHER_TEXT, CAT_TABLE_DATA, CAT_TABLE_HEADLINE, CAT_TABLE_SPECIAL, CAT_MISC, CAT_FOOTER], + ) + best_year = -1 + best_dist = 9999999 + for i in items_idx: + it = self.items[i] + cur_x, cur_y = it.get_rect().get_center() + dist_x = abs(cur_x - base_pos_x) + dist_y = abs(cur_y - base_pos_y) + if cur_x > base_pos_x: + dist_x *= 3.0 # prefer items to the left + dist_1 = min(dist_x, dist_y) + dist_2 = max(dist_x, dist_y) / 3.0 # prefer ortoghonally aligned items + cur_dist = (dist_1 * dist_1 + dist_2 * dist_2) ** 0.5 + if cur_dist < best_dist: + for w in it.words: + cur_year = None + if not aggressive_year_pattern: + if Format_Analyzer.looks_year(w.txt): + cur_year = Format_Analyzer.to_year(w.txt) # by Lei + # cur_year = int(w.txt) + else: + cur_year = Format_Analyzer.looks_year_extended(w.txt) + print_verbose( + 11, '..................... Analyzing possible year string: "' + w.txt + '" => ' + str(cur_year) + ) + + if cur_year is not None and cur_year >= min_year and cur_year <= max_year: + best_year = cur_year + best_dist = cur_dist + + return best_year + + def find_kpi_with_direct_years(self, kpispecs, bonus): + # find KPIs that are directly aligned with year headline + # Example: + # 2019 2018 2017 + # Hydrocarbons 123 456 789 + # + # ==> Result: 2019: 123 ; 2018: 456; 2017: 789 + + if self.year_rows is None or len(self.year_rows) == 0: + return [] # invalid table for this search + + # base score from headlines: + print_verbose(5, "<<< ================================= >>>") + print_verbose(5, "<<== FIND_KPI_WITH_DIRECT_YEARS ==>>") + print_verbose(5, "<<< ================================= >>>") + print_verbose(5, " ") + + print_verbose(5, "year_rows = " + str(self.year_rows)) + + print_verbose(5, "Looking at headlines") + h_txt_nodes = self.get_txt_headline() + h_match_dummy, h_score = kpispecs.match_nodes(h_txt_nodes) + h_score *= 0.5 # decay factor for headline + print_verbose(5, "Headline: " + str(h_txt_nodes) + ", score=" + str(h_score)) + + if h_score < 0: + return [] # headline contains something that must be excluded + + # normal score from acutal items: + res = [] + + previous_txt_node_with_no_values = "" + + for i in range(self.get_num_rows()): + txt_nodes = self.get_txt_nodes(i, 0, HIERARCHY_DIR_UP, True) + txt_nodes = txt_nodes + ( + [previous_txt_node_with_no_values] + if previous_txt_node_with_no_values != "" and previous_txt_node_with_no_values not in txt_nodes + else [] + ) + print_verbose(5, "Looking at row i=" + str(i) + ", txt_nodes=" + str(txt_nodes)) + txt_match, score = kpispecs.match_nodes(txt_nodes) + print_verbose(5, "---> score=" + str(score)) + if not txt_match: + print_verbose(5, "---> No match") + continue # no match + value_row, value_items = self.find_applicable_items_for_table_with_years(i) + if value_items is None: + print_verbose(5, "---> No values found") + if self.has_item_at(i, 0): + previous_txt_node_with_no_values = self.get_item(i, 0).txt + else: + previous_txt_node_with_no_values = "" + continue # no values found + missmatch_value = False + print_verbose(6, "-------> value_row / value_items= " + str(value_row) + " / " + str(value_items)) + for y, it in value_items.items(): + if it is None: + continue + if not kpispecs.match_value(it.txt): + missmatch_value = True + break + if missmatch_value: + print_verbose(5, "---> Value missmatch") + continue # value missmatch + txt_unit = self.find_applicable_unit_item(kpispecs, value_row) + if txt_unit is None: + print_verbose(5, "---> Unit not matched") + continue # unit not matched + + # determine score multiplier + multiplier = 1.0 + + for y, it in value_items.items(): + if it is None: + continue + multiplier *= 1.2 + + if multiplier > 1.0: + multiplier /= 1.2 # we counted one item too much, so divide once + + for y, it in value_items.items(): + if it is None: + continue + + anywhere_match, anywhere_match_score = kpispecs.match_anywhere_on_page(self.htmlpage, it.this_id) + if not anywhere_match: + print_verbose(5, "---> anywhere-match was not matched on this page. No other match possible.") + return [] + + total_score = score + h_score + anywhere_match_score + bonus + if total_score < kpispecs.minimum_score: + print_verbose( + 5, + "---> Total score " + + str(total_score) + + " is less than min. score " + + str(kpispecs.minimum_score), + ) + continue + + kpi_measure = KPIMeasure() + kpi_measure.kpi_id = kpispecs.kpi_id + kpi_measure.kpi_name = kpispecs.kpi_name + kpi_measure.src_file = "TODO" + kpi_measure.page_num = self.htmlpage.page_num + kpi_measure.item_ids = [it.this_id] + kpi_measure.pos_x = it.pos_x + kpi_measure.pos_y = it.pos_y + kpi_measure.raw_txt = it.txt + kpi_measure.year = y + kpi_measure.value = kpispecs.extract_value(it.txt) + kpi_measure.score = total_score * multiplier + kpi_measure.unit = txt_unit + kpi_measure.match_type = "AT.direct" + res.append(kpi_measure) + print_verbose( + 4, + "---> Match: " + + str(kpi_measure) + + ": score=" + + str(score) + + ",h_score=" + + str(h_score) + + ",anywhere_match_score=" + + str(anywhere_match_score) + + ",bonus=" + + str(bonus) + + ", multiplier=" + + str(multiplier), + ) + + res = KPIMeasure.remove_duplicates(res) + return res + + def find_kpi_with_indirect_years(self, kpispecs, bonus): + # find KPIs that are only indirectly connected with year headline, or not at all + # Example: + # Year 2019 + # Upstream Downstream Total + # Sales: 1 2 3 + # + # ==> Result (for KPI "Downstream Sales"): 2019: 2 + + # base score from headlines: + print_verbose(5, "<<< ================================= >>>") + print_verbose(5, "<<== FIND_KPI_WITH_INDIRECT_YEARS ==>>") + print_verbose(5, "<<< ================================= >>>") + print_verbose(5, " ") + + print_verbose(5, "Looking at headlines") + h_txt_nodes = self.get_txt_headline() + h_match_dummy, h_score = kpispecs.match_nodes(h_txt_nodes) + if h_score < 0: + return [] # headline contains something that must be excluded + + h_score *= 0.1 # decay factor for headline + print_verbose(5, "Headline: " + str(h_txt_nodes) + ", score=" + str(h_score)) + + # normal score from acutal items: + res = [] + + # find possible fixed left columns + fixed_left_cols = [0] + for j in range(1, self.get_num_cols()): + if self.htmltable.col_looks_like_text_col(j): + fixed_left_cols.append(j) + + print_verbose(6, "fixed_left_cols=" + str(fixed_left_cols)) + + for fixed_left_column in fixed_left_cols: + # fixed_left_column = 6 + for i in range(self.get_num_rows()): + txt_nodes_row = self.get_txt_nodes(i, fixed_left_column, HIERARCHY_DIR_UP, True) + font_size_row_node = None + if self.has_item_at(i, fixed_left_column): + font_size_row_node = self.get_item(i, fixed_left_column).font_size + + print_verbose( + 5, + "Looking at row i=" + + str(i) + + ", txt_nodes_row=" + + str(txt_nodes_row) + + ",fonz_size=" + + str(font_size_row_node), + ) + + value_row = self.find_applicable_row_with_items_for_any_left_oriented_table(i) + + if value_row is None: + print_verbose(5, "---> No values found") + continue # no values found + + txt_unit = self.find_applicable_unit_item(kpispecs, value_row) + + print_verbose(6, "-------> txt_unit=" + str(txt_unit)) + + if txt_unit is None: + print_verbose(5, "---> Unit not matched") + continue # unit not matched + + years = self.find_applicable_year_line(i) + + print_verbose(5, "--> years= " + str(years)) + for j in range(fixed_left_column + 1, self.get_num_cols()): + if not self.has_item_at(value_row, j): + continue # empty cell + + it = self.get_item(value_row, j) + value_txt = it.txt + font_size_cell = it.font_size + print_verbose( + 5, + "\n-> Looking at cell: row,col=" + + str(value_row) + + "," + + str(j) + + ":" + + str(value_txt) + + ", font_size=" + + str(font_size_cell), + ) + if font_size_row_node is not None and ( + font_size_cell < font_size_row_node / 1.75 or font_size_cell > font_size_row_node * 1.75 + ): + print_verbose(5, "---> Fontsize missmatch") + continue # value missmatch + if not kpispecs.match_value(value_txt): + print_verbose(5, "---> Value missmatch") + continue # value missmatch + + txt_nodes_col = self.get_txt_nodes_above(value_row, j, True, False) # TODO: Really use False here? + print_verbose(5, "---> txt_nodes_col=" + str(txt_nodes_col)) + print_verbose(6, "......... matching against: " + str(txt_nodes_row + txt_nodes_col)) + txt_match, score = kpispecs.match_nodes(txt_nodes_row + txt_nodes_col) + print_verbose(5, "---> score=" + str(score)) + if not txt_match: + print_verbose(5, "---> No match") + continue # no match + + # find best year match + kpi_year = -1 + bad_year_match = False + if years is not None: + min_diff = 9999999 + for y, cell in years.items(): + cur_diff = abs(cell[1] - j) + if cur_diff < min_diff: + min_diff = cur_diff + kpi_year = Format_Analyzer.to_year(self.get_item(cell[0], cell[1]).txt) # by Lei + if cell[0] == i and cell[1] == j: + # we matched the year as value itself => must be wrong + bad_year_match = True + + if bad_year_match: + print_verbose(5, "-----> Bad year match: year is same as value") + continue + + # if still no year found, then search more + if kpi_year == -1: + print_verbose(7, "......---> no year found. searching more agressively") + search_rect = self.htmltable.table_rect + search_rect.y1 = it.pos_y + next_non_empty_row = self.find_next_non_empty_cell_return_row_only(value_row, j, DIR_DOWNWARDS) + max_add = 999999 # we also want to to look a LITTLE bit downwards, in case of two-line-description cells that refer to this cell + if next_non_empty_row != -1: + max_add = (self.get_item(next_non_empty_row, j).pos_y - (it.pos_y + it.height)) * 0.8 + + search_rect.y1 += min(it.height * 1.0, max_add) + print_verbose( + 8, + "..............-> max_add=" + + str(max_add) + + ", y1(old)=" + + str(it.pos_y) + + ", y1(new)=" + + str(search_rect.y1), + ) + base_pos_x, base_pos_y = it.get_rect().get_center() + kpi_year = self.search_year_agressive( + search_rect, + self.default_year - 10, + self.default_year, + base_pos_x, + base_pos_y, + aggressive_year_pattern=False, + ) + if kpi_year == -1: + print_verbose(7, "........---> still no year found. searching even more agressively") + kpi_year = self.search_year_agressive( + search_rect, + self.default_year - 10, + self.default_year, + base_pos_x, + base_pos_y, + aggressive_year_pattern=True, + ) + if kpi_year == -1: + print_verbose(7, "........---> still no year found. searching even MORE agressively") + search_rect.y0 = 0 + search_rect.x0 = 0 + search_rect.x1 = 9999999 + kpi_year = self.search_year_agressive( + search_rect, + self.default_year - 10, + self.default_year, + base_pos_x, + base_pos_y, + aggressive_year_pattern=True, + ) + print_verbose( + 7, + ".........-> year found=" + str(kpi_year) + if kpi_year != -1 + else "..........-> still nothing found. give up.", + ) + + anywhere_match, anywhere_match_score = kpispecs.match_anywhere_on_page(self.htmlpage, it.this_id) + if not anywhere_match: + print_verbose(5, "---> anywhere-match was not matched on this page. No other match possible.") + return [] + + total_score = score + h_score + anywhere_match_score + bonus + + if total_score < kpispecs.minimum_score: + print_verbose( + 5, + "---> Total score " + + str(total_score) + + " is less than min. score " + + str(kpispecs.minimum_score), + ) + continue + + kpi_measure = KPIMeasure() + kpi_measure.kpi_id = kpispecs.kpi_id + kpi_measure.kpi_name = kpispecs.kpi_name + kpi_measure.src_file = "TODO" + kpi_measure.page_num = self.htmlpage.page_num + kpi_measure.item_ids = [it.this_id] + kpi_measure.pos_x = it.pos_x + kpi_measure.pos_y = it.pos_y + kpi_measure.raw_txt = it.txt + kpi_measure.year = kpi_year if kpi_year != -1 else self.default_year + kpi_measure.value = kpispecs.extract_value(it.txt) + kpi_measure.score = total_score + kpi_measure.unit = txt_unit + kpi_measure.match_type = "AT.indirect" + kpi_measure.tmp = i # we use this to determine score multiplier + res.append(kpi_measure) + print_verbose(4, "---> Match: " + str(kpi_measure)) + + res = KPIMeasure.remove_duplicates(res) + + row_multiplier = {} + row_years_taken = {} + for kpi in res: + row = kpi.tmp + if row in row_multiplier: + if kpi.year not in row_years_taken[row]: + row_multiplier[row] *= 1.2 + row_years_taken[row].append(kpi.year) + else: + row_multiplier[row] = 1.0 + row_years_taken[row] = [kpi.year] + + for kpi in res: + kpi.score *= row_multiplier[kpi.tmp] + + print_verbose( + 5, + "===> found AT.indirect KPIs on Page " + + str(self.htmlpage.page_num) + + ": " + + str(res) + + "\n================================", + ) + + return res + + def find_kpis(self, kpispecs): + # Find all possible occurences of KPIs in that table + print_verbose(6, "Analyzing Table :\n" + str(self.htmltable)) + res = [] + res.extend(self.find_kpi_with_direct_years(kpispecs, 100)) # with years + res.extend(self.find_kpi_with_indirect_years(kpispecs, 0)) # without years + + # ... add others here + + res = KPIMeasure.remove_duplicates(res) + + if len(res) > 0: + print_verbose( + 2, + "Found KPIs on Page " + + str(self.htmlpage.page_num) + + ", Table : \n" + + str(self.htmltable.get_printed_repr()) + + "\n" + + str(res) + + "\n================================", + ) + + return res + + def __init__(self, htmltable, htmlpage, default_year): + self.htmltable = htmltable + self.htmlpage = htmlpage + self.items = htmlpage.items + self.default_year = default_year + self.table_hierarchy = [] + for i in range(2): + self.table_hierarchy.append([-2] * len(self.htmltable.idx)) + self.calculate_hierarchy(HIERARCHY_DIR_UP) + self.calculate_hierarchy(HIERARCHY_DIR_LEFT) + self.years = [] + self.find_all_year_rows() diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/ConsoleTable.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/ConsoleTable.py index 097fcf2..414391c 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/ConsoleTable.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/ConsoleTable.py @@ -6,112 +6,97 @@ # ============================================================================================================================ - class ConsoleTable: + FORMAT_NICE_CONSOLE = 0 + FORMAT_CSV = 1 + + num_cols = None + cells = None + + def __init__(self, num_cols): + self.num_cols = num_cols + self.cells = [] + + def get_num_rows(self): + return int(len(self.cells) / self.num_cols) + + def get(self, row, col): + return self.cells[col + row * self.num_cols] + + def get_native_col_width(self, col): + res = 0 + for i in range(self.get_num_rows()): + res = max(res, len(self.get(i, col))) + return res + + def to_string(self, max_width=None, min_col_width=None, use_format=FORMAT_NICE_CONSOLE): + if use_format == ConsoleTable.FORMAT_NICE_CONSOLE: + cols = [] + for j in range(self.num_cols): + cols.append(self.get_native_col_width(j)) + + max_col_width = max_width + while True: + total_width = 1 + for j in range(self.num_cols): + total_width += 1 + min(max_col_width, cols[j]) + if total_width <= max_width: + break + if max_col_width <= min_col_width: + break + max_col_width -= 1 + + for j in range(self.num_cols): + cols[j] = min(max_col_width, cols[j]) + + res = "" + + # headline + res += "\u2554" + for j in range(self.num_cols): + res += "\u2550" * cols[j] + res += "\u2566" if j < self.num_cols - 1 else "\u2557\n" + + # content + for i in range(self.get_num_rows()): + # frame line + if i > 0: + res += "\u2560" + for j in range(self.num_cols): + res += "\u2550" * cols[j] + res += "\u256c" if j < self.num_cols - 1 else "\u2563" + res += "\n" + + # content line + res += "\u2551" + for j in range(self.num_cols): + txt = self.get(i, j).replace("\n", " ") + res += str(txt)[: cols[j]].ljust(cols[j], " ") + res += "\u2551" + + res += "\n" + + # footer line + res += "\u255a" + for j in range(self.num_cols): + res += "\u2550" * cols[j] + res += "\u2569" if j < self.num_cols - 1 else "\u255d\n" + + return res + + if use_format == ConsoleTable.FORMAT_CSV: + res = "" + + # content + for i in range(self.get_num_rows()): + # content line + for j in range(self.num_cols): + if j > 0: + res += "," + res += '"' + self.get(i, j).replace("\n", " ").replace('"', "") + '"' + + res += "\n" + + return res - FORMAT_NICE_CONSOLE = 0 - FORMAT_CSV = 1 - - num_cols = None - cells = None - - - - def __init__(self, num_cols): - self.num_cols = num_cols - self.cells = [] - - def get_num_rows(self): - return int(len(self.cells) / self.num_cols) - - def get(self, row, col): - return self.cells[col + row * self.num_cols] - - def get_native_col_width(self, col): - res = 0 - for i in range(self.get_num_rows()): - res = max(res, len(self.get(i, col))) - return res - - - def to_string(self, max_width = None, min_col_width = None, use_format = FORMAT_NICE_CONSOLE): - - if(use_format == ConsoleTable.FORMAT_NICE_CONSOLE): - - cols = [] - for j in range(self.num_cols): - cols.append(self.get_native_col_width(j)) - - - - max_col_width = max_width - while(True): - total_width = 1 - for j in range(self.num_cols): - total_width += 1 + min(max_col_width,cols[j]) - if(total_width <= max_width): - break - if(max_col_width <= min_col_width): - break - max_col_width -= 1 - - for j in range(self.num_cols): - cols[j] = min(max_col_width, cols[j]) - - - res = '' - - # headline - res += '\u2554' - for j in range(self.num_cols): - res += '\u2550' * cols[j] - res += '\u2566' if j < self.num_cols - 1 else '\u2557\n' - - # content - for i in range(self.get_num_rows()): - # frame line - if(i>0): - res += '\u2560' - for j in range(self.num_cols): - res += '\u2550'*cols[j] - res += '\u256c' if j < self.num_cols -1 else '\u2563' - res += '\n' - - # content line - res += '\u2551' - for j in range(self.num_cols): - txt = self.get(i, j).replace('\n', ' ') - res += str(txt)[:cols[j]].ljust(cols[j], ' ') - res += '\u2551' - - - res += '\n' - - # footer line - res += '\u255a' - for j in range(self.num_cols): - res += '\u2550' * cols[j] - res += '\u2569' if j < self.num_cols - 1 else '\u255d\n' - - return res - - if(use_format == ConsoleTable.FORMAT_CSV): - - - res = '' - - # content - for i in range(self.get_num_rows()): - - # content line - for j in range(self.num_cols): - if(j>0): - res += ',' - res += '"' + self.get(i, j).replace('\n', ' ').replace('"', '') + '"' - - res += '\n' - - return res - - - return 'Unknown format for ConsoleTable\n' + return "Unknown format for ConsoleTable\n" diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/DataImportExport.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/DataImportExport.py index 37223b9..51f7a06 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/DataImportExport.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/DataImportExport.py @@ -8,105 +8,92 @@ from globals import * from Format_Analyzer import * + class DataImportExport: + @staticmethod + def import_files(src_folder, dst_folder, file_list, file_type): + def ext(f): + res = [f] + res.extend(Format_Analyzer.extract_file_path(f)) + return res + + # print(src_folder) + file_paths = glob.glob(src_folder + "/**/*." + file_type, recursive=True) + file_paths = [ext(f.replace("\\", "/")) for f in file_paths] # unixize all file paths + + res = [] + + info_file_contents = {} + + for fname in file_list: + fname_clean = fname.lower().strip() # (new) + if fname_clean[-4:] == "." + file_type: + fname_clean = fname_clean[:-4] + + # fname_clean = fname_clean.strip() (old) + fpath = None + + # look case-sensitive + for f in file_paths: + if f[2] + "." + f[3] == fname: + # match! + fpath = f + break + + # look case-insensitive + if fpath is None: + for f in file_paths: + # if(f[2].lower().strip() == fname_clean): (old) + if f[2].lower() == fname_clean: + # match! + fpath = f + break + + # print('SRC: "' + fname + '" -> ' + str(fpath)) + new_file_name = None + if fpath is None: + print_verbose(0, 'Warning: "' + fname + '" not found.') + else: + new_file_name = Format_Analyzer.cleanup_filename(fpath[2] + "." + fpath[3]) + new_file_path = remove_trailing_slash(dst_folder) + "/" + new_file_name + info_file_contents[new_file_path] = fpath[0] + + if not file_exists(new_file_path): + print_verbose(1, 'Copy "' + fpath[0] + '" to "' + new_file_path + '"') + shutil.copyfile(fpath[0], new_file_path) + + res.append((fname, new_file_name)) + + # print(info_file_contents) + + # save info file contents: + jsonpickle.set_preferred_backend("json") + jsonpickle.set_encoder_options("json", sort_keys=True, indent=4) + data = jsonpickle.encode(info_file_contents) + f_info = open(remove_trailing_slash(config.global_working_folder) + "/info.json", "w") + f_info.write(data) + f_info.close() + + return res + + @staticmethod + def save_info_file_contents(file_paths): + info_file_contents = {} + + for f in file_paths: + info_file_contents[f[0]] = f[0] + + jsonpickle.set_preferred_backend("json") + jsonpickle.set_encoder_options("json", sort_keys=True, indent=4) + data = jsonpickle.encode(info_file_contents) + f_info = open(remove_trailing_slash(config.global_working_folder) + "/info.json", "w") + f_info.write(data) + f_info.close() - @staticmethod - def import_files(src_folder, dst_folder, file_list, file_type): - def ext(f): - res = [f] - res.extend(Format_Analyzer.extract_file_path(f)) - return res - - #print(src_folder) - file_paths = glob.glob(src_folder + '/**/*.' + file_type, recursive=True) - file_paths = [ext(f.replace('\\','/')) for f in file_paths] #unixize all file paths - - res = [] - - info_file_contents = {} - - for fname in file_list: - fname_clean = fname.lower().strip() #(new) - if(fname_clean[-4:]=='.'+file_type): - fname_clean=fname_clean[:-4] - - #fname_clean = fname_clean.strip() (old) - fpath = None - - # look case-sensitive - for f in file_paths: - if(f[2] + '.' + f[3] == fname): - #match! - fpath = f - break - - - - # look case-insensitive - if(fpath is None): - for f in file_paths: - #if(f[2].lower().strip() == fname_clean): (old) - if(f[2].lower() == fname_clean): - # match! - fpath = f - break - - #print('SRC: "' + fname + '" -> ' + str(fpath)) - new_file_name = None - if(fpath is None): - print_verbose(0, 'Warning: "' + fname + '" not found.') - else: - new_file_name = Format_Analyzer.cleanup_filename(fpath[2] + '.' + fpath[3]) - new_file_path = remove_trailing_slash(dst_folder) + '/' + new_file_name - info_file_contents[new_file_path] = fpath[0] - - if(not file_exists(new_file_path)): - print_verbose(1, 'Copy "' + fpath[0] + '" to "' + new_file_path + '"') - shutil.copyfile(fpath[0], new_file_path) - - - - res.append((fname, new_file_name)) - - #print(info_file_contents) - - #save info file contents: - jsonpickle.set_preferred_backend('json') - jsonpickle.set_encoder_options('json', sort_keys=True, indent=4) - data = jsonpickle.encode(info_file_contents) - f_info = open(remove_trailing_slash(config.global_working_folder) + '/info.json', "w") - f_info.write(data) - f_info.close() - - return res - - - @staticmethod - def save_info_file_contents(file_paths): - info_file_contents = {} - - for f in file_paths: - info_file_contents[f[0]] = f[0] - - jsonpickle.set_preferred_backend('json') - jsonpickle.set_encoder_options('json', sort_keys=True, indent=4) - data = jsonpickle.encode(info_file_contents) - f_info = open(remove_trailing_slash(config.global_working_folder) + '/info.json', "w") - f_info.write(data) - f_info.close() - - - @staticmethod - def load_info_file_contents(json_file): - f = open(json_file, "r") - data = f.read() - f.close() - obj = jsonpickle.decode(data) - return obj - - - - - - - \ No newline at end of file + @staticmethod + def load_info_file_contents(json_file): + f = open(json_file, "r") + data = f.read() + f.close() + obj = jsonpickle.decode(data) + return obj diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/Format_Analyzer.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/Format_Analyzer.py index ada344c..5459583 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/Format_Analyzer.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/Format_Analyzer.py @@ -8,204 +8,211 @@ from globals import * - class Format_Analyzer: - #pattern_numeric = re.compile(r'^(-?[ ]*[0-9]*(,[0-9][0-9][0-9])*(\.[0-9]+)?|-?[ ]*[0-9]*(\.[0-9][0-9][0-9])*(,[0-9]+)?)$') - #pattern_numeric = re.compile(r'^\(?(-?\(?[ ]*[0-9]*(,[0-9][0-9][0-9])*(\.[0-9]+)?|-?[ ]*[0-9]*(\.[0-9][0-9][0-9])*(,[0-9]+)?)\)?$') - pattern_numeric = re.compile(r'^\(?(-?\(?[ ]*[0-9]*(,[0-9][0-9][0-9])*(\.[0-9]+)?|-?[ ]*[0-9]*(\.[0-9][0-9][0-9])*(,[0-9]+)?)\)?(\*?)*$') #by Lei - pattern_numeric_multi = re.compile(r'^(\(?(-?\(?[ ]*[0-9]*(,[0-9][0-9][0-9])*(\.[0-9]+)?|-?[ ]*[0-9]*(\.[0-9][0-9][0-9])*(,[0-9]+)?)\)?)+$') - #pattern_year = re.compile(r'^(19[8-9][0-9]|20[0-9][0-9])$') #1980-2099 - pattern_year = re.compile(r'^[^0-9]*(19[8-9][0-9](/[0-9][0-9])?|20[0-9][0-9](/[0-9][0-9])?)[^0-9]*$') #1980-2099 #by Lei - pattern_year_extended_1 = re.compile(r'^.*[0-3][0-9](\.|\\|/)[0-3][0-9](\.|\\|/)(19[8-9][0-9]|20[0-9][0-9]).*$') #1980-2099 - pattern_year_extended_2 = re.compile(r'^.*(19[8-9][0-9]|20[0-9][0-9])(\.|\\|/)[0-3][0-9](\.|\\|/)[0-3][0-9].*$') #1980-2099 - - pattern_year_in_txt = re.compile(r'(19[8-9][0-9]|20[0-9][0-9])') #1980-2099 - pattern_null = re.compile(r'^(null|n/a|na|-*|\.*|,*|;*)$') - pattern_whitespace = re.compile("^\s+|\s+$") - pattern_ends_with_full_stop = re.compile(".*\.$") - pattern_pagenum = re.compile(r'^[0-9]{1,3}$') - pattern_non_numeric_char = re.compile(r'[^0-9\-\.]') - pattern_file_path = re.compile(r'(.*/)(.*)\.(.*)') - pattern_cleanup_text = re.compile(r'[^a-z ]') - - pattern_cleanup_filename = re.compile(r'(\[|\]|\(|\))') - - pattern_footnote = re.compile(r'[0-9]+\).*') - - - @staticmethod - def trim_whitespaces(val): - return re.sub(Format_Analyzer.pattern_whitespace, '', val) - - - @staticmethod - def looks_numeric(val): - #return Format_Analyzer.pattern_numeric.match(val.replace(' ', '').replace('$', '')) and len(val)>0 - val0 = remove_bad_chars(val, ' ()$%') - #return Format_Analyzer.pattern_numeric.match(val0) and len(val0)>0 - return Format_Analyzer.pattern_numeric.match(val0.replace('WLTP', '')) and len(val0)>0 #by Lei - - @staticmethod - def looks_numeric_multiple(val): - #return Format_Analyzer.pattern_numeric.match(val.replace(' ', '').replace('$', '')) and len(val)>0 - return Format_Analyzer.pattern_numeric_multi.match(remove_bad_chars(val, ' ()$%')) and len(val)>0 - - @staticmethod - def looks_weak_numeric(val): - num_numbers = sum(c.isnumeric() for c in val) - return num_numbers > 0 - - @staticmethod - def looks_percentage(val): - return looks_weak_numeric(val) and '%' in val - - @staticmethod - def to_year(val): -# val0 = remove_bad_chars(val, 'FY') #by Lei - val0 = re.sub(r'[^0-9]', '', val) #by Lei - return int(val0) #by Lei -# val0=remove_letter_before_year(val) -# return int(val0.replace(' ', '')) - - @staticmethod - def looks_year(val): - return Format_Analyzer.pattern_year.match(val.replace(' ', '')) - - @staticmethod - def looks_year_extended(val): #return year if found, otherwise None - if(Format_Analyzer.pattern_year_extended_1.match(val.replace(' ', ''))): - return int(Format_Analyzer.pattern_year_extended_1.match(val.replace(' ', '')).groups()[2]) - if(Format_Analyzer.pattern_year_extended_2.match(val.replace(' ', ''))): - return int(Format_Analyzer.pattern_year_extended_2.match(val.replace(' ', '')).groups()[0]) - if(Format_Analyzer.looks_year(val)): - return Format_Analyzer.to_year(val) - - return None - - - @staticmethod - def cleanup_number(val): - s = re.sub(Format_Analyzer.pattern_non_numeric_char, '' , val) - # filter out extra dots - first_dot = s.find('.') - if(first_dot == -1): - return s - - return s[0:first_dot+1] + s[first_dot+1:].replace('.', '') - - @staticmethod - def to_int_number(val, limit_chars=None): - s = Format_Analyzer.cleanup_number(val) - if(s == ''): - return None - return int(float(s if limit_chars is None else s[0:limit_chars])) - - @staticmethod - def to_float_number(val, limit_chars=None): - s = Format_Analyzer.cleanup_number(val) - if(s == ''): - return None - return float(s) - - @staticmethod - def cleanup_text(val): #remove all characters except letters and spaces - return re.sub(Format_Analyzer.pattern_cleanup_text, '', val) - - - @staticmethod - def looks_null(val): - return Format_Analyzer.pattern_null.match(val.replace(' ', '').lower()) - - @staticmethod - def looks_words(val): - num_letters = sum(c.isalpha() for c in val) - return num_letters > 5 - - @staticmethod - def looks_weak_words(val): - num_letters = sum(c.isalpha() for c in val) - num_numbers = sum(c.isnumeric() for c in val) - return num_letters > 2 and num_letters > num_numbers - - @staticmethod - def looks_weak_non_numeric(val): - num_letters = sum(c.isalpha() for c in val) - num_numbers = sum(c.isnumeric() for c in val) - num_others = len(val) - (num_letters + num_numbers) - #return (num_letters + num_others > 1) or (num_letters + num_others > num_numbers) - #return (num_letters + num_others > 1 and num_numbers < (num_letters + num_others) * 2 + 1) or (num_letters + num_others > num_numbers) - return num_letters > 0 and num_letters > num_numbers and ((num_letters + num_others > 1 and num_numbers < (num_letters + num_others) * 2 + 1) or (num_letters + num_others > num_numbers)) - - @staticmethod - def looks_other_special_item(val): - return len(val) < 4 and not Format_Analyzer.looks_words(val) and not Format_Analyzer.looks_numeric(val) - - - - @staticmethod - def looks_pagenum(val): - return Format_Analyzer.pattern_pagenum.match(val.replace(' ', '')) and len(val)>0 and val.replace(' ', '') != '0' - - @staticmethod - def looks_running_text(val): - txt = Format_Analyzer.trim_whitespaces(val) - num_full_stops = txt.count(".") - num_comma = txt.count(",") - num_space = txt.count(" ") - ends_with_full_stop = True if Format_Analyzer.pattern_ends_with_full_stop.match(txt) else False - txtlen = len(txt) - num_letters = sum(c.isalpha() for c in txt) - if(num_letters<20): - return False #too short - if(num_letters / txtlen < 0.5): - return False #strange: less than 50% are letters - if(num_space < 5): - return False #only 5 words or less - if(num_comma / txtlen < 0.004 and num_full_stops / txtlen < 0.002): - return False #too few commans / full stops - if(ends_with_full_stop): - #looks like a sentence - return True - - #does not end with full stop, so we require more conditons to hold - return ((num_full_stops > 2) or \ - (num_full_stops > 1 and num_comma > 1)) and \ - (num_letters > 30) and \ - (num_space > 10) - - - - @staticmethod - def looks_footnote(val): - return Format_Analyzer.pattern_footnote.match(val.replace(' ', '').lower()) - - @staticmethod - def exclude_all_years(val): - return re.sub(Format_Analyzer.pattern_year, '' , val) - - - @staticmethod - def extract_file_path(val): - return Format_Analyzer.pattern_file_path.match(val).groups() - - @staticmethod - def extract_file_name(val): - fp = Format_Analyzer.extract_file_path('/'+val.replace('\\','/')) - return fp[1] + '.' + fp[2] - - @staticmethod - def cleanup_filename(val): - return re.sub(Format_Analyzer.pattern_cleanup_filename, '_', val) - - @staticmethod - def extract_year_from_text(val): - lst = list(set(re.findall(Format_Analyzer.pattern_year_in_txt, val))) - if(len(lst) == 1): - return int(lst[0]) - return None # no or multiple results - - - @staticmethod - def cnt_overlapping_items(l0, l1): - return len(list(set(l0) & set(l1))) - \ No newline at end of file + # pattern_numeric = re.compile(r'^(-?[ ]*[0-9]*(,[0-9][0-9][0-9])*(\.[0-9]+)?|-?[ ]*[0-9]*(\.[0-9][0-9][0-9])*(,[0-9]+)?)$') + # pattern_numeric = re.compile(r'^\(?(-?\(?[ ]*[0-9]*(,[0-9][0-9][0-9])*(\.[0-9]+)?|-?[ ]*[0-9]*(\.[0-9][0-9][0-9])*(,[0-9]+)?)\)?$') + pattern_numeric = re.compile( + r"^\(?(-?\(?[ ]*[0-9]*(,[0-9][0-9][0-9])*(\.[0-9]+)?|-?[ ]*[0-9]*(\.[0-9][0-9][0-9])*(,[0-9]+)?)\)?(\*?)*$" + ) # by Lei + pattern_numeric_multi = re.compile( + r"^(\(?(-?\(?[ ]*[0-9]*(,[0-9][0-9][0-9])*(\.[0-9]+)?|-?[ ]*[0-9]*(\.[0-9][0-9][0-9])*(,[0-9]+)?)\)?)+$" + ) + # pattern_year = re.compile(r'^(19[8-9][0-9]|20[0-9][0-9])$') #1980-2099 + pattern_year = re.compile( + r"^[^0-9]*(19[8-9][0-9](/[0-9][0-9])?|20[0-9][0-9](/[0-9][0-9])?)[^0-9]*$" + ) # 1980-2099 #by Lei + pattern_year_extended_1 = re.compile( + r"^.*[0-3][0-9](\.|\\|/)[0-3][0-9](\.|\\|/)(19[8-9][0-9]|20[0-9][0-9]).*$" + ) # 1980-2099 + pattern_year_extended_2 = re.compile( + r"^.*(19[8-9][0-9]|20[0-9][0-9])(\.|\\|/)[0-3][0-9](\.|\\|/)[0-3][0-9].*$" + ) # 1980-2099 + + pattern_year_in_txt = re.compile(r"(19[8-9][0-9]|20[0-9][0-9])") # 1980-2099 + pattern_null = re.compile(r"^(null|n/a|na|-*|\.*|,*|;*)$") + pattern_whitespace = re.compile("^\s+|\s+$") + pattern_ends_with_full_stop = re.compile(".*\.$") + pattern_pagenum = re.compile(r"^[0-9]{1,3}$") + pattern_non_numeric_char = re.compile(r"[^0-9\-\.]") + pattern_file_path = re.compile(r"(.*/)(.*)\.(.*)") + pattern_cleanup_text = re.compile(r"[^a-z ]") + + pattern_cleanup_filename = re.compile(r"(\[|\]|\(|\))") + + pattern_footnote = re.compile(r"[0-9]+\).*") + + @staticmethod + def trim_whitespaces(val): + return re.sub(Format_Analyzer.pattern_whitespace, "", val) + + @staticmethod + def looks_numeric(val): + # return Format_Analyzer.pattern_numeric.match(val.replace(' ', '').replace('$', '')) and len(val)>0 + val0 = remove_bad_chars(val, " ()$%") + # return Format_Analyzer.pattern_numeric.match(val0) and len(val0)>0 + return Format_Analyzer.pattern_numeric.match(val0.replace("WLTP", "")) and len(val0) > 0 # by Lei + + @staticmethod + def looks_numeric_multiple(val): + # return Format_Analyzer.pattern_numeric.match(val.replace(' ', '').replace('$', '')) and len(val)>0 + return Format_Analyzer.pattern_numeric_multi.match(remove_bad_chars(val, " ()$%")) and len(val) > 0 + + @staticmethod + def looks_weak_numeric(val): + num_numbers = sum(c.isnumeric() for c in val) + return num_numbers > 0 + + @staticmethod + def looks_percentage(val): + return looks_weak_numeric(val) and "%" in val + + @staticmethod + def to_year(val): + # val0 = remove_bad_chars(val, 'FY') #by Lei + val0 = re.sub(r"[^0-9]", "", val) # by Lei + return int(val0) # by Lei + + # val0=remove_letter_before_year(val) + # return int(val0.replace(' ', '')) + + @staticmethod + def looks_year(val): + return Format_Analyzer.pattern_year.match(val.replace(" ", "")) + + @staticmethod + def looks_year_extended(val): # return year if found, otherwise None + if Format_Analyzer.pattern_year_extended_1.match(val.replace(" ", "")): + return int(Format_Analyzer.pattern_year_extended_1.match(val.replace(" ", "")).groups()[2]) + if Format_Analyzer.pattern_year_extended_2.match(val.replace(" ", "")): + return int(Format_Analyzer.pattern_year_extended_2.match(val.replace(" ", "")).groups()[0]) + if Format_Analyzer.looks_year(val): + return Format_Analyzer.to_year(val) + + return None + + @staticmethod + def cleanup_number(val): + s = re.sub(Format_Analyzer.pattern_non_numeric_char, "", val) + # filter out extra dots + first_dot = s.find(".") + if first_dot == -1: + return s + + return s[0 : first_dot + 1] + s[first_dot + 1 :].replace(".", "") + + @staticmethod + def to_int_number(val, limit_chars=None): + s = Format_Analyzer.cleanup_number(val) + if s == "": + return None + return int(float(s if limit_chars is None else s[0:limit_chars])) + + @staticmethod + def to_float_number(val, limit_chars=None): + s = Format_Analyzer.cleanup_number(val) + if s == "": + return None + return float(s) + + @staticmethod + def cleanup_text(val): # remove all characters except letters and spaces + return re.sub(Format_Analyzer.pattern_cleanup_text, "", val) + + @staticmethod + def looks_null(val): + return Format_Analyzer.pattern_null.match(val.replace(" ", "").lower()) + + @staticmethod + def looks_words(val): + num_letters = sum(c.isalpha() for c in val) + return num_letters > 5 + + @staticmethod + def looks_weak_words(val): + num_letters = sum(c.isalpha() for c in val) + num_numbers = sum(c.isnumeric() for c in val) + return num_letters > 2 and num_letters > num_numbers + + @staticmethod + def looks_weak_non_numeric(val): + num_letters = sum(c.isalpha() for c in val) + num_numbers = sum(c.isnumeric() for c in val) + num_others = len(val) - (num_letters + num_numbers) + # return (num_letters + num_others > 1) or (num_letters + num_others > num_numbers) + # return (num_letters + num_others > 1 and num_numbers < (num_letters + num_others) * 2 + 1) or (num_letters + num_others > num_numbers) + return ( + num_letters > 0 + and num_letters > num_numbers + and ( + (num_letters + num_others > 1 and num_numbers < (num_letters + num_others) * 2 + 1) + or (num_letters + num_others > num_numbers) + ) + ) + + @staticmethod + def looks_other_special_item(val): + return len(val) < 4 and not Format_Analyzer.looks_words(val) and not Format_Analyzer.looks_numeric(val) + + @staticmethod + def looks_pagenum(val): + return ( + Format_Analyzer.pattern_pagenum.match(val.replace(" ", "")) and len(val) > 0 and val.replace(" ", "") != "0" + ) + + @staticmethod + def looks_running_text(val): + txt = Format_Analyzer.trim_whitespaces(val) + num_full_stops = txt.count(".") + num_comma = txt.count(",") + num_space = txt.count(" ") + ends_with_full_stop = True if Format_Analyzer.pattern_ends_with_full_stop.match(txt) else False + txtlen = len(txt) + num_letters = sum(c.isalpha() for c in txt) + if num_letters < 20: + return False # too short + if num_letters / txtlen < 0.5: + return False # strange: less than 50% are letters + if num_space < 5: + return False # only 5 words or less + if num_comma / txtlen < 0.004 and num_full_stops / txtlen < 0.002: + return False # too few commans / full stops + if ends_with_full_stop: + # looks like a sentence + return True + + # does not end with full stop, so we require more conditons to hold + return ( + ((num_full_stops > 2) or (num_full_stops > 1 and num_comma > 1)) and (num_letters > 30) and (num_space > 10) + ) + + @staticmethod + def looks_footnote(val): + return Format_Analyzer.pattern_footnote.match(val.replace(" ", "").lower()) + + @staticmethod + def exclude_all_years(val): + return re.sub(Format_Analyzer.pattern_year, "", val) + + @staticmethod + def extract_file_path(val): + return Format_Analyzer.pattern_file_path.match(val).groups() + + @staticmethod + def extract_file_name(val): + fp = Format_Analyzer.extract_file_path("/" + val.replace("\\", "/")) + return fp[1] + "." + fp[2] + + @staticmethod + def cleanup_filename(val): + return re.sub(Format_Analyzer.pattern_cleanup_filename, "_", val) + + @staticmethod + def extract_year_from_text(val): + lst = list(set(re.findall(Format_Analyzer.pattern_year_in_txt, val))) + if len(lst) == 1: + return int(lst[0]) + return None # no or multiple results + + @staticmethod + def cnt_overlapping_items(l0, l1): + return len(list(set(l0) & set(l1))) diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLCluster.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLCluster.py index 7798954..758bc75 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLCluster.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLCluster.py @@ -16,219 +16,178 @@ import numpy CLUSTER_DISTANCE_MODE_EUCLIDIAN = 0 -CLUSTER_DISTANCE_MODE_RAW_TEXT = 1 +CLUSTER_DISTANCE_MODE_RAW_TEXT = 1 class HTMLCluster: + idx = None + # rect = None + children = None + items = None # dont export + flat_text = None # dont export - idx = None - #rect = None - children = None - items = None #dont export - flat_text = None #dont export - - - def __init__(self): - self.idx = -1 - self.children = [] - self.items = [] - #self.rect = Rect(99999,99999,-1,-1) - self.flat_text = "" - - - - def is_internal_node(self): - return len(self.children) > 0 - - def is_leaf(self): - return self.idx != -1 - - def set_idx(self, idx): - if(self.is_internal_node()): - raise ValueError('Node '+str(self) + ' is already an internal node') - self.idx = idx # now it's a child node - - def add_child(self, child): - if(self.is_leaf()): - raise ValueError('Node '+str(self) + ' is already a leaf node') - self.children.append(child) - - def calc_flat_text(self): - if(self.is_leaf()): - self.flat_text = str(self.items[self.idx].txt) - return - first = True - res = "" - for c in self.children: - if(not first): - res += ", " - c.calc_flat_text() - res += c.flat_text - first = False - self.flat_text = res - - def get_idx_list(self): - if(self.is_leaf()): - return [self.idx] - - res =[] - for c in self.children: - res.extend(c.get_idx_list()) - return res - - - def set_items_rec(self, items): - self.items = items - for c in self.children: - c.set_items_rec(items) - - def count_items(self): - if(self.is_leaf()): - return 1 - res = 0 - for c in self.children: - res += c.count_items() - return res - - def generate_rendering_colors_rec(self, h0=0.0, h1=0.75): # h = hue in [0,1] - if(self.is_leaf()): - self.items[self.idx].rendering_color = hsv_to_rgba((h0+h1)*0.5, 1, 1) - else: - num_items_per_child = [] - num_items_tot = 0 - for c in self.children: - cur_num = c.count_items() - num_items_per_child.append(cur_num) - num_items_tot += cur_num - num_items_acc = 0 - for i in range(len(self.children)): - self.children[i].generate_rendering_colors_rec(h0 + (h1-h0) * (num_items_acc/num_items_tot), h0 + (h1-h0) * ((num_items_acc + num_items_per_child[i]) /num_items_tot)) - num_items_acc += num_items_per_child[i] - - - - def regenerate_not_exported(self, items): - self.set_items_rec(items) - self.calc_flat_text() - - - def cleanup_for_export(self): - self.items = None - self.flat_text = None - for c in self.children: - c.cleanup_for_export() - - - def __repr__(self): - if(self.is_leaf()): - return str(self.items[self.idx]) - res = "<" - first = True - for c in self.children: - if(not first): - res += ", " - res += str(c) - first = False - res += ">" - return res - - - @staticmethod - def item_dist(it1, it2, mode): - if(mode == CLUSTER_DISTANCE_MODE_EUCLIDIAN): - it1_x, it1_y = it1.get_rect().get_center() - it2_x, it2_y = it2.get_rect().get_center() - return dist(it1_x, it1_y, it2_x, it2_y) - elif(mode == CLUSTER_DISTANCE_MODE_RAW_TEXT): - return dist(0, it1.pos_y, 0, it2.pos_y) - #return dist(it1.pos_x * 100, it1.pos_y, it2.pos_x * 100, it2.pos_y) #TODO: Add this a a new distance mode ! (20.09.2022) - - - - - raise ValueError('Invalid distance mode') - - - @staticmethod - def generate_clusters(items, mode): - print_verbose(3, "Regenerating clusters") - - if(len(items) < 2): - return None - - # generate a leaf for each items - nodes = [] - - - for it in items: - cur = HTMLCluster() - cur.items = items - cur.idx = it.this_id - nodes.append(cur) - - - print_verbose(3, 'Leaves: ' + str(nodes)) - - # generate distance matrix - l = len(items) - dmatrix = numpy.zeros((l, l)) - for i in range(l): - for j in range(i+1,l): - d = HTMLCluster.item_dist(items[i], items[j], mode) - dmatrix[i, j] = d - dmatrix[j, i] = d - - - print_verbose(5, dmatrix) - - # compute clusters - - sq = squareform(dmatrix) - - output_linkage = hcl.linkage(sq, method='average') - - # build up tree - - num_rows = numpy.size(output_linkage, 0) - - print_verbose(5, output_linkage) - - for i in range(num_rows): - cur_cluster = HTMLCluster() - cur_cluster.children.append(nodes[int(output_linkage[i,0])]) - cur_cluster.children.append(nodes[int(output_linkage[i,1])]) - nodes.append(cur_cluster) - - res = nodes[len(nodes)-1] - res.regenerate_not_exported(items) - - - print_verbose(3, 'Clustering result: ' +str(res)) - - return res - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file + def __init__(self): + self.idx = -1 + self.children = [] + self.items = [] + # self.rect = Rect(99999,99999,-1,-1) + self.flat_text = "" + + def is_internal_node(self): + return len(self.children) > 0 + + def is_leaf(self): + return self.idx != -1 + + def set_idx(self, idx): + if self.is_internal_node(): + raise ValueError("Node " + str(self) + " is already an internal node") + self.idx = idx # now it's a child node + + def add_child(self, child): + if self.is_leaf(): + raise ValueError("Node " + str(self) + " is already a leaf node") + self.children.append(child) + + def calc_flat_text(self): + if self.is_leaf(): + self.flat_text = str(self.items[self.idx].txt) + return + first = True + res = "" + for c in self.children: + if not first: + res += ", " + c.calc_flat_text() + res += c.flat_text + first = False + self.flat_text = res + + def get_idx_list(self): + if self.is_leaf(): + return [self.idx] + + res = [] + for c in self.children: + res.extend(c.get_idx_list()) + return res + + def set_items_rec(self, items): + self.items = items + for c in self.children: + c.set_items_rec(items) + + def count_items(self): + if self.is_leaf(): + return 1 + res = 0 + for c in self.children: + res += c.count_items() + return res + + def generate_rendering_colors_rec(self, h0=0.0, h1=0.75): # h = hue in [0,1] + if self.is_leaf(): + self.items[self.idx].rendering_color = hsv_to_rgba((h0 + h1) * 0.5, 1, 1) + else: + num_items_per_child = [] + num_items_tot = 0 + for c in self.children: + cur_num = c.count_items() + num_items_per_child.append(cur_num) + num_items_tot += cur_num + num_items_acc = 0 + for i in range(len(self.children)): + self.children[i].generate_rendering_colors_rec( + h0 + (h1 - h0) * (num_items_acc / num_items_tot), + h0 + (h1 - h0) * ((num_items_acc + num_items_per_child[i]) / num_items_tot), + ) + num_items_acc += num_items_per_child[i] + + def regenerate_not_exported(self, items): + self.set_items_rec(items) + self.calc_flat_text() + + def cleanup_for_export(self): + self.items = None + self.flat_text = None + for c in self.children: + c.cleanup_for_export() + + def __repr__(self): + if self.is_leaf(): + return str(self.items[self.idx]) + res = "<" + first = True + for c in self.children: + if not first: + res += ", " + res += str(c) + first = False + res += ">" + return res + + @staticmethod + def item_dist(it1, it2, mode): + if mode == CLUSTER_DISTANCE_MODE_EUCLIDIAN: + it1_x, it1_y = it1.get_rect().get_center() + it2_x, it2_y = it2.get_rect().get_center() + return dist(it1_x, it1_y, it2_x, it2_y) + elif mode == CLUSTER_DISTANCE_MODE_RAW_TEXT: + return dist(0, it1.pos_y, 0, it2.pos_y) + # return dist(it1.pos_x * 100, it1.pos_y, it2.pos_x * 100, it2.pos_y) #TODO: Add this a a new distance mode ! (20.09.2022) + + raise ValueError("Invalid distance mode") + + @staticmethod + def generate_clusters(items, mode): + print_verbose(3, "Regenerating clusters") + + if len(items) < 2: + return None + + # generate a leaf for each items + nodes = [] + + for it in items: + cur = HTMLCluster() + cur.items = items + cur.idx = it.this_id + nodes.append(cur) + + print_verbose(3, "Leaves: " + str(nodes)) + + # generate distance matrix + l = len(items) + dmatrix = numpy.zeros((l, l)) + for i in range(l): + for j in range(i + 1, l): + d = HTMLCluster.item_dist(items[i], items[j], mode) + dmatrix[i, j] = d + dmatrix[j, i] = d + + print_verbose(5, dmatrix) + + # compute clusters + + sq = squareform(dmatrix) + + output_linkage = hcl.linkage(sq, method="average") + + # build up tree + + num_rows = numpy.size(output_linkage, 0) + + print_verbose(5, output_linkage) + + for i in range(num_rows): + cur_cluster = HTMLCluster() + cur_cluster.children.append(nodes[int(output_linkage[i, 0])]) + cur_cluster.children.append(nodes[int(output_linkage[i, 1])]) + nodes.append(cur_cluster) + + res = nodes[len(nodes) - 1] + res.regenerate_not_exported(items) + + print_verbose(3, "Clustering result: " + str(res)) + + return res diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLDirectory.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLDirectory.py index 1164fcc..25aec6a 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLDirectory.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLDirectory.py @@ -11,116 +11,111 @@ from globals import * from HTMLPage import * from Format_Analyzer import * - + class HTMLDirectory: + htmlpages = None + src_pdf_filename = None + + def __init__(self): + self.htmlpages = [] + self.src_pdf_filename = None + + @staticmethod + def call_pdftohtml(infile, outdir): + print_verbose(2, "-> call pdftohtml_mod " + infile) + os.system( + config.global_exec_folder + + r'/pdftohtml_mod/pdftohtml_mod "' + + infile + + '" "' + + remove_trailing_slash(outdir) + + '"' + ) ## TODO: Specify correct path here! + + @staticmethod + def fix_strange_encryption(html_dir): + html_dir = remove_trailing_slash(html_dir) + + pathname = html_dir + "/page*.html" + print_verbose(2, "Fixing strange encryption = " + str(pathname)) + + for f in glob.glob(pathname): + print_verbose(3, "---> " + str(f)) + # HTMLPage.fix_strange_encryption(f) ### TODO: This might be needed, because there are some PDFs with some strange encryption in place (but so far not in the ESG context). + + @staticmethod + def convert_pdf_to_html(pdf_file, info_file_contents, out_dir=None): + out_dir = get_html_out_dir(pdf_file) if out_dir is None else remove_trailing_slash(out_dir) + + try: + shutil.rmtree(out_dir) + except OSError: + pass + HTMLDirectory.call_pdftohtml(pdf_file, out_dir) + + # fix strange encryption + HTMLDirectory.fix_strange_encryption(out_dir) + + f = open(out_dir + "/info.txt", "w") + # f.write(Format_Analyzer.extract_file_name(pdf_file)) + f.write(info_file_contents[pdf_file]) + f.close() + + def read_pdf_filename(self, html_dir): + with open(remove_trailing_slash(html_dir) + "/info.txt") as f: + self.src_pdf_filename = f.read() + print_verbose(2, "PDF-Filename: " + self.src_pdf_filename) + + def parse_html_directory(self, html_dir, page_wildcard): + html_dir = remove_trailing_slash(html_dir) + + pathname = html_dir + "/" + page_wildcard + print_verbose(1, "PARSING DIR = " + str(pathname)) + + self.read_pdf_filename(html_dir) + + for f in glob.glob(pathname): + print_verbose(1, "ANALYZING HTML-FILE = " + str(f)) + + htmlpage = HTMLPage.parse_html_file(html_dir, f) + + print_verbose(1, "Discovered tables: ") + + print_verbose(1, htmlpage.repr_tables_only()) + + print_verbose(1, "Done with page = " + str(htmlpage.page_num)) + + self.htmlpages.append(htmlpage) + + def render_to_png(self, base_dir, out_dir): + for it in self.htmlpages: + print_verbose(1, "Render to png : page = " + str(it.page_num)) + it.render_to_png(remove_trailing_slash(base_dir), remove_trailing_slash(out_dir)) + + def print_all_tables(self): + for it in self.htmlpages: + print(it.repr_tables_only()) + + def save_to_dir(self, out_dir): + for it in self.htmlpages: + print_verbose(1, "Save to JSON and CSV: page = " + str(it.page_num)) + it.save_to_file(remove_trailing_slash(out_dir) + r"/jpage" + "{:05d}".format(it.page_num) + ".json") + it.save_all_tables_to_csv(out_dir) + it.save_all_footnotes_to_txt(out_dir) + + def load_from_dir(self, html_dir, page_wildcard): + html_dir = remove_trailing_slash(html_dir) + pathname = html_dir + "/" + page_wildcard + + self.read_pdf_filename(html_dir) + + for f in glob.glob(pathname): + # if not (f.endswith('0052.json') or f.endswith('0053.json')): # can be used for debugging, esp. multipage analyzing + # continue + + print_verbose(1, "LOADING JSON-FILE = " + str(f)) + + htmlpage = HTMLPage.load_from_file(f) - htmlpages = None - src_pdf_filename = None - - def __init__(self): - self.htmlpages = [] - self.src_pdf_filename = None - - @staticmethod - def call_pdftohtml(infile, outdir): - print_verbose(2, '-> call pdftohtml_mod '+infile) - os.system(config.global_exec_folder + r'/pdftohtml_mod/pdftohtml_mod "' + infile + '" "' + remove_trailing_slash(outdir) + '"') ## TODO: Specify correct path here! - - @staticmethod - def fix_strange_encryption(html_dir): - html_dir = remove_trailing_slash(html_dir) - - pathname = html_dir + '/page*.html' - print_verbose(2, "Fixing strange encryption = " + str(pathname)) - - for f in glob.glob(pathname): - print_verbose(3, "---> " + str(f)) - #HTMLPage.fix_strange_encryption(f) ### TODO: This might be needed, because there are some PDFs with some strange encryption in place (but so far not in the ESG context). - - - - @staticmethod - def convert_pdf_to_html(pdf_file, info_file_contents, out_dir=None): - out_dir = get_html_out_dir(pdf_file) if out_dir is None else remove_trailing_slash(out_dir) - - try: - shutil.rmtree(out_dir) - except OSError: - pass - HTMLDirectory.call_pdftohtml(pdf_file , out_dir) - - # fix strange encryption - HTMLDirectory.fix_strange_encryption(out_dir) - - f = open(out_dir + '/info.txt', 'w') - #f.write(Format_Analyzer.extract_file_name(pdf_file)) - f.write(info_file_contents[pdf_file]) - f.close() - - def read_pdf_filename(self, html_dir): - with open(remove_trailing_slash(html_dir) + '/info.txt') as f: - self.src_pdf_filename = f.read() - print_verbose(2, 'PDF-Filename: ' + self.src_pdf_filename) - - - def parse_html_directory(self, html_dir, page_wildcard): - - html_dir = remove_trailing_slash(html_dir) - - pathname = html_dir + '/' + page_wildcard - print_verbose(1, "PARSING DIR = " + str(pathname)) - - self.read_pdf_filename(html_dir) - - - for f in glob.glob(pathname): - - print_verbose(1, "ANALYZING HTML-FILE = " + str(f)) - - htmlpage = HTMLPage.parse_html_file(html_dir,f) - - print_verbose(1, "Discovered tables: ") - - print_verbose(1, htmlpage.repr_tables_only()) - - print_verbose(1, "Done with page = " + str(htmlpage.page_num)) - - self.htmlpages.append(htmlpage) - - - def render_to_png(self, base_dir, out_dir): - for it in self.htmlpages: - print_verbose(1, "Render to png : page = " + str(it.page_num)) - it.render_to_png(remove_trailing_slash(base_dir), remove_trailing_slash(out_dir)) - - def print_all_tables(self): - for it in self.htmlpages: - print(it.repr_tables_only()) - - - - def save_to_dir(self, out_dir): - for it in self.htmlpages: - print_verbose(1, "Save to JSON and CSV: page = " + str(it.page_num)) - it.save_to_file(remove_trailing_slash(out_dir) + r'/jpage'+"{:05d}".format(it.page_num) +'.json') - it.save_all_tables_to_csv(out_dir) - it.save_all_footnotes_to_txt(out_dir) - - def load_from_dir(self, html_dir, page_wildcard): - - html_dir = remove_trailing_slash(html_dir) - pathname = html_dir + '/' + page_wildcard - - self.read_pdf_filename(html_dir) - - for f in glob.glob(pathname): - #if not (f.endswith('0052.json') or f.endswith('0053.json')): # can be used for debugging, esp. multipage analyzing - # continue - - print_verbose(1, "LOADING JSON-FILE = " + str(f)) - - htmlpage = HTMLPage.load_from_file(f) - - self.htmlpages.append(htmlpage) - \ No newline at end of file + self.htmlpages.append(htmlpage) diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLItem.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLItem.py index 6db98b5..d2e54e5 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLItem.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLItem.py @@ -14,283 +14,316 @@ class HTMLItem: - line_num = None - tot_line_num = None - pos_x = None # in pixels - pos_y = None # in pixels - width = None # in pixels - height = None # in pixels - initial_height = None # in pixels - font_size = None - txt = None - is_bold = None - brightness = None - alignment = None - font_file = None - - this_id = None - next_id = None # next line, -1 if None - prev_id = None # prev line, -1 if None - left_id = None # item to the left, -1 if None - right_id = None # item to the right, -1 if None - - category = None - temp_assignment = None # an integer, that is normally set to 0. It has greater values while table extraction is in progress - - merged_list = None #indexes of items, that this item had been merged with - words = None #list of all words (each a HTMLWord) - space_width = None - has_been_split = None - - rendering_color = None # only used for PNG rendering. not related with KPI extraction - - page_num = None - - - def __init__(self): - self.line_num = 0 - self.tot_line_num = 0 - self.pos_x = 0 - self.pos_y = 0 - self.width = 0 - self.height = 0 - self.initial_height = None - self.font_size = 0 - self.txt = "" - self.is_bold = False - self.brightness = 255 - self.alignment = ALIGN_LEFT - self.font_file = "" - self.this_id = -1 - self.next_id = -1 - self.prev_id = -1 - self.category = CAT_DEFAULT - self.temp_assignment = 0 - self.merged_list = [] - self.words = [] - self.space_width = 0 - self.has_been_split = False - self.left_id = -1 - self.right_id = -1 - self.rendering_color = (0,0,0,255) #black by default - self.page_num = -1 - - - def is_connected(self): - return next_id != -1 or prev_id != -1 - - def get_depth(self): - size = self.font_size - if(len(self.words)>0): - size = 0 - for w in self.words: - size = max(size, w.rect.y1 - w.rect.y0) - if(size < self.font_size * 0.8): - size = self.font_size *0.8 - if(size > self.font_size * 1.2): - size = self.font_size * 1.2 - return 10000 - int(size * 10 + (5 if self.is_bold else 0) + (3 * (255 - self.brightness)) / 255) - - def get_aligned_pos_x(self): - if(self.alignment == ALIGN_LEFT): - return self.pos_x - if(self.alignment == ALIGN_RIGHT): - return self.pos_x + self.width - if(self.alignment == ALIGN_CENTER): # New-27.06.2022 - return self.pos_x + self.width * 0.5 - # not yet implemented: - return None - - def is_text_component(self): - return self.category == CAT_HEADLINE or self.category == CAT_OTHER_TEXT or self.category == CAT_RUNNING_TEXT or self.category == CAT_FOOTER - - def has_category(self): - return self.category != CAT_DEFAULT - - def has_category_besides(self, category_to_neglect): - return self.category != CAT_DEFAULT and self.category != category_to_neglect - - def get_rect(self): - return Rect(self.pos_x,self.pos_y,self.pos_x+self.width,self.pos_y+self.height) - - @staticmethod - def find_item_by_id(items, id): - for it in items: - if(it.this_id == id): - return it - return None # not found. should never happen - - def reconnect(self, next_it, all_items): - if(self.next_id != -1): - old_next_it = HTMLItem.find_item_by_id(all_items, self.next_id) - old_next_it.prev_id = -1 - - if(next_it.prev_id != -1): - new_next_olds_prev_it = HTMLItem.find_item_by_id(all_items, next_it.prev_id) - new_next_olds_prev_it.next_id = -1 - - self.next_id = next_it.this_id - next_it.prev_id = self.this_id - - def is_mergable(self, it): - if(self.next_id == -1 and it.next_id == -1): - return False - if(self.prev_id == -1 and it.prev_id == -1): - return False - return (self.next_id == it.this_id or self.prev_id == it.this_id) \ - and self.pos_x == it.pos_x \ - and self.font_file == it.font_file \ - and self.height == it.height \ - and not Format_Analyzer.looks_numeric(self.txt) \ - and not Format_Analyzer.looks_numeric(it.txt) - - def is_weakly_mergable_after_reconnect(self, it): - return self.font_file == it.font_file \ - and self.font_size == it.font_size \ - and abs(self.get_initial_height() - it.get_initial_height()) < 0.1 - - def get_font_characteristics(self): - return self.font_file + '???' + str(self.font_size) + '???' + str(self.brightness) + '???' +str(self.is_bold) - - def get_initial_height(self): - if(self.initial_height is not None): - return self.initial_height - return self.height - - - def recalc_width(self): - span_font = ImageFont.truetype(self.font_file, self.font_size) - size = span_font.getsize(self.txt) - self.width = size[0] - if(self.width == 0): - #aproximate - size = span_font.getsize('x' * len(self.txt)) - self.width = size[0] - - - def merge(self, it): - # precondition : both items must be mergable - if(self.next_id == it.this_id): - self.txt += '\n' + it.txt - self.initial_height = self.get_initial_height() - self.height = it.pos_y + it.height - self.pos_y - self.width = max(self.width, it.width) - self.words = self.words + it.words - it.words = [] - it.txt = '' - elif(self.prev_id == it.this_id): - it.txt += '\n' + self.txt - it.initial_height = it.get_initial_height() - it.height = self.pos_y + self.height - it.pos_y - it.width = max(self.width, it.width) - it.words = self.words + it.words - self.txt = '' - self.words = [] - else: - raise ValueError('Items '+str(self)+' and '+str(it) + ' cannot be merged.') - - old_merged_list = self.merged_list.copy() - self.merged_list.append(it.this_id) - self.merged_list.extend(it.merged_list) - it.merged_list.append(self.this_id) - it.merged_list.extend(old_merged_list) - - - def fix_overlapping_words(self): - # assertion: all words are ordered by x asceding - for i in range(len(self.words)-1): - self.words[i].rect.x1 = min(self.words[i].rect.x1, self.words[i+1].rect.x0 - 0.00001) - - - def recalc_geometry(self): - self.pos_x = 9999999 - self.pos_y = 9999999 - x1 = -1 - y1 = -1 - for w in self.words: - self.pos_x = min(self.pos_x, w.rect.x0) - self.pos_y = min(self.pos_y, w.rect.y0) - x1 = max(x1, w.rect.x1) - y1 = max(y1, w.rect.y1) - self.width = x1 - self.pos_x - self.height = y1 - self.pos_y - - def rejoin_words(self): - self.txt = '' - for w in self.words: - if(self.txt != ''): - self.txt += ' ' - self.txt += w.txt - - - - def split(self, at_word, next_item_id): - # example "abc 123 def" -> split(1, 99) -> - # result "abc", and new item with item_id=99 "123 def" - new_item = HTMLItem() - new_item.line_num = self.line_num - new_item.tot_line_num = self.tot_line_num - #new_item.pos_x = self.pos_x - #new_item.pos_y = self.pos_y - #new_item.width = self.width - #new_item.height = self.height - new_item.font_size = self.font_size - new_item.words = self.words[at_word:] - #new_item.txt = self.txt - new_item.is_bold = self.is_bold - new_item.brightness = self.brightness - new_item.alignment = self.alignment - new_item.font_file = self.font_file - new_item.this_id = next_item_id - new_item.next_id = -1 - new_item.prev_id = -1 - new_item.category = self.category - new_item.temp_assignment= self.temp_assignment - new_item.merged_list = self.merged_list - new_item.space_width = self.space_width - new_item.has_been_split = True - - self.has_been_split = True - - new_item.left_id = self.this_id - new_item.right_id = self.right_id - new_item.page_num = self.page_num - self.right_id = new_item.this_id - - for k in range(at_word, len(self.words)): - self.words[k].item_id = next_item_id - - self.words = self.words[0:at_word] - self.recalc_geometry() - self.rejoin_words() - - new_item.recalc_geometry() - new_item.rejoin_words() - - return new_item - -# def unsplit(self, right): -# self.words = self.words + right.words -# self.right_id = right.right_id -# self.recalc_geometry() -# self.rejoin_words() -# # afterwards, right needs to be discarded - - - @staticmethod - def concat_txt(item_list, sep=' '): - res = '' - for it in item_list: - if(res != ''): - res += sep - res += it.txt - return res - - - def __repr__(self): - return "" + line_num = None + tot_line_num = None + pos_x = None # in pixels + pos_y = None # in pixels + width = None # in pixels + height = None # in pixels + initial_height = None # in pixels + font_size = None + txt = None + is_bold = None + brightness = None + alignment = None + font_file = None + this_id = None + next_id = None # next line, -1 if None + prev_id = None # prev line, -1 if None + left_id = None # item to the left, -1 if None + right_id = None # item to the right, -1 if None + + category = None + temp_assignment = ( + None # an integer, that is normally set to 0. It has greater values while table extraction is in progress + ) + + merged_list = None # indexes of items, that this item had been merged with + words = None # list of all words (each a HTMLWord) + space_width = None + has_been_split = None + + rendering_color = None # only used for PNG rendering. not related with KPI extraction + + page_num = None + + def __init__(self): + self.line_num = 0 + self.tot_line_num = 0 + self.pos_x = 0 + self.pos_y = 0 + self.width = 0 + self.height = 0 + self.initial_height = None + self.font_size = 0 + self.txt = "" + self.is_bold = False + self.brightness = 255 + self.alignment = ALIGN_LEFT + self.font_file = "" + self.this_id = -1 + self.next_id = -1 + self.prev_id = -1 + self.category = CAT_DEFAULT + self.temp_assignment = 0 + self.merged_list = [] + self.words = [] + self.space_width = 0 + self.has_been_split = False + self.left_id = -1 + self.right_id = -1 + self.rendering_color = (0, 0, 0, 255) # black by default + self.page_num = -1 + + def is_connected(self): + return next_id != -1 or prev_id != -1 + + def get_depth(self): + size = self.font_size + if len(self.words) > 0: + size = 0 + for w in self.words: + size = max(size, w.rect.y1 - w.rect.y0) + if size < self.font_size * 0.8: + size = self.font_size * 0.8 + if size > self.font_size * 1.2: + size = self.font_size * 1.2 + return 10000 - int(size * 10 + (5 if self.is_bold else 0) + (3 * (255 - self.brightness)) / 255) + + def get_aligned_pos_x(self): + if self.alignment == ALIGN_LEFT: + return self.pos_x + if self.alignment == ALIGN_RIGHT: + return self.pos_x + self.width + if self.alignment == ALIGN_CENTER: # New-27.06.2022 + return self.pos_x + self.width * 0.5 + # not yet implemented: + return None + + def is_text_component(self): + return ( + self.category == CAT_HEADLINE + or self.category == CAT_OTHER_TEXT + or self.category == CAT_RUNNING_TEXT + or self.category == CAT_FOOTER + ) + + def has_category(self): + return self.category != CAT_DEFAULT + + def has_category_besides(self, category_to_neglect): + return self.category != CAT_DEFAULT and self.category != category_to_neglect + + def get_rect(self): + return Rect(self.pos_x, self.pos_y, self.pos_x + self.width, self.pos_y + self.height) + + @staticmethod + def find_item_by_id(items, id): + for it in items: + if it.this_id == id: + return it + return None # not found. should never happen + + def reconnect(self, next_it, all_items): + if self.next_id != -1: + old_next_it = HTMLItem.find_item_by_id(all_items, self.next_id) + old_next_it.prev_id = -1 + + if next_it.prev_id != -1: + new_next_olds_prev_it = HTMLItem.find_item_by_id(all_items, next_it.prev_id) + new_next_olds_prev_it.next_id = -1 + + self.next_id = next_it.this_id + next_it.prev_id = self.this_id + + def is_mergable(self, it): + if self.next_id == -1 and it.next_id == -1: + return False + if self.prev_id == -1 and it.prev_id == -1: + return False + return ( + (self.next_id == it.this_id or self.prev_id == it.this_id) + and self.pos_x == it.pos_x + and self.font_file == it.font_file + and self.height == it.height + and not Format_Analyzer.looks_numeric(self.txt) + and not Format_Analyzer.looks_numeric(it.txt) + ) + + def is_weakly_mergable_after_reconnect(self, it): + return ( + self.font_file == it.font_file + and self.font_size == it.font_size + and abs(self.get_initial_height() - it.get_initial_height()) < 0.1 + ) + + def get_font_characteristics(self): + return self.font_file + "???" + str(self.font_size) + "???" + str(self.brightness) + "???" + str(self.is_bold) + + def get_initial_height(self): + if self.initial_height is not None: + return self.initial_height + return self.height + + def recalc_width(self): + span_font = ImageFont.truetype(self.font_file, self.font_size) + size = span_font.getsize(self.txt) + self.width = size[0] + if self.width == 0: + # aproximate + size = span_font.getsize("x" * len(self.txt)) + self.width = size[0] + + def merge(self, it): + # precondition : both items must be mergable + if self.next_id == it.this_id: + self.txt += "\n" + it.txt + self.initial_height = self.get_initial_height() + self.height = it.pos_y + it.height - self.pos_y + self.width = max(self.width, it.width) + self.words = self.words + it.words + it.words = [] + it.txt = "" + elif self.prev_id == it.this_id: + it.txt += "\n" + self.txt + it.initial_height = it.get_initial_height() + it.height = self.pos_y + self.height - it.pos_y + it.width = max(self.width, it.width) + it.words = self.words + it.words + self.txt = "" + self.words = [] + else: + raise ValueError("Items " + str(self) + " and " + str(it) + " cannot be merged.") + + old_merged_list = self.merged_list.copy() + self.merged_list.append(it.this_id) + self.merged_list.extend(it.merged_list) + it.merged_list.append(self.this_id) + it.merged_list.extend(old_merged_list) + + def fix_overlapping_words(self): + # assertion: all words are ordered by x asceding + for i in range(len(self.words) - 1): + self.words[i].rect.x1 = min(self.words[i].rect.x1, self.words[i + 1].rect.x0 - 0.00001) + + def recalc_geometry(self): + self.pos_x = 9999999 + self.pos_y = 9999999 + x1 = -1 + y1 = -1 + for w in self.words: + self.pos_x = min(self.pos_x, w.rect.x0) + self.pos_y = min(self.pos_y, w.rect.y0) + x1 = max(x1, w.rect.x1) + y1 = max(y1, w.rect.y1) + self.width = x1 - self.pos_x + self.height = y1 - self.pos_y + + def rejoin_words(self): + self.txt = "" + for w in self.words: + if self.txt != "": + self.txt += " " + self.txt += w.txt + + def split(self, at_word, next_item_id): + # example "abc 123 def" -> split(1, 99) -> + # result "abc", and new item with item_id=99 "123 def" + new_item = HTMLItem() + new_item.line_num = self.line_num + new_item.tot_line_num = self.tot_line_num + # new_item.pos_x = self.pos_x + # new_item.pos_y = self.pos_y + # new_item.width = self.width + # new_item.height = self.height + new_item.font_size = self.font_size + new_item.words = self.words[at_word:] + # new_item.txt = self.txt + new_item.is_bold = self.is_bold + new_item.brightness = self.brightness + new_item.alignment = self.alignment + new_item.font_file = self.font_file + new_item.this_id = next_item_id + new_item.next_id = -1 + new_item.prev_id = -1 + new_item.category = self.category + new_item.temp_assignment = self.temp_assignment + new_item.merged_list = self.merged_list + new_item.space_width = self.space_width + new_item.has_been_split = True + + self.has_been_split = True + + new_item.left_id = self.this_id + new_item.right_id = self.right_id + new_item.page_num = self.page_num + self.right_id = new_item.this_id + + for k in range(at_word, len(self.words)): + self.words[k].item_id = next_item_id + + self.words = self.words[0:at_word] + self.recalc_geometry() + self.rejoin_words() + + new_item.recalc_geometry() + new_item.rejoin_words() + + return new_item + + # def unsplit(self, right): + # self.words = self.words + right.words + # self.right_id = right.right_id + # self.recalc_geometry() + # self.rejoin_words() + # # afterwards, right needs to be discarded + + @staticmethod + def concat_txt(item_list, sep=" "): + res = "" + for it in item_list: + if res != "": + res += sep + res += it.txt + return res + + def __repr__(self): + return ( + "" + ) diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLPage.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLPage.py index eaa79b1..bd0614d 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLPage.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLPage.py @@ -14,239 +14,248 @@ from HTMLTable import * from HTMLCluster import * import copy - + class HTMLPage: - page_num = None - page_width = None - page_height = None - items = None - left_distrib = None #distribution of pos_x values (left alignments) - tables = None - paragraphs = None - clusters = None - clusters_text = None #clusters for traversing raw text - footnotes_idx = None - page_start_y0 = None - - def __init__(self): - self.page_num = 0 - self.page_width = 0 - self.page_height = 0 - self.items = [] - self.left_distrib = {} - self.tables = [] - self.paragraphs = [] - self.clusters = [] - self.clusters_text = [] - self.footnotes_idx = [] - self.page_start_y0 = [0] - - - # ===================================================================================================================== - # Utilities for Handling Multiple pages - # ===================================================================================================================== - - - @staticmethod - def merge(p0, p1): # merge two pages (p1 will be below p0) - def get_max_line_num(p): - res = 0 - for it in p.items: - if(it.line_num > res): - res = it.line_num - return res - - def get_max_id(p): - res = -1 - for it in p.items: - if(it.this_id > res): - res = it.this_id - return res - - def transform_id(id, offset): - if(id==-1): - return -1 - return id + offset - - p0c = copy.deepcopy(p0) #p0's copy - p1c = copy.deepcopy(p1) #p1's copy - #p0c will be the result - p0c.page_start_y0.append(p0c.page_height) - p0c.page_height += p1c.page_height - p0c.page_width = max(p0c.page_width, p1c.page_width) - p1c_min_line_num = get_max_line_num(p0c) + 1 - p1c_min_id = get_max_id(p0c) + 1 - p0c_num_items = len(p0c.items) - - for it in p1c.items: - it.line_num += p1c_min_line_num - it.tot_line_num = p0c.page_num * 10000 + it.line_num - it.pos_y += p1c.page_height - it.this_id = transform_id(it.this_id, p1c_min_id) - it.next_id = transform_id(it.next_id, p1c_min_id) - it.prev_id = transform_id(it.prev_id, p1c_min_id) - it.left_id = transform_id(it.left_id, p1c_min_id) - it.right_id = transform_id(it.right_id, p1c_min_id) - it.merged_list = [transform_id(id, p1c_min_id) for id in it.merged_list] - for w in it.words: - w.item_id = transform_id(w.item_id, p1c_min_id) - w.rect.y0 += p0c.page_height - w.rect.y1 += p0c.page_height - #print(it) - p0c.items.append(it) - - for ky in p1c.left_distrib: - p0c.left_distrib[ky] = p0c.left_distrib.get(ky, 0) + p1c.left_distrib[ky] - - for t in p1c.tables: - t.recalc_geometry() - - p0c.tables.extend(p1c.tables) - p0c.find_paragraphs() - - for idx in p1c.footnotes_idx: - p0c.footnotes_idx.append(idx + p0c_num_items) - - p0c.generate_clusters() - - return p0c - - - - - # ===================================================================================================================== - # Utilities for Identifying Basic Structures - # ===================================================================================================================== - - # transform coordinates such that they are in [0,1) range on the current page - def transform_coords(self, x, y): - #print("TTTT") - #print(str(x) + "," + str(y) + "-> " + str(self.page_width) + ", " + str(self.page_start_y0) + ", " + str(self.page_height)) - res_x = x / self.page_width - res_y = 1 - for i in range(len(self.page_start_y0)): - start_y = self.page_start_y0[i] - end_y = self.page_start_y0[i+1] if i < len(self.page_start_y0)-1 else self.page_height - res_y = (y-start_y)/(end_y-start_y) - if(res_y<1): - break - return res_x, res_y - - def find_idx_of_item_by_txt(self, txt): - res = -1 - for i in range(len(self.items)): - if(self.items[i].txt == txt): - if(res != -1): - raise ValueError('Text "'+txt+'" occurs more than once in this page.') - res = i - return res - - def detect_split_items(self): - - def find_aligned_words_in_direction(all_words, k0, use_alignment, dir): - SPLIT_DETECTION_THRESHOLD = 1.0 / 609.0 - - threshold = SPLIT_DETECTION_THRESHOLD * self.page_width - - item_id0 = all_words[k0][0] - word_id0 = all_words[k0][1] - x0_0 = self.items[item_id0].words[word_id0].rect.x0 - x1_0 = self.items[item_id0].words[word_id0].rect.x1 - - k_max = len(all_words) - k = k0 + dir - res = [] - is_numeric = False - while(0<= k and k < k_max): - item_id = all_words[k][0] - word_id = all_words[k][1] - x0 = self.items[item_id].words[word_id].rect.x0 - x1 = self.items[item_id].words[word_id].rect.x1 - - if(use_alignment == ALIGN_LEFT and abs(x0-x0_0) < threshold): - # we found a left-aligned word - res.append(k) - is_numeric = is_numeric or Format_Analyzer.looks_weak_numeric(self.items[item_id].words[word_id].txt) - elif(use_alignment == ALIGN_RIGHT and abs(x1-x1_0) < threshold): - # we found a right-aligned word - res.append(k) - is_numeric = is_numeric or Format_Analyzer.looks_weak_numeric(self.items[item_id].words[word_id].txt) - #elif(x0 < x1_0 and x1 > x0_0): - elif((x0 < x0_0 and x1 > x0_0 and use_alignment == ALIGN_LEFT) or (x0 < x1_0 and x1 > x1_0 and use_alignment == ALIGN_RIGHT)): - #print('BAD:'+self.items[item_id].words[word_id].txt) - #print(x0, x0_0, threshold) - break # another word is in the way, but not correctly aligned - - k += dir - return res, is_numeric - - def find_space_in_direction(all_words, k0, dir): - item_id = all_words[k0][0] - word_id = all_words[k0][1] - num_words = len(self.items[item_id].words) - if(dir==-1): - return self.items[item_id].words[word_id].rect.x0 - self.items[item_id].words[word_id-1].rect.x1 if word_id > 0 else 9999999 - elif(dir==1): - return self.items[item_id].words[word_id+1].rect.x0 - self.items[item_id].words[word_id].rect.x1 if word_id < num_words-1 else 9999999 - raise ValueError('Invalid dir') - return -1 # should never happen - - - # prepare sorted order of all words - all_words = [] - - for i in range(len(self.items)): - for j in range(len(self.items[i].words)): - all_words.append((i, j, False)) - - all_words = sorted(all_words, key=lambda ij: self.items[ij[0]].words[ij[1]].rect.y0) - - do_split_l = [] - do_split_r = [] - - - # find words that can be split - for k in range(len(all_words)): - item_id = all_words[k][0] - word_id = all_words[k][1] - num_words = len(self.items[item_id].words) - - isnum = Format_Analyzer.looks_weak_numeric(self.items[item_id].words[word_id].txt) - isnum_l = Format_Analyzer.looks_weak_numeric(self.items[item_id].words[word_id-1].txt) if word_id > 0 else False - isnum_r = Format_Analyzer.looks_weak_numeric(self.items[item_id].words[word_id+1].txt) if word_id < num_words-1 else False - - # left-aligned words: - #print("L") - left_aligned_words_down, isnum_ld = find_aligned_words_in_direction(all_words, k, ALIGN_LEFT, 1) - left_aligned_words_up, isnum_lu = find_aligned_words_in_direction(all_words, k, ALIGN_LEFT, -1) - left_aligned_words = left_aligned_words_down + left_aligned_words_up - - - # right-aligned words: - #print("R") - right_aligned_words_down, isnum_rd = find_aligned_words_in_direction(all_words, k, ALIGN_RIGHT, 1) - right_aligned_words_up, isnum_ru = find_aligned_words_in_direction(all_words, k, ALIGN_RIGHT, -1) - right_aligned_words = right_aligned_words_down + right_aligned_words_up - - - # It doesnt make sense to split for CENTER aligned words # New-27.06.2022 - - isnum_l = isnum_l or isnum_ld or isnum_lu - isnum_r = isnum_r or isnum_rd or isnum_ru + page_num = None + page_width = None + page_height = None + items = None + left_distrib = None # distribution of pos_x values (left alignments) + tables = None + paragraphs = None + clusters = None + clusters_text = None # clusters for traversing raw text + footnotes_idx = None + page_start_y0 = None - - threshold_rows_l = 1 if isnum and isnum_l else 2 - threshold_rows_r = 1 if isnum and isnum_r else 2 - - threshold_space_l = (1.2 if isnum and isnum_l else 1.5) * self.items[item_id].space_width - threshold_space_r = (1.2 if isnum and isnum_r else 1.5) * self.items[item_id].space_width - - # space between this and previous/next word - space_width_l = find_space_in_direction(all_words, k, -1) - space_width_r = find_space_in_direction(all_words, k, 1) - - """ + def __init__(self): + self.page_num = 0 + self.page_width = 0 + self.page_height = 0 + self.items = [] + self.left_distrib = {} + self.tables = [] + self.paragraphs = [] + self.clusters = [] + self.clusters_text = [] + self.footnotes_idx = [] + self.page_start_y0 = [0] + + # ===================================================================================================================== + # Utilities for Handling Multiple pages + # ===================================================================================================================== + + @staticmethod + def merge(p0, p1): # merge two pages (p1 will be below p0) + def get_max_line_num(p): + res = 0 + for it in p.items: + if it.line_num > res: + res = it.line_num + return res + + def get_max_id(p): + res = -1 + for it in p.items: + if it.this_id > res: + res = it.this_id + return res + + def transform_id(id, offset): + if id == -1: + return -1 + return id + offset + + p0c = copy.deepcopy(p0) # p0's copy + p1c = copy.deepcopy(p1) # p1's copy + # p0c will be the result + p0c.page_start_y0.append(p0c.page_height) + p0c.page_height += p1c.page_height + p0c.page_width = max(p0c.page_width, p1c.page_width) + p1c_min_line_num = get_max_line_num(p0c) + 1 + p1c_min_id = get_max_id(p0c) + 1 + p0c_num_items = len(p0c.items) + + for it in p1c.items: + it.line_num += p1c_min_line_num + it.tot_line_num = p0c.page_num * 10000 + it.line_num + it.pos_y += p1c.page_height + it.this_id = transform_id(it.this_id, p1c_min_id) + it.next_id = transform_id(it.next_id, p1c_min_id) + it.prev_id = transform_id(it.prev_id, p1c_min_id) + it.left_id = transform_id(it.left_id, p1c_min_id) + it.right_id = transform_id(it.right_id, p1c_min_id) + it.merged_list = [transform_id(id, p1c_min_id) for id in it.merged_list] + for w in it.words: + w.item_id = transform_id(w.item_id, p1c_min_id) + w.rect.y0 += p0c.page_height + w.rect.y1 += p0c.page_height + # print(it) + p0c.items.append(it) + + for ky in p1c.left_distrib: + p0c.left_distrib[ky] = p0c.left_distrib.get(ky, 0) + p1c.left_distrib[ky] + + for t in p1c.tables: + t.recalc_geometry() + + p0c.tables.extend(p1c.tables) + p0c.find_paragraphs() + + for idx in p1c.footnotes_idx: + p0c.footnotes_idx.append(idx + p0c_num_items) + + p0c.generate_clusters() + + return p0c + + # ===================================================================================================================== + # Utilities for Identifying Basic Structures + # ===================================================================================================================== + + # transform coordinates such that they are in [0,1) range on the current page + def transform_coords(self, x, y): + # print("TTTT") + # print(str(x) + "," + str(y) + "-> " + str(self.page_width) + ", " + str(self.page_start_y0) + ", " + str(self.page_height)) + res_x = x / self.page_width + res_y = 1 + for i in range(len(self.page_start_y0)): + start_y = self.page_start_y0[i] + end_y = self.page_start_y0[i + 1] if i < len(self.page_start_y0) - 1 else self.page_height + res_y = (y - start_y) / (end_y - start_y) + if res_y < 1: + break + return res_x, res_y + + def find_idx_of_item_by_txt(self, txt): + res = -1 + for i in range(len(self.items)): + if self.items[i].txt == txt: + if res != -1: + raise ValueError('Text "' + txt + '" occurs more than once in this page.') + res = i + return res + + def detect_split_items(self): + def find_aligned_words_in_direction(all_words, k0, use_alignment, dir): + SPLIT_DETECTION_THRESHOLD = 1.0 / 609.0 + + threshold = SPLIT_DETECTION_THRESHOLD * self.page_width + + item_id0 = all_words[k0][0] + word_id0 = all_words[k0][1] + x0_0 = self.items[item_id0].words[word_id0].rect.x0 + x1_0 = self.items[item_id0].words[word_id0].rect.x1 + + k_max = len(all_words) + k = k0 + dir + res = [] + is_numeric = False + while 0 <= k and k < k_max: + item_id = all_words[k][0] + word_id = all_words[k][1] + x0 = self.items[item_id].words[word_id].rect.x0 + x1 = self.items[item_id].words[word_id].rect.x1 + + if use_alignment == ALIGN_LEFT and abs(x0 - x0_0) < threshold: + # we found a left-aligned word + res.append(k) + is_numeric = is_numeric or Format_Analyzer.looks_weak_numeric( + self.items[item_id].words[word_id].txt + ) + elif use_alignment == ALIGN_RIGHT and abs(x1 - x1_0) < threshold: + # we found a right-aligned word + res.append(k) + is_numeric = is_numeric or Format_Analyzer.looks_weak_numeric( + self.items[item_id].words[word_id].txt + ) + # elif(x0 < x1_0 and x1 > x0_0): + elif (x0 < x0_0 and x1 > x0_0 and use_alignment == ALIGN_LEFT) or ( + x0 < x1_0 and x1 > x1_0 and use_alignment == ALIGN_RIGHT + ): + # print('BAD:'+self.items[item_id].words[word_id].txt) + # print(x0, x0_0, threshold) + break # another word is in the way, but not correctly aligned + + k += dir + return res, is_numeric + + def find_space_in_direction(all_words, k0, dir): + item_id = all_words[k0][0] + word_id = all_words[k0][1] + num_words = len(self.items[item_id].words) + if dir == -1: + return ( + self.items[item_id].words[word_id].rect.x0 - self.items[item_id].words[word_id - 1].rect.x1 + if word_id > 0 + else 9999999 + ) + elif dir == 1: + return ( + self.items[item_id].words[word_id + 1].rect.x0 - self.items[item_id].words[word_id].rect.x1 + if word_id < num_words - 1 + else 9999999 + ) + raise ValueError("Invalid dir") + return -1 # should never happen + + # prepare sorted order of all words + all_words = [] + + for i in range(len(self.items)): + for j in range(len(self.items[i].words)): + all_words.append((i, j, False)) + + all_words = sorted(all_words, key=lambda ij: self.items[ij[0]].words[ij[1]].rect.y0) + + do_split_l = [] + do_split_r = [] + + # find words that can be split + for k in range(len(all_words)): + item_id = all_words[k][0] + word_id = all_words[k][1] + num_words = len(self.items[item_id].words) + + isnum = Format_Analyzer.looks_weak_numeric(self.items[item_id].words[word_id].txt) + isnum_l = ( + Format_Analyzer.looks_weak_numeric(self.items[item_id].words[word_id - 1].txt) if word_id > 0 else False + ) + isnum_r = ( + Format_Analyzer.looks_weak_numeric(self.items[item_id].words[word_id + 1].txt) + if word_id < num_words - 1 + else False + ) + + # left-aligned words: + # print("L") + left_aligned_words_down, isnum_ld = find_aligned_words_in_direction(all_words, k, ALIGN_LEFT, 1) + left_aligned_words_up, isnum_lu = find_aligned_words_in_direction(all_words, k, ALIGN_LEFT, -1) + left_aligned_words = left_aligned_words_down + left_aligned_words_up + + # right-aligned words: + # print("R") + right_aligned_words_down, isnum_rd = find_aligned_words_in_direction(all_words, k, ALIGN_RIGHT, 1) + right_aligned_words_up, isnum_ru = find_aligned_words_in_direction(all_words, k, ALIGN_RIGHT, -1) + right_aligned_words = right_aligned_words_down + right_aligned_words_up + + # It doesnt make sense to split for CENTER aligned words # New-27.06.2022 + + isnum_l = isnum_l or isnum_ld or isnum_lu + isnum_r = isnum_r or isnum_rd or isnum_ru + + threshold_rows_l = 1 if isnum and isnum_l else 2 + threshold_rows_r = 1 if isnum and isnum_r else 2 + + threshold_space_l = (1.2 if isnum and isnum_l else 1.5) * self.items[item_id].space_width + threshold_space_r = (1.2 if isnum and isnum_r else 1.5) * self.items[item_id].space_width + + # space between this and previous/next word + space_width_l = find_space_in_direction(all_words, k, -1) + space_width_r = find_space_in_direction(all_words, k, 1) + + """ space_width_l_min = space_width_l for n in left_aligned_words: space_width_l_min = min(space_width_l_min, find_space_in_direction(all_words, n, -1)) @@ -256,993 +265,1016 @@ def find_space_in_direction(all_words, k0, dir): space_width_r_min = min(space_width_r_min, find_space_in_direction(all_words, n, 1)) """ - #if(self.items[item_id].words[word_id].txt=='2019-2q'): - # print(self.items[item_id].words[word_id].txt) - # print(len(left_aligned_words) , threshold_rows_l , word_id , num_words, space_width_l_min , threshold_space_l) - # print(len(right_aligned_words) , threshold_rows_r , word_id , num_words, space_width_r_min , threshold_space_r) - - - #if(len(left_aligned_words) > threshold_rows_l and word_id > 0 and space_width_l_min > threshold_space_l): - if(len(left_aligned_words) > threshold_rows_l and word_id > 0 and space_width_l > threshold_space_l): - #print(self.items[item_id].words[word_id].txt) - do_split_l.append(k) - #do_split.extend(left_aligned_words) - - #if(len(right_aligned_words) > threshold_rows_r and word_id < num_words-1 and space_width_r_min > threshold_space_r): - if(len(right_aligned_words) > threshold_rows_r and word_id < num_words-1 and space_width_r > threshold_space_r): - #print(self.items[item_id].words[word_id].txt) - #print('!!!') - do_split_r.append(k) - #do_split.extend(right_aligned_words) - - - if(space_width_l > 3 * self.items[item_id].space_width and word_id > 0 ): - # two words are are too far apart => split them in any case - do_split_l.append(k) - do_split_l.extend(left_aligned_words) - - - if(space_width_r > 3 * self.items[item_id].space_width and word_id < num_words-1 ): - # two words are are too far apart => split them in any case - do_split_r.append(k) - do_split_r.extend(right_aligned_words) - #print('!!!!!!!') - - - #print('???') + # if(self.items[item_id].words[word_id].txt=='2019-2q'): + # print(self.items[item_id].words[word_id].txt) + # print(len(left_aligned_words) , threshold_rows_l , word_id , num_words, space_width_l_min , threshold_space_l) + # print(len(right_aligned_words) , threshold_rows_r , word_id , num_words, space_width_r_min , threshold_space_r) - + # if(len(left_aligned_words) > threshold_rows_l and word_id > 0 and space_width_l_min > threshold_space_l): + if len(left_aligned_words) > threshold_rows_l and word_id > 0 and space_width_l > threshold_space_l: + # print(self.items[item_id].words[word_id].txt) + do_split_l.append(k) + # do_split.extend(left_aligned_words) - # do splitting - for k in do_split_l: - if(all_words[k][1] > 0): - all_words[k] = (all_words[k][0], all_words[k][1], True) + # if(len(right_aligned_words) > threshold_rows_r and word_id < num_words-1 and space_width_r_min > threshold_space_r): + if ( + len(right_aligned_words) > threshold_rows_r + and word_id < num_words - 1 + and space_width_r > threshold_space_r + ): + # print(self.items[item_id].words[word_id].txt) + # print('!!!') + do_split_r.append(k) + # do_split.extend(right_aligned_words) - word_map = {} - for k in range(len(all_words)): - word_map[(all_words[k][0],all_words[k][1])] = k - - - for k in do_split_r: - item_id = all_words[k][0] - word_id = all_words[k][1] - num_words = len(self.items[item_id].words) - if(word_id < num_words-1): - kr = word_map[(item_id, word_id+1)] - all_words[kr] = (all_words[kr][0], all_words[kr][1], True) - - all_words = sorted(all_words, key=lambda ij: -ij[1]) #split always beginning from the end - - next_id = len(self.items) - - for ij in all_words: - if(not ij[2]): - continue # do not split this one - if(self.items[ij[0]].words[ij[1]].rect.x0 - self.items[ij[0]].words[ij[1]-1].rect.x1 < self.items[item_id].space_width * 0): - continue # words are too close to split - - if(self.items[ij[0]].words[ij[1]].rect.x0 - self.items[ij[0]].words[ij[1]-1].rect.x1 < self.items[item_id].space_width * 1.5 and \ - not Format_Analyzer.looks_weak_numeric(self.items[ij[0]].words[ij[1]].txt) and \ - not Format_Analyzer.looks_weak_numeric(self.items[ij[0]].words[ij[1]-1].txt)): - continue # words are too close to split - - - - print_verbose(3, '---> Split item '+str(self.items[ij[0]]) + ' at word ' + \ - str(ij[1]) + '(x1='+str(self.items[ij[0]].words[ij[1]-1].rect.x1)+'<-> x0='+str(self.items[ij[0]].words[ij[1]].rect.x0)+ \ - ' , space_width= '+str(self.items[item_id].space_width)) - new_item = self.items[ij[0]].split(ij[1], next_id) - self.items.append(new_item) - print_verbose(3, '------> Result = "'+str(self.items[ij[0]].txt) + '" + "' + str(new_item.txt) + '"') - - next_id += 1 - - - - - #raise ValueError('XXX') - - - - def get_txt_unsplit(self, idx): - if(self.items[idx].right_id==-1): - return self.items[idx].txt - return self.items[idx].txt + " " + self.get_txt_unsplit(self.items[idx].right_id) - - - - - def find_left_distributions(self): - self.left_distrib = {} - for it in self.items: - cur_x = it.pos_x - self.left_distrib[cur_x] = self.left_distrib.get(cur_x, 0) + 1 - print_verbose(5, 'Left distrib: ' + str(self.left_distrib)) - - def find_paragraphs(self): - self.paragraphs = [] - - distrib = {} - for it in self.items: - if(it.category == CAT_RUNNING_TEXT or it.category == CAT_HEADLINE): - cur_x = it.pos_x - distrib[cur_x] = distrib.get(cur_x, 0) + 1 - - for pos_x, frequency in distrib.items(): - if(frequency > 5): - self.paragraphs.append(pos_x) - - self.paragraphs.sort() - - - def find_items_within_rect_all_categories(self, rect): # returns list of indices - res = [] - for i in range(len(self.items)): - if(Rect.calc_intersection_area(self.items[i].get_rect(), rect) > self.items[i].get_rect().get_area() * 0.3): #0.5? - res.append(i) - return res - - def find_items_within_rect(self, rect, categories): # returns list of indices - res = [] - for i in range(len(self.items)): - if(self.items[i].category in categories): - if(Rect.calc_intersection_area(self.items[i].get_rect(), rect) > self.items[i].get_rect().get_area() * 0.3): #0.5? - res.append(i) - return res - - def explode_item(self, idx, sep = ' '): #return concatenated txt - def expl_int(dir, idx, sep): - return ('' if self.items[idx].left_id == -1 or dir > 0 else expl_int(-1, self.items[idx].left_id, sep) + sep) + \ - self.items[idx].txt + \ - ('' if self.items[idx].right_id == -1 or dir < 0 else sep + expl_int(1, self.items[idx].right_id, sep) ) - return expl_int(0, idx, sep) - - - def explode_item_by_idx(self, idx): #return list of idx - def expl_int(dir, idx): - return ([] if self.items[idx].left_id == -1 or dir > 0 else expl_int(-1, self.items[idx].left_id)) + \ - [idx]+ \ - ([] if self.items[idx].right_id == -1 or dir < 0 else expl_int(1, self.items[idx].right_id) ) - return expl_int(0, idx) - - def find_vertical_aligned_items(self, item, alignment, threshold, do_print=False): #TODO: remove do_print - #returns indices of affected items - res = [] - - if(alignment != ALIGN_DEFAULT): - this_align = alignment - else: - this_align = item.alignment - - pos_x = item.pos_x - pos_y = item.pos_y - threshold_px = threshold * self.page_width - - score = 0.0 - - if(this_align == ALIGN_RIGHT): - pos_x += item.width - # New-27.06.2022: - if(this_align == ALIGN_CENTER): - pos_x += item.width*0.5 + if space_width_l > 3 * self.items[item_id].space_width and word_id > 0: + # two words are are too far apart => split them in any case + do_split_l.append(k) + do_split_l.extend(left_aligned_words) - - for i in range(len(self.items)): - if(alignment != ALIGN_DEFAULT): - cur_align = alignment - else: - cur_align = self.items[i].alignment - cur_x = self.items[i].pos_x - cur_y = self.items[i].pos_y - - if(cur_align == ALIGN_RIGHT): - cur_x += self.items[i].width - if(cur_align == ALIGN_CENTER): - cur_x += self.items[i].width*0.5 - - - delta = abs(cur_x-pos_x) - if(do_print): - print_verbose(7, '---> delta for '+str(self.items[i])+' to '+str(pos_x)+' is '+str(delta)) - if(delta <= threshold_px): - cur_score = ((threshold_px - delta)/self.page_width) * ( ((1.0 - abs(cur_y - pos_y) / self.page_height)) ** 5.0) - if(cur_score<0.003): - cur_score=0 - print_verbose(9, "VALIGN->"+str(self.items[i])+" has SCORE: "+str(cur_score)) - score += cur_score - res.append(i) - - return res, score - - - - - def find_horizontal_aligned_items(self, item): - #returns indices of affected items - res = [] - - y0 = item.pos_y - y1 = item.pos_y + item.height - - for i in range(len(self.items)): - it = self.items[i] - if(it.pos_y < y1 and it.pos_y + it.height > y0): - res.append(i) - - return res - - - def clear_all_temp_assignments(self): - for it in self.items: - it.temp_assignment = 0 - - def guess_all_alignments(self): - for it in self.items: - dummy, score_left = self.find_vertical_aligned_items(it, ALIGN_LEFT, DEFAULT_VTHRESHOLD) - dummy, score_right = self.find_vertical_aligned_items(it, ALIGN_RIGHT, DEFAULT_VTHRESHOLD) - dummy, score_center = self.find_vertical_aligned_items(it, ALIGN_CENTER, DEFAULT_VTHRESHOLD) # New-27.06.2022 - #it.alignment = ALIGN_LEFT if score_left >= score_right else ALIGN_RIGHT - if(score_left >= score_right and score_left >= score_center): - it.alignment = ALIGN_LEFT - elif(score_right >= score_left and score_right >= score_center): - it.alignment = ALIGN_RIGHT - else: - it.alignment = ALIGN_CENTER - - - def find_next_nonclassified_item(self): - for it in self.items: - if(not it.has_category()): - return it - return None - - - def identify_connected_txt_lines(self): - - def insert_next_id(cur_id, next_id, items): - if(items[next_id].pos_y <= items[cur_id].pos_y): - raise ValueError('Invalid item order:' + str(items[cur_id]) + ' --> ' +str(items[next_id])) - - - if(items[cur_id].next_id==-1): - items[cur_id].next_id = next_id - elif(items[cur_id].next_id==next_id): - return - else: - old_next_id = items[cur_id].next_id - if(items[next_id].pos_y < items[old_next_id].pos_y): - items[cur_id].next_id = next_id - insert_next_id(next_id, old_next_id, items) - elif(items[next_id].pos_y < items[old_next_id].pos_y): - insert_next_id(old_next_id, next_id, items) - else: - # sometimes this can happen, but then we dont update anything in order to avoid cycles - pass - - - threshold = int(0.03 * self.page_width + 0.5) # allow max 3% deviation to the left - cur_threshold = 0 - - for cur_x, cnt in self.left_distrib.items(): - if(cnt < 2): - # if we have less than 2 lines, we skip this "column" - continue - - cur_lines = {} # for each line in this column, we store its y position - last_pos_y = -1 - for i in range(len(self.items)): - if(self.items[i].pos_x >= cur_x and self.items[i].pos_x <= cur_x + cur_threshold and self.items[i].pos_y > last_pos_y): - cur_lines[i] = self.items[i].pos_y - last_pos_y = self.items[i].pos_y + self.items[i].height * 0.9 - - cur_lines = sorted(cur_lines.items(), key=lambda kv: kv[1]) - - # get row_spacing - row_spacings = [] - for i in range(len(cur_lines)-1): - cur_item_id, cur_y = cur_lines[i] - next_item_id, next_y = cur_lines[i+1] - cur_item = self.items[cur_item_id] - cur_spacing = next_y - (cur_y + cur_item.height) - row_spacings.append(cur_spacing) - - - if(len(row_spacings) == 0): - continue - - max_allowed_spacing = statistics.median(row_spacings) * 1.1 + if space_width_r > 3 * self.items[item_id].space_width and word_id < num_words - 1: + # two words are are too far apart => split them in any case + do_split_r.append(k) + do_split_r.extend(right_aligned_words) + # print('!!!!!!!') - - for i in range(len(cur_lines)-1): - cur_item_id, cur_y = cur_lines[i] - cur_item = self.items[cur_item_id] - - next_item_id, next_y = cur_lines[i+1] - next_item = self.items[next_item_id] - if((next_y > cur_y + min(cur_item.height + max_allowed_spacing, 2 * cur_item.height)) or # too far apart - (cur_item.font_size != next_item.font_size) or # different font sizes - (cur_item.font_file != next_item.font_file)): # different fonts faces - cur_threshold = 0 - continue - - cur_threshold = threshold - insert_next_id(cur_item_id, next_item_id, self.items) - #self.items[cur_item_id].next_id = next_item_id - #self.items[next_item_id].prev_id = cur_item_id - - # update all prev_ids - for i in range(len(self.items)): - if(self.items[i].next_id != -1): - self.items[self.items[i].next_id].prev_id = i - - - def mark_regular_text(self): - # mark connected text components - for it in self.items: - if(it.category != CAT_DEFAULT): - continue # already taken - if(it.prev_id != -1): - continue #has previous item => we look at that - - txt = Format_Analyzer.trim_whitespaces(it.txt) - - next = it.next_id - while(next!=-1): - #print(self.items[next], self.items[next].next_id) - txt += ' ' + Format_Analyzer.trim_whitespaces(self.items[next].txt) - next = self.items[next].next_id - - if(Format_Analyzer.looks_running_text(txt)): - it.category = CAT_RUNNING_TEXT - next = it.next_id - while(next!=-1): - self.items[next].category = CAT_RUNNING_TEXT - next = self.items[next].next_id - - def mark_other_text_components(self): - - threshold = int(0.03 * self.page_width + 0.5) # allow max 3% deviation to the left - - for cur_x, cnt in self.left_distrib.items(): - - cur_lines = {} # for each line in this column, we store its y position - cur_threshold = 0 - for i in range(len(self.items)): - if(self.items[i].pos_x >= cur_x and self.items[i].pos_x <= cur_x + cur_threshold): - cur_threshold = threshold - cur_lines[i] = self.items[i].pos_y - - cur_lines = sorted(cur_lines.items(), key=lambda kv: kv[1]) - - for i in range(len(cur_lines)): - cur_item_id, cur_y = cur_lines[i] - if(self.items[cur_item_id].category != CAT_DEFAULT): - continue # already taken - prev_item_id = -1 - prev_y = -1 - next_item_id = -1 - next_y = -1 - if(i>0): - prev_item_id, prev_y = cur_lines[i-1] - if(i self.items[next_item_id].height): - self.items[cur_item_id].category = CAT_HEADLINE - else: - print_verbose(10, "---->>> found CAT_OTHER_TEXT/1 for item " + str(cur_item_id)) - self.items[cur_item_id].category = CAT_OTHER_TEXT - - # single (head-)lines at the beginning - if((prev_item_id==-1 or self.items[cur_item_id].prev_id == -1) and next_item_id!=-1): - y_threshold = 2*max(self.items[cur_item_id].height, self.items[next_item_id].height) - - if(self.items[next_item_id].category == CAT_RUNNING_TEXT and - abs(cur_y-next_y) < y_threshold): - if(self.items[cur_item_id].height > self.items[next_item_id].height): - self.items[cur_item_id].category = CAT_HEADLINE - else: - print_verbose(10, "---->>> found CAT_OTHER_TEXT/2 for item " + str(cur_item_id)) - self.items[cur_item_id].category = CAT_OTHER_TEXT - - - # multiple rows spanning headlines at the beginning - for i in range(len(cur_lines)-1): - cur_item_id, cur_y = cur_lines[i] - if(self.items[cur_item_id].category != CAT_DEFAULT): - continue # already taken - - if(self.items[cur_item_id].next_id != -1 or self.items[cur_item_id].prev_id == -1): - continue # we are only interested at items that mark end of a block - - if(not Format_Analyzer.looks_words(self.items[cur_item_id].txt)): - continue # only text - - print_verbose(9, "--> mark_other_text_components \ multi-rows headline: "+str(self.items[cur_item_id])) - next_item_id, next_y = cur_lines[i+1] - - if(self.items[next_item_id].category != CAT_RUNNING_TEXT): - continue # only when followed by normal paragraph - - if(self.items[cur_item_id].font_size <= self.items[next_item_id].font_size and - not (self.items[cur_item_id].font_size == self.items[next_item_id].font_size and - self.items[cur_item_id].is_bold and not self.items[next_item_id].is_bold)): - continue # header must of greater font size, or: same font size but greater boldness (i.e, headline bold, but following text not) - - y_threshold = 2*max(self.items[cur_item_id].height, self.items[next_item_id].height) - - print_verbose(9, "----> cur_y, next_y , y_threshold = " + str(cur_y) + ","+str(next_y)+","+str(y_threshold)) - - - if(abs(cur_y - next_y) < y_threshold and self.items[cur_item_id].height > self.items[next_item_id].height): - # count number of affected lines - iter_item_id = cur_item_id - num_affected = 1 - while(iter_item_id != -1): - iter_item_id = self.items[iter_item_id].prev_id - num_affected += 1 - if(num_affected <= 3): # more than 3 lines would be too much for a headline - # match! - print_verbose(9, "------->> MATCH!") - iter_item_id = cur_item_id - while(iter_item_id != -1): - self.items[iter_item_id].category = CAT_HEADLINE - iter_item_id = self.items[iter_item_id].prev_id - + # print('???') - - # Page number and footer - pgnum_threshold = 0.9 * self.page_height - pgnum_id = -1 - pgnum_pos_y = 0 - - for i in range(len(self.items)): - if(self.items[i].pos_y < pgnum_threshold): - continue - if(Format_Analyzer.looks_pagenum(self.items[i].txt)): - cur_y = self.items[i].pos_y - if(cur_y > pgnum_pos_y): - pgnum_pos_y = cur_y - pgnum_id = i - - if(pgnum_id != -1): - print_verbose(10, "---->>> found CAT_FOOTER/3 for item " + str(pgnum_id)) - self.items[pgnum_id].category = CAT_FOOTER - for it in self.items: - if(it.pos_y == pgnum_pos_y): - print_verbose(10, "---->>> found CAT_FOOTER/4 for item " + str(it.this_id)) - it.category = CAT_FOOTER #footer - - - # Isolated items - iso_threshold = 0.05 * self.page_height - - for it in self.items: - if(it.category != CAT_DEFAULT): - continue #already taken - - min_dist = 99999999 - - for jt in self.items: - if(it==jt or jt.category==CAT_RUNNING_TEXT or jt.category==CAT_HEADLINE or jt.category==CAT_OTHER_TEXT or jt.category==CAT_FOOTER): - continue # we do not consider these ones (=> text is assumed to be isolated, even if any of these ones are near) - cur_dist = Rect.raw_rect_distance(it.pos_x, it.pos_y, it.pos_x+it.width, it.pos_y+it.height, jt.pos_x, jt.pos_y, jt.pos_x+jt.width, jt.pos_y+jt.height) - if(cur_dist iso_threshold): - print_verbose(10, "---->>> found CAT_OTHER_TEXT/5 for item " + str(it.this_id)) - it.category = CAT_OTHER_TEXT - - - - # ===================================================================================================================== - # Utilities for Table Extraction - # ===================================================================================================================== - - def find_vnearest_taken_components(self, initial_item): #search vertically, return nearest top and bottom component (if any) - FONT_RECT_TOLERANCE = 0.9 #be less restrictive, since the font rect's width is not always accurate - left_min = initial_item.pos_x - left_max = initial_item.pos_x + initial_item.width * FONT_RECT_TOLERANCE - - top = None - bottom = None - - for it in self.items: - if(it.has_category_besides(CAT_FOOTER)): - if(it.pos_x < left_max and it.pos_x + it.width * FONT_RECT_TOLERANCE > left_min): - if(it.pos_y <= initial_item.pos_y and (top is None or it.pos_y > top.pos_y)): - top = it - if(it.pos_y >= initial_item.pos_y and (bottom is None or it.pos_y < bottom.pos_y)): - bottom = it - - return top, bottom - - - def sort_out_non_vconnected_items(self, vitems, top_y, bottom_y): - res = [] - for i in vitems: - if(self.items[i].pos_y > top_y and self.items[i].pos_y < bottom_y): - res.append(i) - - return res - - - def sort_out_taken_items(self, any_items): - res = [] - for i in any_items: - if(not self.items[i].has_category_besides(CAT_FOOTER) and self.items[i].temp_assignment == 0): - res.append(i) - - return res - - - def sort_out_non_vertical_aligned_items_by_bounding_box(self, initial_item, vitems): - res = [] - - x0 = initial_item.pos_x - x1 = initial_item.pos_x + initial_item.width - - - - for i in vitems: - cur_x0 = self.items[i].pos_x - cur_x1 = self.items[i].pos_x + self.items[i].width - - if(cur_x0 < x1 and cur_x1 > x0 ): - res.append(i) - - return res - - - - - def sort_out_items_in_same_row(self, initial_item, vitems): - orig_pos_x = initial_item.get_aligned_pos_x() - res = [] - for i in vitems: - it_delta = abs(self.items[i].get_aligned_pos_x() - orig_pos_x) - better_item_in_same_row = False - for j in vitems: - if(i==j): - continue - if(self.items[j].pos_y + self.items[j].height >= self.items[i].pos_y and self.items[j].pos_y <= self.items[i].pos_y + self.items[i].height): - jt_delta = abs(self.items[j].get_aligned_pos_x() - orig_pos_x) - if(jt_delta < it_delta): - better_item_in_same_row = True - break - if(not better_item_in_same_row): - res.append(i) - - return res - - def sort_out_non_connected_row_items(self, hitems, initial_item): - # in the same row, we sort out all items, that are not connected with initial_item - # 2 items are connected, iff between them, there are no taken items - # initial_item will always be included - min_x = -1 - max_x = 9999999 - for i in hitems: - if(self.items[i].has_category_besides(CAT_FOOTER) or self.items[i].temp_assignment != 0): - cur_x = self.items[i].pos_x - if(cur_x < initial_item.pos_x): - min_x = max(min_x, cur_x) - if(cur_x > initial_item.pos_x): - max_x = min(max_x, cur_x) - - res = [] - for i in hitems: - if(self.items[i].pos_x > min_x and self.items[i].pos_x < max_x): - res.append(i) - - return res - - - - - - def discover_table_column(self, initial_item): - print_verbose(7, 'discover_table_column for item : ' + str(initial_item)) - - vitems, dummy = self.find_vertical_aligned_items(initial_item, ALIGN_DEFAULT, DEFAULT_VTHRESHOLD) - print_verbose(9, '---> 1. V-Items ') - print_subset(9, self.items, vitems) - vitems = self.sort_out_taken_items(vitems) - print_verbose(9, '---> 2. V-Items ') - print_subset(9, self.items, vitems) - vitems = self.sort_out_non_vertical_aligned_items_by_bounding_box(initial_item, vitems) - print_verbose(9, '---> 3. V-Items ') - print_subset(9, self.items, vitems) - vitems = self.sort_out_items_in_same_row(initial_item, vitems) - print_verbose(9, '---> 4. V-Items ') - print_subset(9, self.items, vitems) - - top, bottom = self.find_vnearest_taken_components(initial_item) - vitems = self.sort_out_non_vconnected_items(vitems, -1 if top is None else top.pos_y, 9999999 if bottom is None else bottom.pos_y) - - # make sure, that initial_item is always included - if(initial_item.this_id not in vitems): - vitems.append(initial_item.this_id) - - print_verbose(7, '---> top: ' + str(top) + ', and bottom:' + str(bottom)) - print_verbose(7, '---> V-Items ') - print_subset(7, self.items, vitems) - - sub_tab = HTMLTable() - if(len(vitems) > 0): - sub_tab.init_by_cols(vitems, self.items) - sub_tab.set_temp_assignment() - print_verbose(5, 'Sub Table for current column at: ' +str(initial_item) + " = " +str(sub_tab)) - - - return sub_tab - - - def discover_table_row(self, initial_item): - hitems = self.find_horizontal_aligned_items(initial_item) - hitems = self.sort_out_non_connected_row_items(hitems, initial_item) - - print_verbose(7, "discover row at item: "+ str(initial_item)) - print_subset(7, self.items, hitems) - - return hitems - - - def discover_subtables_recursively(self, initial_item, step): #step=0 => discover col; step=1 => discover row. each subtable is a column - print_verbose(5, "discover subtable rec, at item : " +str(initial_item) + " and step = " +str(step)) - - if(initial_item.has_category() or (initial_item.temp_assignment != 0 and step == 0)): - print_verbose(5, "---> recusion end") - return [] # end recursion - - if(step==0): # col - res = [] - cur_sub_table = self.discover_table_column(initial_item) - - if(cur_sub_table.count_actual_items() > 0): - print_verbose(5, "---> added new subtable") - res.append(cur_sub_table) - for i in cur_sub_table.idx: - if(i!=-1): - res.extend(self.discover_subtables_recursively(self.items[i], 1)) - return res - - elif(step==1): # row - res = [] - hitems = self.discover_table_row(initial_item) - print_verbose(5, "---> found hitems = "+str(hitems)) - for i in hitems: - res.extend(self.discover_subtables_recursively(self.items[i], 0)) - - return res - - return [] - - - def discover_table(self, initial_item): - print_verbose(2, "DISCOVER NEW TABLE AT " + str(initial_item)) - - done = False - - while(not done): - done = True - - initial_item.temp_assignment = 0 - self.clear_all_temp_assignments() - - sub_tables = self.discover_subtables_recursively(initial_item, 0) - if(len(sub_tables) == 0): - return None - - table = sub_tables[0] - print_verbose(2, "Starting with table: "+str(sub_tables[0])) - - for i in range(1,len(sub_tables)): - print_verbose(5, "Merging table: "+str(sub_tables[i])) - table = HTMLTable.merge(table, sub_tables[i], self.page_width) - print_verbose(5, "Next table:" + str(table)) - - # TODO!!! - #table.recalc_geometry() - #table.unfold_patched_numbers() - table.cleanup_table(self.page_width, self.paragraphs) - - if(table.is_good_table()): - # did we miss any items? - missing_items = self.find_items_within_rect(table.table_rect, [CAT_HEADLINE, CAT_OTHER_TEXT, CAT_RUNNING_TEXT, CAT_FOOTER]) - if(len(missing_items)>0): - #yes => reclassify - print_verbose(2, "Found missing items : " +str(missing_items)) - for i in missing_items: - self.items[i].category = CAT_DEFAULT - done = False - - - - return table - - - def mark_all_tables(self): - while(True): - next = self.find_next_nonclassified_item() - if(next is None): - break # we are done - - table = self.discover_table(next) - print_verbose(2, "FOUND TABLE: "+str(table)) - if(config.global_force_special_items_into_table): - table.force_special_items_into_table() - - - if(table.is_good_table()): - print_verbose(2, "---> good") - table.categorize_as_table() - self.tables.append(table) - else: - print_verbose(2, "---> bad") - table.categorize_as_misc() - - # sort out all empty special items - for t in self.tables: - tmp_sp_idx = [] - for sp_idx in t.special_idx: - if(self.items[sp_idx].txt != ''): - tmp_sp_idx.append(sp_idx) - else: - self.items[sp_idx].category = CAT_MISC - t.special_idx = tmp_sp_idx - - # merge non-overlapping rows, if needed - if(config.global_table_merge_non_overlapping_rows): - for table in self.tables: - table.merge_non_overlapping_rows() - #pass - - def mark_all_footnotes(self): - - def apply_cat_unsplit(idx, cat): - self.items[idx].category = cat - if(self.items[idx].right_id!=-1): - apply_cat_unsplit(self.items[idx].right_id, cat) - - - print_verbose(5, "Marking all footnotes . . .") - for idx in range(len(self.items)): - if(self.items[idx].left_id != -1): - continue # skip this - txt = self.get_txt_unsplit(idx) - print_verbose(7, "Analyzing==>" + txt+ ", cat=" +str(self.items[idx].category)) - if(self.items[idx].category != CAT_OTHER_TEXT): - continue # skip this also - if(Format_Analyzer.looks_footnote(txt)): - # this is a footnote ! - print_verbose(7, ".....>>> Yes, footnote!") - apply_cat_unsplit(idx, CAT_FOOTNOTE) - self.footnotes_idx.append(idx) - - - - - + # do splitting + for k in do_split_l: + if all_words[k][1] > 0: + all_words[k] = (all_words[k][0], all_words[k][1], True) - - # ===================================================================================================================== - # Rendering - # ===================================================================================================================== - - def render_to_png(self, in_dir, out_dir): - - base = Image.open(in_dir+r'/page'+str(self.page_num)+'.png').convert('RGBA').resize((self.page_width, self.page_height)) - context = ImageDraw.Draw(base) - - if(not RENDERING_USE_CLUSTER_COLORS): - table_bg_color = (0,255,255,255) #if not RENDERING_USE_CLUSTER_COLORS else (255, 255, 255, 64) - - # table borders - for t in self.tables: - context.rectangle([(t.table_rect.x0,t.table_rect.y0),(t.table_rect.x1,t.table_rect.y1)], fill = table_bg_color, outline =(0,0,0,255)) - first = True - for r in t.rows: - if(first): - first = False - context.line([(r.x0, r.y0),(r.x1, r.y0)], fill =(0,0,0,255), width = 0) - context.line([(r.x0, r.y1),(r.x1, r.y1)], fill =(0,0,0,255), width = 0) - first = True - for c in t.cols: - if(first): - first = False - context.line([(c.x0, c.y0),(c.x0, c.y1)], fill =(0,0,0,255), width = 0) - context.line([(c.x1, c.y0),(c.x1, c.y1)], fill =(0,0,0,255), width = 0) - - - # text - if(RENDERING_USE_CLUSTER_COLORS): - self.clusters_text.generate_rendering_colors_rec() - - for it in self.items: - font_color = (0,0,255,255) #default - if(it.category in (CAT_RUNNING_TEXT, CAT_HEADLINE, CAT_OTHER_TEXT, CAT_FOOTER)): - font_color = (216, 216, 216, 255) - - #if(it.category == CAT_HEADLINE): - # font_color = (0, 128, 32, 255) - - #if(it.category == CAT_OTHER_TEXT): - # font_color = (0, 255, 0, 255) - - if(it.category == CAT_TABLE_DATA): - font_color = (0, 0, 0, 255) - - if(it.category == CAT_TABLE_HEADLINE): - font_color = (255, 0, 0, 255) - - if(it.category == CAT_TABLE_SPECIAL): - font_color = (255, 0, 128, 255) - - if(it.category == CAT_FOOTNOTE): - font_color = (127, 0, 255, 255) - - """ + word_map = {} + for k in range(len(all_words)): + word_map[(all_words[k][0], all_words[k][1])] = k + + for k in do_split_r: + item_id = all_words[k][0] + word_id = all_words[k][1] + num_words = len(self.items[item_id].words) + if word_id < num_words - 1: + kr = word_map[(item_id, word_id + 1)] + all_words[kr] = (all_words[kr][0], all_words[kr][1], True) + + all_words = sorted(all_words, key=lambda ij: -ij[1]) # split always beginning from the end + + next_id = len(self.items) + + for ij in all_words: + if not ij[2]: + continue # do not split this one + if ( + self.items[ij[0]].words[ij[1]].rect.x0 - self.items[ij[0]].words[ij[1] - 1].rect.x1 + < self.items[item_id].space_width * 0 + ): + continue # words are too close to split + + if ( + self.items[ij[0]].words[ij[1]].rect.x0 - self.items[ij[0]].words[ij[1] - 1].rect.x1 + < self.items[item_id].space_width * 1.5 + and not Format_Analyzer.looks_weak_numeric(self.items[ij[0]].words[ij[1]].txt) + and not Format_Analyzer.looks_weak_numeric(self.items[ij[0]].words[ij[1] - 1].txt) + ): + continue # words are too close to split + + print_verbose( + 3, + "---> Split item " + + str(self.items[ij[0]]) + + " at word " + + str(ij[1]) + + "(x1=" + + str(self.items[ij[0]].words[ij[1] - 1].rect.x1) + + "<-> x0=" + + str(self.items[ij[0]].words[ij[1]].rect.x0) + + " , space_width= " + + str(self.items[item_id].space_width), + ) + new_item = self.items[ij[0]].split(ij[1], next_id) + self.items.append(new_item) + print_verbose(3, '------> Result = "' + str(self.items[ij[0]].txt) + '" + "' + str(new_item.txt) + '"') + + next_id += 1 + + # raise ValueError('XXX') + + def get_txt_unsplit(self, idx): + if self.items[idx].right_id == -1: + return self.items[idx].txt + return self.items[idx].txt + " " + self.get_txt_unsplit(self.items[idx].right_id) + + def find_left_distributions(self): + self.left_distrib = {} + for it in self.items: + cur_x = it.pos_x + self.left_distrib[cur_x] = self.left_distrib.get(cur_x, 0) + 1 + print_verbose(5, "Left distrib: " + str(self.left_distrib)) + + def find_paragraphs(self): + self.paragraphs = [] + + distrib = {} + for it in self.items: + if it.category == CAT_RUNNING_TEXT or it.category == CAT_HEADLINE: + cur_x = it.pos_x + distrib[cur_x] = distrib.get(cur_x, 0) + 1 + + for pos_x, frequency in distrib.items(): + if frequency > 5: + self.paragraphs.append(pos_x) + + self.paragraphs.sort() + + def find_items_within_rect_all_categories(self, rect): # returns list of indices + res = [] + for i in range(len(self.items)): + if ( + Rect.calc_intersection_area(self.items[i].get_rect(), rect) > self.items[i].get_rect().get_area() * 0.3 + ): # 0.5? + res.append(i) + return res + + def find_items_within_rect(self, rect, categories): # returns list of indices + res = [] + for i in range(len(self.items)): + if self.items[i].category in categories: + if ( + Rect.calc_intersection_area(self.items[i].get_rect(), rect) + > self.items[i].get_rect().get_area() * 0.3 + ): # 0.5? + res.append(i) + return res + + def explode_item(self, idx, sep=" "): # return concatenated txt + def expl_int(dir, idx, sep): + return ( + ("" if self.items[idx].left_id == -1 or dir > 0 else expl_int(-1, self.items[idx].left_id, sep) + sep) + + self.items[idx].txt + + ( + "" + if self.items[idx].right_id == -1 or dir < 0 + else sep + expl_int(1, self.items[idx].right_id, sep) + ) + ) + + return expl_int(0, idx, sep) + + def explode_item_by_idx(self, idx): # return list of idx + def expl_int(dir, idx): + return ( + ([] if self.items[idx].left_id == -1 or dir > 0 else expl_int(-1, self.items[idx].left_id)) + + [idx] + + ([] if self.items[idx].right_id == -1 or dir < 0 else expl_int(1, self.items[idx].right_id)) + ) + + return expl_int(0, idx) + + def find_vertical_aligned_items(self, item, alignment, threshold, do_print=False): # TODO: remove do_print + # returns indices of affected items + res = [] + + if alignment != ALIGN_DEFAULT: + this_align = alignment + else: + this_align = item.alignment + + pos_x = item.pos_x + pos_y = item.pos_y + threshold_px = threshold * self.page_width + + score = 0.0 + + if this_align == ALIGN_RIGHT: + pos_x += item.width + # New-27.06.2022: + if this_align == ALIGN_CENTER: + pos_x += item.width * 0.5 + + for i in range(len(self.items)): + if alignment != ALIGN_DEFAULT: + cur_align = alignment + else: + cur_align = self.items[i].alignment + cur_x = self.items[i].pos_x + cur_y = self.items[i].pos_y + + if cur_align == ALIGN_RIGHT: + cur_x += self.items[i].width + if cur_align == ALIGN_CENTER: + cur_x += self.items[i].width * 0.5 + + delta = abs(cur_x - pos_x) + if do_print: + print_verbose(7, "---> delta for " + str(self.items[i]) + " to " + str(pos_x) + " is " + str(delta)) + if delta <= threshold_px: + cur_score = ((threshold_px - delta) / self.page_width) * ( + ((1.0 - abs(cur_y - pos_y) / self.page_height)) ** 5.0 + ) + if cur_score < 0.003: + cur_score = 0 + print_verbose(9, "VALIGN->" + str(self.items[i]) + " has SCORE: " + str(cur_score)) + score += cur_score + res.append(i) + + return res, score + + def find_horizontal_aligned_items(self, item): + # returns indices of affected items + res = [] + + y0 = item.pos_y + y1 = item.pos_y + item.height + + for i in range(len(self.items)): + it = self.items[i] + if it.pos_y < y1 and it.pos_y + it.height > y0: + res.append(i) + + return res + + def clear_all_temp_assignments(self): + for it in self.items: + it.temp_assignment = 0 + + def guess_all_alignments(self): + for it in self.items: + dummy, score_left = self.find_vertical_aligned_items(it, ALIGN_LEFT, DEFAULT_VTHRESHOLD) + dummy, score_right = self.find_vertical_aligned_items(it, ALIGN_RIGHT, DEFAULT_VTHRESHOLD) + dummy, score_center = self.find_vertical_aligned_items( + it, ALIGN_CENTER, DEFAULT_VTHRESHOLD + ) # New-27.06.2022 + # it.alignment = ALIGN_LEFT if score_left >= score_right else ALIGN_RIGHT + if score_left >= score_right and score_left >= score_center: + it.alignment = ALIGN_LEFT + elif score_right >= score_left and score_right >= score_center: + it.alignment = ALIGN_RIGHT + else: + it.alignment = ALIGN_CENTER + + def find_next_nonclassified_item(self): + for it in self.items: + if not it.has_category(): + return it + return None + + def identify_connected_txt_lines(self): + def insert_next_id(cur_id, next_id, items): + if items[next_id].pos_y <= items[cur_id].pos_y: + raise ValueError("Invalid item order:" + str(items[cur_id]) + " --> " + str(items[next_id])) + + if items[cur_id].next_id == -1: + items[cur_id].next_id = next_id + elif items[cur_id].next_id == next_id: + return + else: + old_next_id = items[cur_id].next_id + if items[next_id].pos_y < items[old_next_id].pos_y: + items[cur_id].next_id = next_id + insert_next_id(next_id, old_next_id, items) + elif items[next_id].pos_y < items[old_next_id].pos_y: + insert_next_id(old_next_id, next_id, items) + else: + # sometimes this can happen, but then we dont update anything in order to avoid cycles + pass + + threshold = int(0.03 * self.page_width + 0.5) # allow max 3% deviation to the left + cur_threshold = 0 + + for cur_x, cnt in self.left_distrib.items(): + if cnt < 2: + # if we have less than 2 lines, we skip this "column" + continue + + cur_lines = {} # for each line in this column, we store its y position + last_pos_y = -1 + for i in range(len(self.items)): + if ( + self.items[i].pos_x >= cur_x + and self.items[i].pos_x <= cur_x + cur_threshold + and self.items[i].pos_y > last_pos_y + ): + cur_lines[i] = self.items[i].pos_y + last_pos_y = self.items[i].pos_y + self.items[i].height * 0.9 + + cur_lines = sorted(cur_lines.items(), key=lambda kv: kv[1]) + + # get row_spacing + row_spacings = [] + for i in range(len(cur_lines) - 1): + cur_item_id, cur_y = cur_lines[i] + next_item_id, next_y = cur_lines[i + 1] + cur_item = self.items[cur_item_id] + cur_spacing = next_y - (cur_y + cur_item.height) + row_spacings.append(cur_spacing) + + if len(row_spacings) == 0: + continue + + max_allowed_spacing = statistics.median(row_spacings) * 1.1 + + for i in range(len(cur_lines) - 1): + cur_item_id, cur_y = cur_lines[i] + cur_item = self.items[cur_item_id] + + next_item_id, next_y = cur_lines[i + 1] + next_item = self.items[next_item_id] + if ( + (next_y > cur_y + min(cur_item.height + max_allowed_spacing, 2 * cur_item.height)) + or (cur_item.font_size != next_item.font_size) # too far apart + or (cur_item.font_file != next_item.font_file) # different font sizes + ): # different fonts faces + cur_threshold = 0 + continue + + cur_threshold = threshold + insert_next_id(cur_item_id, next_item_id, self.items) + # self.items[cur_item_id].next_id = next_item_id + # self.items[next_item_id].prev_id = cur_item_id + + # update all prev_ids + for i in range(len(self.items)): + if self.items[i].next_id != -1: + self.items[self.items[i].next_id].prev_id = i + + def mark_regular_text(self): + # mark connected text components + for it in self.items: + if it.category != CAT_DEFAULT: + continue # already taken + if it.prev_id != -1: + continue # has previous item => we look at that + + txt = Format_Analyzer.trim_whitespaces(it.txt) + + next = it.next_id + while next != -1: + # print(self.items[next], self.items[next].next_id) + txt += " " + Format_Analyzer.trim_whitespaces(self.items[next].txt) + next = self.items[next].next_id + + if Format_Analyzer.looks_running_text(txt): + it.category = CAT_RUNNING_TEXT + next = it.next_id + while next != -1: + self.items[next].category = CAT_RUNNING_TEXT + next = self.items[next].next_id + + def mark_other_text_components(self): + threshold = int(0.03 * self.page_width + 0.5) # allow max 3% deviation to the left + + for cur_x, cnt in self.left_distrib.items(): + cur_lines = {} # for each line in this column, we store its y position + cur_threshold = 0 + for i in range(len(self.items)): + if self.items[i].pos_x >= cur_x and self.items[i].pos_x <= cur_x + cur_threshold: + cur_threshold = threshold + cur_lines[i] = self.items[i].pos_y + + cur_lines = sorted(cur_lines.items(), key=lambda kv: kv[1]) + + for i in range(len(cur_lines)): + cur_item_id, cur_y = cur_lines[i] + if self.items[cur_item_id].category != CAT_DEFAULT: + continue # already taken + prev_item_id = -1 + prev_y = -1 + next_item_id = -1 + next_y = -1 + if i > 0: + prev_item_id, prev_y = cur_lines[i - 1] + if i < len(cur_lines) - 1: + next_item_id, next_y = cur_lines[i + 1] + + # between to running texts (paragraphs): + if prev_item_id != -1 and next_item_id != -1: + y_threshold = 2 * max( + self.items[cur_item_id].height, self.items[prev_item_id].height, self.items[next_item_id].height + ) + + if ( + self.items[prev_item_id].category == CAT_RUNNING_TEXT + and self.items[next_item_id].category == CAT_RUNNING_TEXT + and abs(cur_y - prev_y) < y_threshold + and abs(cur_y - next_y) < y_threshold + ): + if self.items[cur_item_id].height > self.items[next_item_id].height: + self.items[cur_item_id].category = CAT_HEADLINE + else: + print_verbose(10, "---->>> found CAT_OTHER_TEXT/1 for item " + str(cur_item_id)) + self.items[cur_item_id].category = CAT_OTHER_TEXT + + # single (head-)lines at the beginning + if (prev_item_id == -1 or self.items[cur_item_id].prev_id == -1) and next_item_id != -1: + y_threshold = 2 * max(self.items[cur_item_id].height, self.items[next_item_id].height) + + if self.items[next_item_id].category == CAT_RUNNING_TEXT and abs(cur_y - next_y) < y_threshold: + if self.items[cur_item_id].height > self.items[next_item_id].height: + self.items[cur_item_id].category = CAT_HEADLINE + else: + print_verbose(10, "---->>> found CAT_OTHER_TEXT/2 for item " + str(cur_item_id)) + self.items[cur_item_id].category = CAT_OTHER_TEXT + + # multiple rows spanning headlines at the beginning + for i in range(len(cur_lines) - 1): + cur_item_id, cur_y = cur_lines[i] + if self.items[cur_item_id].category != CAT_DEFAULT: + continue # already taken + + if self.items[cur_item_id].next_id != -1 or self.items[cur_item_id].prev_id == -1: + continue # we are only interested at items that mark end of a block + + if not Format_Analyzer.looks_words(self.items[cur_item_id].txt): + continue # only text + + print_verbose( + 9, "--> mark_other_text_components \ multi-rows headline: " + str(self.items[cur_item_id]) + ) + next_item_id, next_y = cur_lines[i + 1] + + if self.items[next_item_id].category != CAT_RUNNING_TEXT: + continue # only when followed by normal paragraph + + if self.items[cur_item_id].font_size <= self.items[next_item_id].font_size and not ( + self.items[cur_item_id].font_size == self.items[next_item_id].font_size + and self.items[cur_item_id].is_bold + and not self.items[next_item_id].is_bold + ): + continue # header must of greater font size, or: same font size but greater boldness (i.e, headline bold, but following text not) + + y_threshold = 2 * max(self.items[cur_item_id].height, self.items[next_item_id].height) + + print_verbose( + 9, "----> cur_y, next_y , y_threshold = " + str(cur_y) + "," + str(next_y) + "," + str(y_threshold) + ) + + if ( + abs(cur_y - next_y) < y_threshold + and self.items[cur_item_id].height > self.items[next_item_id].height + ): + # count number of affected lines + iter_item_id = cur_item_id + num_affected = 1 + while iter_item_id != -1: + iter_item_id = self.items[iter_item_id].prev_id + num_affected += 1 + if num_affected <= 3: # more than 3 lines would be too much for a headline + # match! + print_verbose(9, "------->> MATCH!") + iter_item_id = cur_item_id + while iter_item_id != -1: + self.items[iter_item_id].category = CAT_HEADLINE + iter_item_id = self.items[iter_item_id].prev_id + + # Page number and footer + pgnum_threshold = 0.9 * self.page_height + pgnum_id = -1 + pgnum_pos_y = 0 + + for i in range(len(self.items)): + if self.items[i].pos_y < pgnum_threshold: + continue + if Format_Analyzer.looks_pagenum(self.items[i].txt): + cur_y = self.items[i].pos_y + if cur_y > pgnum_pos_y: + pgnum_pos_y = cur_y + pgnum_id = i + + if pgnum_id != -1: + print_verbose(10, "---->>> found CAT_FOOTER/3 for item " + str(pgnum_id)) + self.items[pgnum_id].category = CAT_FOOTER + for it in self.items: + if it.pos_y == pgnum_pos_y: + print_verbose(10, "---->>> found CAT_FOOTER/4 for item " + str(it.this_id)) + it.category = CAT_FOOTER # footer + + # Isolated items + iso_threshold = 0.05 * self.page_height + + for it in self.items: + if it.category != CAT_DEFAULT: + continue # already taken + + min_dist = 99999999 + + for jt in self.items: + if ( + it == jt + or jt.category == CAT_RUNNING_TEXT + or jt.category == CAT_HEADLINE + or jt.category == CAT_OTHER_TEXT + or jt.category == CAT_FOOTER + ): + continue # we do not consider these ones (=> text is assumed to be isolated, even if any of these ones are near) + cur_dist = Rect.raw_rect_distance( + it.pos_x, + it.pos_y, + it.pos_x + it.width, + it.pos_y + it.height, + jt.pos_x, + jt.pos_y, + jt.pos_x + jt.width, + jt.pos_y + jt.height, + ) + if cur_dist < min_dist: + min_dist = cur_dist + + if min_dist > iso_threshold: + print_verbose(10, "---->>> found CAT_OTHER_TEXT/5 for item " + str(it.this_id)) + it.category = CAT_OTHER_TEXT + + # ===================================================================================================================== + # Utilities for Table Extraction + # ===================================================================================================================== + + def find_vnearest_taken_components( + self, initial_item + ): # search vertically, return nearest top and bottom component (if any) + FONT_RECT_TOLERANCE = 0.9 # be less restrictive, since the font rect's width is not always accurate + left_min = initial_item.pos_x + left_max = initial_item.pos_x + initial_item.width * FONT_RECT_TOLERANCE + + top = None + bottom = None + + for it in self.items: + if it.has_category_besides(CAT_FOOTER): + if it.pos_x < left_max and it.pos_x + it.width * FONT_RECT_TOLERANCE > left_min: + if it.pos_y <= initial_item.pos_y and (top is None or it.pos_y > top.pos_y): + top = it + if it.pos_y >= initial_item.pos_y and (bottom is None or it.pos_y < bottom.pos_y): + bottom = it + + return top, bottom + + def sort_out_non_vconnected_items(self, vitems, top_y, bottom_y): + res = [] + for i in vitems: + if self.items[i].pos_y > top_y and self.items[i].pos_y < bottom_y: + res.append(i) + + return res + + def sort_out_taken_items(self, any_items): + res = [] + for i in any_items: + if not self.items[i].has_category_besides(CAT_FOOTER) and self.items[i].temp_assignment == 0: + res.append(i) + + return res + + def sort_out_non_vertical_aligned_items_by_bounding_box(self, initial_item, vitems): + res = [] + + x0 = initial_item.pos_x + x1 = initial_item.pos_x + initial_item.width + + for i in vitems: + cur_x0 = self.items[i].pos_x + cur_x1 = self.items[i].pos_x + self.items[i].width + + if cur_x0 < x1 and cur_x1 > x0: + res.append(i) + + return res + + def sort_out_items_in_same_row(self, initial_item, vitems): + orig_pos_x = initial_item.get_aligned_pos_x() + res = [] + for i in vitems: + it_delta = abs(self.items[i].get_aligned_pos_x() - orig_pos_x) + better_item_in_same_row = False + for j in vitems: + if i == j: + continue + if ( + self.items[j].pos_y + self.items[j].height >= self.items[i].pos_y + and self.items[j].pos_y <= self.items[i].pos_y + self.items[i].height + ): + jt_delta = abs(self.items[j].get_aligned_pos_x() - orig_pos_x) + if jt_delta < it_delta: + better_item_in_same_row = True + break + if not better_item_in_same_row: + res.append(i) + + return res + + def sort_out_non_connected_row_items(self, hitems, initial_item): + # in the same row, we sort out all items, that are not connected with initial_item + # 2 items are connected, iff between them, there are no taken items + # initial_item will always be included + min_x = -1 + max_x = 9999999 + for i in hitems: + if self.items[i].has_category_besides(CAT_FOOTER) or self.items[i].temp_assignment != 0: + cur_x = self.items[i].pos_x + if cur_x < initial_item.pos_x: + min_x = max(min_x, cur_x) + if cur_x > initial_item.pos_x: + max_x = min(max_x, cur_x) + + res = [] + for i in hitems: + if self.items[i].pos_x > min_x and self.items[i].pos_x < max_x: + res.append(i) + + return res + + def discover_table_column(self, initial_item): + print_verbose(7, "discover_table_column for item : " + str(initial_item)) + + vitems, dummy = self.find_vertical_aligned_items(initial_item, ALIGN_DEFAULT, DEFAULT_VTHRESHOLD) + print_verbose(9, "---> 1. V-Items ") + print_subset(9, self.items, vitems) + vitems = self.sort_out_taken_items(vitems) + print_verbose(9, "---> 2. V-Items ") + print_subset(9, self.items, vitems) + vitems = self.sort_out_non_vertical_aligned_items_by_bounding_box(initial_item, vitems) + print_verbose(9, "---> 3. V-Items ") + print_subset(9, self.items, vitems) + vitems = self.sort_out_items_in_same_row(initial_item, vitems) + print_verbose(9, "---> 4. V-Items ") + print_subset(9, self.items, vitems) + + top, bottom = self.find_vnearest_taken_components(initial_item) + vitems = self.sort_out_non_vconnected_items( + vitems, -1 if top is None else top.pos_y, 9999999 if bottom is None else bottom.pos_y + ) + + # make sure, that initial_item is always included + if initial_item.this_id not in vitems: + vitems.append(initial_item.this_id) + + print_verbose(7, "---> top: " + str(top) + ", and bottom:" + str(bottom)) + print_verbose(7, "---> V-Items ") + print_subset(7, self.items, vitems) + + sub_tab = HTMLTable() + if len(vitems) > 0: + sub_tab.init_by_cols(vitems, self.items) + sub_tab.set_temp_assignment() + print_verbose(5, "Sub Table for current column at: " + str(initial_item) + " = " + str(sub_tab)) + + return sub_tab + + def discover_table_row(self, initial_item): + hitems = self.find_horizontal_aligned_items(initial_item) + hitems = self.sort_out_non_connected_row_items(hitems, initial_item) + + print_verbose(7, "discover row at item: " + str(initial_item)) + print_subset(7, self.items, hitems) + + return hitems + + def discover_subtables_recursively( + self, initial_item, step + ): # step=0 => discover col; step=1 => discover row. each subtable is a column + print_verbose(5, "discover subtable rec, at item : " + str(initial_item) + " and step = " + str(step)) + + if initial_item.has_category() or (initial_item.temp_assignment != 0 and step == 0): + print_verbose(5, "---> recusion end") + return [] # end recursion + + if step == 0: # col + res = [] + cur_sub_table = self.discover_table_column(initial_item) + + if cur_sub_table.count_actual_items() > 0: + print_verbose(5, "---> added new subtable") + res.append(cur_sub_table) + for i in cur_sub_table.idx: + if i != -1: + res.extend(self.discover_subtables_recursively(self.items[i], 1)) + return res + + elif step == 1: # row + res = [] + hitems = self.discover_table_row(initial_item) + print_verbose(5, "---> found hitems = " + str(hitems)) + for i in hitems: + res.extend(self.discover_subtables_recursively(self.items[i], 0)) + + return res + + return [] + + def discover_table(self, initial_item): + print_verbose(2, "DISCOVER NEW TABLE AT " + str(initial_item)) + + done = False + + while not done: + done = True + + initial_item.temp_assignment = 0 + self.clear_all_temp_assignments() + + sub_tables = self.discover_subtables_recursively(initial_item, 0) + if len(sub_tables) == 0: + return None + + table = sub_tables[0] + print_verbose(2, "Starting with table: " + str(sub_tables[0])) + + for i in range(1, len(sub_tables)): + print_verbose(5, "Merging table: " + str(sub_tables[i])) + table = HTMLTable.merge(table, sub_tables[i], self.page_width) + print_verbose(5, "Next table:" + str(table)) + + # TODO!!! + # table.recalc_geometry() + # table.unfold_patched_numbers() + table.cleanup_table(self.page_width, self.paragraphs) + + if table.is_good_table(): + # did we miss any items? + missing_items = self.find_items_within_rect( + table.table_rect, [CAT_HEADLINE, CAT_OTHER_TEXT, CAT_RUNNING_TEXT, CAT_FOOTER] + ) + if len(missing_items) > 0: + # yes => reclassify + print_verbose(2, "Found missing items : " + str(missing_items)) + for i in missing_items: + self.items[i].category = CAT_DEFAULT + done = False + + return table + + def mark_all_tables(self): + while True: + next = self.find_next_nonclassified_item() + if next is None: + break # we are done + + table = self.discover_table(next) + print_verbose(2, "FOUND TABLE: " + str(table)) + if config.global_force_special_items_into_table: + table.force_special_items_into_table() + + if table.is_good_table(): + print_verbose(2, "---> good") + table.categorize_as_table() + self.tables.append(table) + else: + print_verbose(2, "---> bad") + table.categorize_as_misc() + + # sort out all empty special items + for t in self.tables: + tmp_sp_idx = [] + for sp_idx in t.special_idx: + if self.items[sp_idx].txt != "": + tmp_sp_idx.append(sp_idx) + else: + self.items[sp_idx].category = CAT_MISC + t.special_idx = tmp_sp_idx + + # merge non-overlapping rows, if needed + if config.global_table_merge_non_overlapping_rows: + for table in self.tables: + table.merge_non_overlapping_rows() + # pass + + def mark_all_footnotes(self): + def apply_cat_unsplit(idx, cat): + self.items[idx].category = cat + if self.items[idx].right_id != -1: + apply_cat_unsplit(self.items[idx].right_id, cat) + + print_verbose(5, "Marking all footnotes . . .") + for idx in range(len(self.items)): + if self.items[idx].left_id != -1: + continue # skip this + txt = self.get_txt_unsplit(idx) + print_verbose(7, "Analyzing==>" + txt + ", cat=" + str(self.items[idx].category)) + if self.items[idx].category != CAT_OTHER_TEXT: + continue # skip this also + if Format_Analyzer.looks_footnote(txt): + # this is a footnote ! + print_verbose(7, ".....>>> Yes, footnote!") + apply_cat_unsplit(idx, CAT_FOOTNOTE) + self.footnotes_idx.append(idx) + + # ===================================================================================================================== + # Rendering + # ===================================================================================================================== + + def render_to_png(self, in_dir, out_dir): + base = ( + Image.open(in_dir + r"/page" + str(self.page_num) + ".png") + .convert("RGBA") + .resize((self.page_width, self.page_height)) + ) + context = ImageDraw.Draw(base) + + if not RENDERING_USE_CLUSTER_COLORS: + table_bg_color = (0, 255, 255, 255) # if not RENDERING_USE_CLUSTER_COLORS else (255, 255, 255, 64) + + # table borders + for t in self.tables: + context.rectangle( + [(t.table_rect.x0, t.table_rect.y0), (t.table_rect.x1, t.table_rect.y1)], + fill=table_bg_color, + outline=(0, 0, 0, 255), + ) + first = True + for r in t.rows: + if first: + first = False + context.line([(r.x0, r.y0), (r.x1, r.y0)], fill=(0, 0, 0, 255), width=0) + context.line([(r.x0, r.y1), (r.x1, r.y1)], fill=(0, 0, 0, 255), width=0) + first = True + for c in t.cols: + if first: + first = False + context.line([(c.x0, c.y0), (c.x0, c.y1)], fill=(0, 0, 0, 255), width=0) + context.line([(c.x1, c.y0), (c.x1, c.y1)], fill=(0, 0, 0, 255), width=0) + + # text + if RENDERING_USE_CLUSTER_COLORS: + self.clusters_text.generate_rendering_colors_rec() + + for it in self.items: + font_color = (0, 0, 255, 255) # default + if it.category in (CAT_RUNNING_TEXT, CAT_HEADLINE, CAT_OTHER_TEXT, CAT_FOOTER): + font_color = (216, 216, 216, 255) + + # if(it.category == CAT_HEADLINE): + # font_color = (0, 128, 32, 255) + + # if(it.category == CAT_OTHER_TEXT): + # font_color = (0, 255, 0, 255) + + if it.category == CAT_TABLE_DATA: + font_color = (0, 0, 0, 255) + + if it.category == CAT_TABLE_HEADLINE: + font_color = (255, 0, 0, 255) + + if it.category == CAT_TABLE_SPECIAL: + font_color = (255, 0, 128, 255) + + if it.category == CAT_FOOTNOTE: + font_color = (127, 0, 255, 255) + + """ # Color by alignment: if(it.alignment == ALIGN_LEFT): font_color = (255, 0, 0, 255) else: font_color = (0, 0, 255, 255) """ - - """ + + """ # Color by split: if(it.has_been_split): font_color = hsv_to_rgba((it.this_id % 6)/6.0, 1, 1) else: font_color = (128, 128, 128, 255) """ - - if(RENDERING_USE_CLUSTER_COLORS): - font_color = it.rendering_color - - span_font = ImageFont.truetype(it.font_file if config.global_rendering_font_override == "" else config.global_rendering_font_override, it.font_size) - context.text((it.pos_x,it.pos_y), it.txt, font=span_font, fill=font_color) - - #for w in it.words: - # context.text(( w.rect.x0, w.rect.y0), w.txt, font=span_font, fill=font_color) - - - - # conntected lines - for it in self.items: - if(it.next_id != -1 and it.category == CAT_RUNNING_TEXT): - line_shape = [(it.pos_x + 5, int(it.pos_y + it.height / 2)), (self.items[it.next_id].pos_x + 5, self.items[it.next_id].pos_y + int(self.items[it.next_id].height / 2))] - context.line(line_shape, fill =(80,80,80,255), width = 0) - - - base.save(out_dir+r'/output'+"{:05d}".format(self.page_num) +'.png') - - - - # ===================================================================================================================== - # Clustering procedures - # ===================================================================================================================== - - def generate_clusters(self): - self.clusters = HTMLCluster.generate_clusters(self.items, CLUSTER_DISTANCE_MODE_EUCLIDIAN) - self.clusters_text = HTMLCluster.generate_clusters(self.items, CLUSTER_DISTANCE_MODE_RAW_TEXT) - - - # ===================================================================================================================== - # Other procedures - # ===================================================================================================================== - - def remove_certain_items(self, txt, threshold): #they would confuse tables and are not needed - count = 0 - for it in self.items: - for w in it.words: - if(w.txt==txt): - count += 1 - - #print('-->count: '+str(count)) - - if(count >= threshold): - new_items = [] - cur_id = 0 - for it in self.items: - new_words = [] - for w in it.words: - if(w.txt!=txt): - w.item_id = cur_id - new_words.append(w) - if(len(new_words)==0): - continue # skip this item - - it.this_id = cur_id - it.words = new_words - it.recalc_geometry() - it.rejoin_words() - new_items.append(it) - cur_id += 1 - - self.items = new_items - - def remove_flyspeck(self): - threshold = DEFAULT_FLYSPECK_HEIGHT * self.page_height - new_items = [] - cur_id = 0 - for it in self.items: - if(it.height>threshold): - it.this_id = cur_id - new_items.append(it) - cur_id += 1 - else: - print_verbose(6,"Removing flyspeck item : " + str(it)) - - - self.items = new_items - - - - def remove_overlapping_items(self): - - keep = [True] * len(self.items) - - for i in range(len(self.items)-1): - if(keep[i]): - for j in range(i+1, len(self.items)): - if(Rect.calc_intersection_area(self.items[i].get_rect(), self.items[j].get_rect()) > 0.): - #overlapping items => remove it - keep[j] = False - print_verbose(5, "Removing item : " + str(self.items[j]) + ", because overlap with : " + str(self.items[i])) - - new_items = [] - cur_id = 0 - for i in range(len(self.items)): - if(keep[i]): - self.items[i].this_id = cur_id - new_items.append(self.items[i]) - cur_id += 1 - - self.items = new_items - - def save_all_tables_to_csv(self, outdir): - for i in range(len(self.tables)): - self.tables[i].save_to_csv(remove_trailing_slash(outdir) + r'/tab_' + str(self.page_num) + r'_' + str(i+1) + r'.csv') - - def save_all_footnotes_to_txt(self, outdir): - if(len(self.footnotes_idx) == 0): - return # dont save if empty - res = "" - for idx in self.footnotes_idx: - res += self.get_txt_unsplit(idx) + "\n" - - save_txt_to_file(res, remove_trailing_slash(outdir) + r'/footnotes_' + str(self.page_num) + r'.txt') - - - def preprocess_data(self): - self.remove_flyspeck() - self.remove_certain_items('.', 50) - #self.remove_certain_items('', 1) - self.remove_overlapping_items() - self.detect_split_items() - self.find_left_distributions() - self.guess_all_alignments() - self.identify_connected_txt_lines() - self.mark_regular_text() - self.mark_other_text_components() - self.find_paragraphs() - self.mark_all_tables() - self.mark_all_footnotes() - self.generate_clusters() - - - #tmp = self.discover_table_column(self.items[self.find_idx_of_item_by_txt(r'Group statement of changes in equity a')]) - #print(tmp) - - """ + + if RENDERING_USE_CLUSTER_COLORS: + font_color = it.rendering_color + + span_font = ImageFont.truetype( + it.font_file if config.global_rendering_font_override == "" else config.global_rendering_font_override, + it.font_size, + ) + context.text((it.pos_x, it.pos_y), it.txt, font=span_font, fill=font_color) + + # for w in it.words: + # context.text(( w.rect.x0, w.rect.y0), w.txt, font=span_font, fill=font_color) + + # conntected lines + for it in self.items: + if it.next_id != -1 and it.category == CAT_RUNNING_TEXT: + line_shape = [ + (it.pos_x + 5, int(it.pos_y + it.height / 2)), + ( + self.items[it.next_id].pos_x + 5, + self.items[it.next_id].pos_y + int(self.items[it.next_id].height / 2), + ), + ] + context.line(line_shape, fill=(80, 80, 80, 255), width=0) + + base.save(out_dir + r"/output" + "{:05d}".format(self.page_num) + ".png") + + # ===================================================================================================================== + # Clustering procedures + # ===================================================================================================================== + + def generate_clusters(self): + self.clusters = HTMLCluster.generate_clusters(self.items, CLUSTER_DISTANCE_MODE_EUCLIDIAN) + self.clusters_text = HTMLCluster.generate_clusters(self.items, CLUSTER_DISTANCE_MODE_RAW_TEXT) + + # ===================================================================================================================== + # Other procedures + # ===================================================================================================================== + + def remove_certain_items(self, txt, threshold): # they would confuse tables and are not needed + count = 0 + for it in self.items: + for w in it.words: + if w.txt == txt: + count += 1 + + # print('-->count: '+str(count)) + + if count >= threshold: + new_items = [] + cur_id = 0 + for it in self.items: + new_words = [] + for w in it.words: + if w.txt != txt: + w.item_id = cur_id + new_words.append(w) + if len(new_words) == 0: + continue # skip this item + + it.this_id = cur_id + it.words = new_words + it.recalc_geometry() + it.rejoin_words() + new_items.append(it) + cur_id += 1 + + self.items = new_items + + def remove_flyspeck(self): + threshold = DEFAULT_FLYSPECK_HEIGHT * self.page_height + new_items = [] + cur_id = 0 + for it in self.items: + if it.height > threshold: + it.this_id = cur_id + new_items.append(it) + cur_id += 1 + else: + print_verbose(6, "Removing flyspeck item : " + str(it)) + + self.items = new_items + + def remove_overlapping_items(self): + keep = [True] * len(self.items) + + for i in range(len(self.items) - 1): + if keep[i]: + for j in range(i + 1, len(self.items)): + if Rect.calc_intersection_area(self.items[i].get_rect(), self.items[j].get_rect()) > 0.0: + # overlapping items => remove it + keep[j] = False + print_verbose( + 5, + "Removing item : " + str(self.items[j]) + ", because overlap with : " + str(self.items[i]), + ) + + new_items = [] + cur_id = 0 + for i in range(len(self.items)): + if keep[i]: + self.items[i].this_id = cur_id + new_items.append(self.items[i]) + cur_id += 1 + + self.items = new_items + + def save_all_tables_to_csv(self, outdir): + for i in range(len(self.tables)): + self.tables[i].save_to_csv( + remove_trailing_slash(outdir) + r"/tab_" + str(self.page_num) + r"_" + str(i + 1) + r".csv" + ) + + def save_all_footnotes_to_txt(self, outdir): + if len(self.footnotes_idx) == 0: + return # dont save if empty + res = "" + for idx in self.footnotes_idx: + res += self.get_txt_unsplit(idx) + "\n" + + save_txt_to_file(res, remove_trailing_slash(outdir) + r"/footnotes_" + str(self.page_num) + r".txt") + + def preprocess_data(self): + self.remove_flyspeck() + self.remove_certain_items(".", 50) + # self.remove_certain_items('', 1) + self.remove_overlapping_items() + self.detect_split_items() + self.find_left_distributions() + self.guess_all_alignments() + self.identify_connected_txt_lines() + self.mark_regular_text() + self.mark_other_text_components() + self.find_paragraphs() + self.mark_all_tables() + self.mark_all_footnotes() + self.generate_clusters() + + # tmp = self.discover_table_column(self.items[self.find_idx_of_item_by_txt(r'Group statement of changes in equity a')]) + # print(tmp) + + """ # TODO: Remove this procedure def test(self): tmp = self.discover_table_column(self.items[self.find_idx_of_item_by_txt(r'$ per barrel')]) @@ -1252,318 +1284,333 @@ def test(self): print(tmp) print(tmp1) print(xx) - """ - - - def __repr__(self): - res = "====>>> HTMLPage : No. = " + str(self.page_num) + ", Width = " + str(self.page_width) + ", Height = " + str(self.page_height) - for i in range(len(self.items)): - res = res + "\n" +str(i)+". "+ str(self.items[i]) - - res += "\n=====>>> LIST OF TABLES:\n" - for t in self.tables: - res = res + str(t) - return res - - def repr_tables_only(self): - res = "====>>> HTMLPage : No. = " + str(self.page_num) - - res += "\n=====>>> LIST OF TABLES:\n" - for t in self.tables: - res = res + t.get_printed_repr() - return res - - - def to_json(self): - for t in self.tables: - t.items = None - - if(self.clusters is not None): - self.clusters.cleanup_for_export() - if(self.clusters_text is not None): - self.clusters_text.cleanup_for_export() - - #data = json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) - jsonpickle.set_preferred_backend('json') - jsonpickle.set_encoder_options('json', sort_keys=True, indent=4) - data = jsonpickle.encode(self) - - for t in self.tables: - t.items = self.items - - if(self.clusters is not None): - self.clusters.regenerate_not_exported(self.items) - if(self.clusters_text is not None): - self.clusters_text.regenerate_not_exported(self.items) - - - return data - - def save_to_file(self, json_file): - data = self.to_json() - f = open(json_file, "w") - f.write(data) - f.close() - - @staticmethod - def load_from_json(data): - obj = jsonpickle.decode(data) - - for t in obj.tables: - t.items = obj.items - - # regenerate clustes, if they are not available - if(obj.clusters is None or obj.clusters_text is None): - obj.generate_clusters() - else: - #just fill up clusters with missing values - obj.clusters.regenerate_not_exported(obj.items) - obj.clusters_text.regenerate_not_exported(obj.items) - - - - return obj - - @staticmethod - def load_from_file(json_file): - f = open(json_file, "r") - data = f.read() - f.close() - return HTMLPage.load_from_json(data) - - - - @staticmethod - def fix_strange_encryption(htmlfile): - #print(htmlfile) - bak_file = htmlfile + ".bak" - - new_bytes = [] - old_bytes = [] - detected_strange_char = 0 - - with open(htmlfile, "rb") as f: - b0 = f.read(1) - while b0: - b_out = 0 - c0 = int.from_bytes(b0, byteorder='big') - if(c0==195): - #strange character found - detected_strange_char += 1 - - b1 = f.read(1) - c1 = int.from_bytes(b1, byteorder='big') - chr = 0 - - if(c1 >= 95 and c1 <= 191): - chr = 191 - c1 + 33 - else: - print("BAD CHARACTER FOUND: 195+-->" + str(c1)) - exit() - - b_out = chr.to_bytes(1, 'big') - - - elif(c0==194): - #strange character found - detected_strange_char += 1 - - b1 = f.read(1) - c1 = int.from_bytes(b1, byteorder='big') - chr = 0 - - if(c1 >= 160 and c1 <= 255): - chr = 255 - c1 + 33 - else: - print("BAD CHARACTER FOUND: 194+-->" + str(c1)) - exit() - - b_out = chr.to_bytes(1, 'big') - else: - b_out = b0 - - nb = int.from_bytes(b_out, byteorder='big') - if(nb == c0 or (nb != 60 and nb!= 62)): #ignore < and > - new_bytes.append(nb) - else: - new_bytes.append(32) #space - - old_bytes.append(c0) - - #print(b_out) - b0 = f.read(1) - - if(detected_strange_char > 10): - if(len(old_bytes) != len(new_bytes)): - #whoops, something went really wrong - print("old_bytes doesnt match new_bytes!") - exit() - - # replace stupid "s" case and check if we really have strange case - hit_body = False - is_txt = True - good_old_bytes = 0 - for i in range(len(new_bytes)): - if(i < len(new_bytes) - 6): - # = 60,98,111,100,121,62 - if(new_bytes[i]==60 and new_bytes[i+1]==98 and new_bytes[i+2]==111 and new_bytes[i+3]==100 and new_bytes[i+4]==121 and new_bytes[i+5]==62): - hit_body = True - if(i>0 and new_bytes[i-1] == 62): - is_txt = True - if(new_bytes[i] == 60): - is_txt = False - if(hit_body and is_txt and new_bytes[i]==45): - new_bytes[i] = 115 #s - if(hit_body and is_txt and old_bytes[i]==new_bytes[i]): - good_old_bytes +=1 - - if(detected_strange_char < good_old_bytes * 2): - return # too few strange characters - - shutil.copy(htmlfile, bak_file) - with open(htmlfile, "wb") as f: - f.write(bytes(new_bytes)) - - - - - @staticmethod - def parse_html_file(fonts_dir, htmlfile): - pattern_pgnum = re.compile('.*page([0-9]+)\\.html') - pattern_background = re.compile('') - pattern_div = re.compile('
(.*)
') - pattern_span = re.compile('(.*)') - #
direct GHGs)
- - pattern_bbox = re.compile('\(([0-9]+\.[0-9]+),([0-9]+\.[0-9]+)\)-\(([0-9]+\.[0-9]+),([0-9]+\.[0-9]+)\)-->(.*)') - # - - pattern_font = re.compile('#f([0-9]+) { font-family:ff([0-9]+)[^0-9].*; }') - pattern_bold = re.compile('#f([0-9]+) { font-family:.*; font-weight:bold; font-style:.*; }') - ##f1 { font-family:ff0,sans-serif; font-weight:bold; font-style:normal; } - - pattern_font_url = re.compile(r'@font-face { font-family: ff([0-9]+); src: url\("([0-9]+\.ttf|([0-9]+)\.otf)"\); }') - #@font-face { font-family: ff24; src: url("24.ttf"); } - - - print_verbose(2, "PARSING HTML-FILE " + htmlfile) - - page_num = int(pattern_pgnum.match(htmlfile).groups()[0]) - print_verbose(4, "---> Page: " + str(page_num)) - bold_styles = [] - font_dict = {} - font_url_dict = {} - - res = HTMLPage() - - cur_item_id = 0 - - with open(htmlfile, errors='ignore', encoding=config.global_html_encoding) as f: - html_file = f.readlines() - - for i in range(0, len(html_file)): - h = html_file[i].strip() - print_verbose(7, '---->' + h) - if(pattern_background.match(h)): - bg = pattern_background.match(h).groups() - res.page_num = int(bg[2+2]) - res.page_width = int(bg[0+2]) - res.page_height = int(bg[1+2]) - if(pattern_font.match(h)): - f = pattern_font.match(h).groups() - font_dict[int(f[0])] = int(f[1]) - print_verbose(7, 'Font-> f='+str(f)) - if(pattern_font_url.match(h)): - fu = pattern_font_url.match(h).groups() - font_url_dict[int(fu[0])] = fu[1] - if(pattern_bold.match(h)): - b = pattern_bold.match(h).groups() - bold_styles.append(b[0]) - print_verbose(7, 'Bold->' + str(b)) - if(pattern_div.match(h)): - g = pattern_div.match(h).groups() - print_verbose(7, '-------->' + str(g)) - spans = g[2+1].split('
') - print_verbose(7, '-------->' + str(spans)) - item = None - item = HTMLItem() - item.line_num = i - item.tot_line_num = page_num * 10000 + i - item.this_id = cur_item_id - #item.pos_x = int(g[0+1]) - #item.pos_y = int(g[1+1]) - item.txt = '' - item.is_bold = False - #item.width = 0 - #item.height = 0 - item.font_size = 0 - item.brightness = 255 - item.page_num = page_num - space_width = 0 - for s in spans: - if(pattern_span.match(s)): - gs = pattern_span.match(s).groups() - print_verbose(7, '---------->' + str(gs)) - if(gs[0] in bold_styles): - item.is_bold = True - - #print("font_url_dict="+str(font_url_dict)) - #print("font_dict="+str(font_dict)) - - span_font = None - if(int(gs[0]) in font_dict and font_dict[int(gs[0])] in font_url_dict): - span_font = ImageFont.truetype(fonts_dir + '/' + font_url_dict[font_dict[int(gs[0])]] , int(gs[1])) - item.font_file = fonts_dir + '/' + font_url_dict[font_dict[int(gs[0])]] - else: - span_font = ImageFont.truetype(config.global_approx_font_name , int(gs[1])) - item.font_file = config.global_approx_font_name - - try: - space_width = max(space_width, get_text_width(' ', span_font), get_text_width('x', span_font)) - except: - span_font = ImageFont.truetype(config.global_approx_font_name , int(gs[1])) - item.font_file = config.global_approx_font_name - space_width = max(space_width, get_text_width(' ', span_font), get_text_width('x', span_font)) - - #text_width = get_text_width(gs[6], int(gs[1]), span_font) - #if(text_width == 0): - # print_verbose(5, "Warning! Found font with 0 for "+ str(fonts_dir + '\\' + font_url_dict[font_dict[int(gs[0])]]) + ", h="+str(int(gs[1]))+", txt='"+gs[6]) - # text_width = get_text_width("x"*max(1,len(gs[6])), int(gs[1]), span_font) #approximate - #item.width += text_width - #item.height = max(item.height, int(gs[1])) - item.font_size = max(item.font_size, int(gs[1])) - #item.txt += (' ' if item.txt != '' else '') + gs[6] if gs[6] != '' else '-' - item.brightness = min(item.brightness, (int(gs[2])+int(gs[2])+int(gs[2]))/3) - span_text = gs[6] - bboxes = span_text.split('" + str(c1)) + exit() + + b_out = chr.to_bytes(1, "big") + + elif c0 == 194: + # strange character found + detected_strange_char += 1 + + b1 = f.read(1) + c1 = int.from_bytes(b1, byteorder="big") + chr = 0 + + if c1 >= 160 and c1 <= 255: + chr = 255 - c1 + 33 + else: + print("BAD CHARACTER FOUND: 194+-->" + str(c1)) + exit() + + b_out = chr.to_bytes(1, "big") + else: + b_out = b0 + + nb = int.from_bytes(b_out, byteorder="big") + if nb == c0 or (nb != 60 and nb != 62): # ignore < and > + new_bytes.append(nb) + else: + new_bytes.append(32) # space + + old_bytes.append(c0) + + # print(b_out) + b0 = f.read(1) + + if detected_strange_char > 10: + if len(old_bytes) != len(new_bytes): + # whoops, something went really wrong + print("old_bytes doesnt match new_bytes!") + exit() + + # replace stupid "s" case and check if we really have strange case + hit_body = False + is_txt = True + good_old_bytes = 0 + for i in range(len(new_bytes)): + if i < len(new_bytes) - 6: + # = 60,98,111,100,121,62 + if ( + new_bytes[i] == 60 + and new_bytes[i + 1] == 98 + and new_bytes[i + 2] == 111 + and new_bytes[i + 3] == 100 + and new_bytes[i + 4] == 121 + and new_bytes[i + 5] == 62 + ): + hit_body = True + if i > 0 and new_bytes[i - 1] == 62: + is_txt = True + if new_bytes[i] == 60: + is_txt = False + if hit_body and is_txt and new_bytes[i] == 45: + new_bytes[i] = 115 # s + if hit_body and is_txt and old_bytes[i] == new_bytes[i]: + good_old_bytes += 1 + + if detected_strange_char < good_old_bytes * 2: + return # too few strange characters + + shutil.copy(htmlfile, bak_file) + with open(htmlfile, "wb") as f: + f.write(bytes(new_bytes)) + + @staticmethod + def parse_html_file(fonts_dir, htmlfile): + pattern_pgnum = re.compile(".*page([0-9]+)\\.html") + pattern_background = re.compile( + '' + ) + pattern_div = re.compile( + '
(.*)
' + ) + pattern_span = re.compile( + '(.*)' + ) + #
direct GHGs)
+ + pattern_bbox = re.compile("\(([0-9]+\.[0-9]+),([0-9]+\.[0-9]+)\)-\(([0-9]+\.[0-9]+),([0-9]+\.[0-9]+)\)-->(.*)") + # + + pattern_font = re.compile("#f([0-9]+) { font-family:ff([0-9]+)[^0-9].*; }") + pattern_bold = re.compile("#f([0-9]+) { font-family:.*; font-weight:bold; font-style:.*; }") + ##f1 { font-family:ff0,sans-serif; font-weight:bold; font-style:normal; } + + pattern_font_url = re.compile( + r'@font-face { font-family: ff([0-9]+); src: url\("([0-9]+\.ttf|([0-9]+)\.otf)"\); }' + ) + # @font-face { font-family: ff24; src: url("24.ttf"); } + + print_verbose(2, "PARSING HTML-FILE " + htmlfile) + + page_num = int(pattern_pgnum.match(htmlfile).groups()[0]) + print_verbose(4, "---> Page: " + str(page_num)) + bold_styles = [] + font_dict = {} + font_url_dict = {} + + res = HTMLPage() + + cur_item_id = 0 + + with open(htmlfile, errors="ignore", encoding=config.global_html_encoding) as f: + html_file = f.readlines() + + for i in range(0, len(html_file)): + h = html_file[i].strip() + print_verbose(7, "---->" + h) + if pattern_background.match(h): + bg = pattern_background.match(h).groups() + res.page_num = int(bg[2 + 2]) + res.page_width = int(bg[0 + 2]) + res.page_height = int(bg[1 + 2]) + if pattern_font.match(h): + f = pattern_font.match(h).groups() + font_dict[int(f[0])] = int(f[1]) + print_verbose(7, "Font-> f=" + str(f)) + if pattern_font_url.match(h): + fu = pattern_font_url.match(h).groups() + font_url_dict[int(fu[0])] = fu[1] + if pattern_bold.match(h): + b = pattern_bold.match(h).groups() + bold_styles.append(b[0]) + print_verbose(7, "Bold->" + str(b)) + if pattern_div.match(h): + g = pattern_div.match(h).groups() + print_verbose(7, "-------->" + str(g)) + spans = g[2 + 1].split("
") + print_verbose(7, "-------->" + str(spans)) + item = None + item = HTMLItem() + item.line_num = i + item.tot_line_num = page_num * 10000 + i + item.this_id = cur_item_id + # item.pos_x = int(g[0+1]) + # item.pos_y = int(g[1+1]) + item.txt = "" + item.is_bold = False + # item.width = 0 + # item.height = 0 + item.font_size = 0 + item.brightness = 255 + item.page_num = page_num + space_width = 0 + for s in spans: + if pattern_span.match(s): + gs = pattern_span.match(s).groups() + print_verbose(7, "---------->" + str(gs)) + if gs[0] in bold_styles: + item.is_bold = True + + # print("font_url_dict="+str(font_url_dict)) + # print("font_dict="+str(font_dict)) + + span_font = None + if int(gs[0]) in font_dict and font_dict[int(gs[0])] in font_url_dict: + span_font = ImageFont.truetype( + fonts_dir + "/" + font_url_dict[font_dict[int(gs[0])]], int(gs[1]) + ) + item.font_file = fonts_dir + "/" + font_url_dict[font_dict[int(gs[0])]] + else: + span_font = ImageFont.truetype(config.global_approx_font_name, int(gs[1])) + item.font_file = config.global_approx_font_name + + try: + space_width = max( + space_width, get_text_width(" ", span_font), get_text_width("x", span_font) + ) + except: + span_font = ImageFont.truetype(config.global_approx_font_name, int(gs[1])) + item.font_file = config.global_approx_font_name + space_width = max( + space_width, get_text_width(" ", span_font), get_text_width("x", span_font) + ) + + # text_width = get_text_width(gs[6], int(gs[1]), span_font) + # if(text_width == 0): + # print_verbose(5, "Warning! Found font with 0 for "+ str(fonts_dir + '\\' + font_url_dict[font_dict[int(gs[0])]]) + ", h="+str(int(gs[1]))+", txt='"+gs[6]) + # text_width = get_text_width("x"*max(1,len(gs[6])), int(gs[1]), span_font) #approximate + # item.width += text_width + # item.height = max(item.height, int(gs[1])) + item.font_size = max(item.font_size, int(gs[1])) + # item.txt += (' ' if item.txt != '' else '') + gs[6] if gs[6] != '' else '-' + item.brightness = min(item.brightness, (int(gs[2]) + int(gs[2]) + int(gs[2])) / 3) + span_text = gs[6] + bboxes = span_text.split(" Can't merge rows " + + str(r0) + + " and next one, because c0=" + + str(c0) + + ", c1=" + + str(c1) + + " would overlap", + ) + print_verbose( + 7, "These are the items " + str(self.get_item(r0, c0)) + ", and " + str(self.get_item(r1, c1)) + ) + return False # two columns would overlaping in the same row + if x1_0 < x1_1: + c0 += 1 + else: + c1 += 1 + elif self.has_item_at(r0, c0): + c1 += 1 + else: + c0 += 1 + + print_verbose(7, "----> Merge r0=" + str(r0) + " where y0_max,y1_min = " + str(y0_max) + "," + str(y1_min)) + return True + + def is_col_mergable(self, c0): # return True iff col c0 and c0+1 can be merged + if c0 < 0 or c0 >= self.num_cols - 1: + raise ValueError("Cols c0=" + str(c0) + " and c0 out of range") + + for i in range(self.num_rows): + if self.has_item_at(i, c0) and self.has_item_at(i, c0 + 1): + return False + + return True + + def merge_rows(self, r0, reconnect=False): # merge rows r0 and r0+1 + if r0 < 0 or r0 >= self.num_rows - 1: + raise ValueError("Rows r0=" + str(r0) + " and r0+1 out of range.") + + for j in range(self.num_cols): + ix0 = self.get_ix(r0, j) + ix1 = self.get_ix(r0 + 1, j) + if self.idx[ix0] == -1: + self.idx[ix0] = self.idx[ix1] + elif self.idx[ix1] == -1: + # self.idx[ix0] = ix0 + pass + else: + if reconnect: + # print("\n\n\n\nRECONNECT: "+str(self.items[self.idx[ix0]].this_id) + " <-> " + str(self.items[self.idx[ix1]].this_id))#zz + self.items[self.idx[ix0]].reconnect(self.items[self.idx[ix1]], self.items) + self.items[self.idx[ix0]].merge(self.items[self.idx[ix1]]) + # if(reconnect):#zz + # print("After reconnect: " + str(self.items))#zz + self.delete_rows(r0 + 1, r0 + 2, False) + + def merge_cols(self, c0): # merge cols c0 and c0+1 + if c0 < 0 or c0 >= self.num_cols - 1: + raise ValueError("Cols c0=" + str(c0) + " and c0 out of range") + + for i in range(self.num_rows): + ix0 = self.get_ix(i, c0) + ix1 = self.get_ix(i, c0 + 1) + if self.idx[ix0] == -1: + self.idx[ix0] = self.idx[ix1] + + self.delete_cols(c0 + 1, c0 + 2, False) + + def merge_down_all_rows(self): + # print_subset(0, self.items, self.idx) + for i in range(self.num_rows - 1): + if self.is_row_mergable(i): + self.merge_rows(i) + self.merge_down_all_rows() + return + + def merge_down_all_cols(self): + for j in range(self.num_cols - 1): + if self.is_col_mergable(j): + self.merge_cols(j) + self.merge_down_all_cols() + return + + def is_empty_row(self, r): + for j in range(self.num_cols): + if self.has_item_at(r, j): + return False + return True + + def is_empty_col(self, c): + for i in range(self.num_rows): + if self.has_item_at(i, c): # and self.get_item(i,c).txt != ''): + # print('HAS ITEM AT '+str(i)+','+str(c)+', namely: '+str(self.get_item(i,c))) + return False + return True + + def compactify(self): # remove all empty rows and cols + has_changed = True + while has_changed: + has_changed = False + # rows + for i in range(self.num_rows - 1, -1, -1): + if self.is_empty_row(i): + print_verbose(7, "Delete empty row : " + str(i)) + has_changed = True + self.delete_rows(i, i + 1, False) + # cols + for j in range(self.num_cols - 1, -1, -1): + if self.is_empty_col(j): + print_verbose(7, "Delete empty column : " + str(j)) + has_changed = True + self.delete_cols(j, j + 1, False) + + self.recalc_geometry() + + def throw_away_non_connected_rows(self): # throw away rows that are probably not connected to the table + def is_connected_row(r0): + min_y0 = 9999999 + max_y0 = -1 + min_y1 = 9999999 + for j in range(self.num_cols): + if self.has_item_at(r0, j): + min_y0 = min(min_y0, self.get_item(r0, j).pos_y) + max_y0 = max(max_y0, self.get_item(r0, j).pos_y + self.get_item(r0, j).height) + if self.has_item_at(r0 + 1, j): + min_y1 = min(min_y1, self.get_item(r0 + 1, j).pos_y) + + if min_y1 == 9999999 or min_y0 == 9999999: + return True # at least one row was empty => consider always as connected + + y_limit = ( + min_y0 + (max_y0 - min_y0) * config.global_row_connection_threshold + ) # at least space that woulld occupy 4x row r0 are (not) empty #New: 27.06.2022 (was previously: * 4) + + if min_y1 <= y_limit: + return True + + sp_idx = self.find_special_item_idx_in_rect(self.rows[r0]) + + for k in sp_idx: + cur_y = self.items[k].pos_y + if cur_y <= y_limit: + return True + + return False + + for i in range(self.num_rows - 1): + if not is_connected_row(i): + print_verbose(5, "Throw away non-connected rows after /excl. :" + str(i)) + print_verbose(5, "Current table: " + str(self)) + self.delete_rows(i + 1, self.num_rows) + return + + def throw_away_rows_after_new_header(self): + if self.num_cols < 2 or self.num_rows < 3: + return + + num_numeric_rows = 0 + num_rows_with_left_txt = 0 + last_pos_y = 9999999 + last_delta_y = 9999999 + + for i in range(self.num_rows): + if self.has_item_at(i, 0) and Format_Analyzer.looks_words(self.get_item(i, 0).txt): + num_rows_with_left_txt += 1 + cur_numeric_values = 0 + cur_header_values = 0 + cur_other_values = 0 + cur_pos_y = 9999999 + for j in range(1, self.num_cols): + if self.has_item_at(i, j): + if cur_pos_y == 9999999: + cur_pos_y = self.get_item(i, j).pos_y + txt = self.get_item(i, j).txt + if Format_Analyzer.looks_numeric(txt) and not Format_Analyzer.looks_year(txt): + cur_numeric_values += 1 + elif (Format_Analyzer.looks_words(txt) and txt[0].isupper()) or Format_Analyzer.looks_year(txt): + cur_header_values += 1 + else: + cur_other_values += 1 + + if cur_numeric_values > max(self.num_cols * 0.6, 1): + num_numeric_rows += 1 + + cur_delta_y = cur_pos_y - last_pos_y + + if num_rows_with_left_txt > 2 and num_numeric_rows > 2 and cur_delta_y > last_delta_y * 1.05 + 2: + if cur_numeric_values == 0 and cur_header_values > 0 and cur_other_values == 0: + print_verbose( + 5, + "Throw away non-connected rows after probably new headline at row = " + + str(i) + + ", cur/last_delta_y=" + + str(cur_delta_y) + + "/" + + str(last_delta_y), + ) + self.delete_rows(i, self.num_rows) + return + + last_pos_y = cur_pos_y + last_delta_y = cur_delta_y + + def throw_away_last_headline(self): + # a headline at the end of a table probably doesnt belong to it, rather, it belongs to the next table + if self.num_rows < 2 or self.num_cols < 2: + return + + if not self.has_non_empty_item_at(self.num_rows - 1, 0): + return + + if Format_Analyzer.looks_numeric(self.get_item(self.num_rows - 1, 0).txt): + return + + for j in range(1, self.num_cols): + if self.has_non_empty_item_at(self.num_rows - 1, j): + return + + self.delete_rows(self.num_rows - 1, self.num_rows) + + def throw_away_non_connected_cols(self, page_width): # throw away cols that are probably not connected to the table + def is_connected_col(c0): + min_x0 = 9999999 + max_x0 = -1 + min_x1 = 9999999 + for i in range(self.num_rows): + if self.has_item_at(i, c0): + min_x0 = min(min_x0, self.get_item(i, c0).pos_x) + max_x0 = max(max_x0, self.get_item(i, c0).pos_x + self.get_item(i, c0).width) + if self.has_item_at(i, c0 + 1): + min_x1 = min(min_x1, self.get_item(i, c0 + 1).pos_x) + + if min_x1 == 9999999 or min_x0 == 9999999: + return True # at least one col was empty => consider always as connected + + if min_x1 <= max_x0 + DEFAULT_HTHROWAWAY_DIST * page_width: # at least that much space is not empty + return True + + num_reg_text = 0 + for i in range(self.num_rows): + if not self.has_item_at(i, c0): + continue + y0 = self.get_item(i, c0).pos_y + y1 = self.get_item(i, c0).pos_y + self.get_item(i, c0).height + x0 = max((min_x0 + max_x0) / 2.0, self.get_item(i, c0).pos_x) + x1 = min_x1 + cur_rect = Rect(x0, y0, x1, y1) + for it in self.items: + if it.category not in [CAT_HEADLINE, CAT_OTHER_TEXT, CAT_RUNNING_TEXT, CAT_FOOTER]: + continue + it_rect = it.get_rect() + if Rect.calc_intersection_area(cur_rect, it_rect) > 0: + print_verbose( + 2, "----->> With " + str(self.get_item(i, c0)) + " the item " + str(it) + " overlaps" + ) + num_reg_text += 1 + + # num_reg_text = self.count_regular_items_in_rect(Rect( (min_x0 + max_x0)/2.0, self.table_rect.y0, min_x1, self.table_rect.y1) ) + + return num_reg_text == 0 + + for j in range(self.num_cols - 1): + if not is_connected_col(j): + print_verbose(5, "Throw away non-connected cols after /excl. :" + str(j)) + print_verbose(5, "Current table: " + str(self)) + self.delete_cols(j + 1, self.num_cols) + return + + def throw_away_cols_at_next_paragraph(self, paragraphs): + def find_cur_paragraph_idx(x0, x1, my_paragraphs): + max_overlap = 0 + res = -1 + for i in range(len(my_paragraphs)): + p0 = my_paragraphs[i] + p1 = my_paragraphs[i + 1] if i < len(my_paragraphs) - 1 else 9999999 + overlap = max(min(x1, p1) - max(x0, p0), 0) + if overlap > max_overlap: + max_overlap = overlap + res = i + return res + + if self.num_cols == 0 or len(paragraphs) < 2: + return + + self.recalc_geometry() + + # Find relevant paragraphs + my_paragraphs = [] + for p in paragraphs: + cnt = 0 + for it in self.items: + if ( + it.pos_x == p + and it.category in (CAT_RUNNING_TEXT, CAT_HEADLINE) + and it.pos_x < self.table_rect.x1 + and it.pos_x + it.width > self.table_rect.x0 + ): + if ( + it.pos_y < self.table_rect.y1 and it.pos_y + it.height * 5 > self.table_rect.y0 + ) or ( # TODO 5 was 10, test it + it.pos_y > self.table_rect.y0 and it.pos_y - it.height * 5 < self.table_rect.y1 + ): # TODO 5 was 10, test it + cnt += 1 + # print(it) + if cnt > 3: + my_paragraphs.append(p) + + my_paragraphs.sort() + + print_verbose(5, "Relevant paragraphs:" + str(my_paragraphs)) + + last_para_idx = -1 + + for j in range(self.num_cols): + x0 = 9999999 + x1 = -1 + for i in range(self.num_rows): + if self.has_item_at(i, j): + x0 = min(x0, self.get_item(i, j).pos_x) + x1 = max(x1, self.get_item(i, j).pos_x + self.get_item(i, j).width) + if x1 == -1: + continue # b/c empty col + + # print(j, x0, x1, my_paragraphs) + cur_para_idx = find_cur_paragraph_idx(x0, x1, my_paragraphs) + print_verbose( + 7, + "--> Col =" + + str(j) + + ", x0/x1=" + + str(self.cols[j].x0) + + "/" + + str(self.cols[j].x1) + + " belong to paragraph p_idx = " + + str(cur_para_idx) + + ", which is at " + + str(my_paragraphs[cur_para_idx] if cur_para_idx != -1 else None) + + " px", + ) + if j > 1 and cur_para_idx != last_para_idx: # TODO 1 was 0, test is + # table is here probably split between two text paragraphs + print_verbose(5, "Throw away cols at next paragraph: col at next j=" + str(j)) + self.delete_cols(j, self.num_cols) + return + last_para_idx = cur_para_idx + + def throw_away_cols_after_year_list(self): # throw away cols after a list of years is over + class YearCols: + r = None + c0 = None + c1 = None + + def __init__(self, r, c0): + self.r = r + self.c0 = c0 + + def __repr__(self): + return "(r=" + str(self.r) + ",c0=" + str(self.c0) + ",c1=" + str(self.c1) + ")" + + if self.num_cols < 5: + return + + year_cols = [] + + # print(str(self)) + # print("Cols:"+str(self.num_cols)) + + for i in range(self.num_rows): + j = 0 + while j < self.num_cols - 1: + if ( + self.has_item_at(i, j) + and self.has_item_at(i, j + 1) + and Format_Analyzer.looks_year(self.get_item(i, j).txt) + and Format_Analyzer.looks_year(self.get_item(i, j + 1).txt) + and abs( + Format_Analyzer.to_year(self.get_item(i, j + 1).txt) + - Format_Analyzer.to_year(self.get_item(i, j).txt) + ) + == 1 + ): + dir = Format_Analyzer.to_year(self.get_item(i, j + 1).txt) - Format_Analyzer.to_year( + self.get_item(i, j).txt + ) + cur_year_cols = YearCols(i, j) + # find last year col + # print("Now at cell:"+str(i)+","+str(j)) + for j1 in range(j + 1, self.num_cols): + if ( + self.has_item_at(i, j1) + and Format_Analyzer.looks_year(self.get_item(i, j1).txt) + and Format_Analyzer.to_year(self.get_item(i, j1).txt) + - Format_Analyzer.to_year(self.get_item(i, j1 - 1).txt) + == dir + ): + cur_year_cols.c1 = j1 + else: + break + year_cols.append(cur_year_cols) + j = cur_year_cols.c1 + j += 1 + + print_verbose(6, "----->> Found year lists at: " + str(year_cols)) + + if len(year_cols) < 2: + return + + # test if we can throw away + + # find min + + min_yc = -1 + for k in range(len(year_cols)): + if min_yc == -1 or (year_cols[k].c0 < year_cols[min_yc].c0): + min_yc = k + + print_verbose(6, "------->> min:" + str(year_cols[min_yc]) + " at idx " + str(min_yc)) + + # find overlapping max + max_overlap_yc = -1 + for k in range(len(year_cols)): + if year_cols[k].c0 < year_cols[min_yc].c1: + if max_overlap_yc == -1 or (year_cols[k].c1 > year_cols[max_overlap_yc].c1): + max_overlap_yc = k + + print_verbose(6, "------->> max overlap:" + str(year_cols[max_overlap_yc]) + " at idx " + str(max_overlap_yc)) + + # are there any year cols after max overlap? if so, we can throw that part away + can_throw_away = False + for yc in year_cols: + if yc.c0 > year_cols[max_overlap_yc].c1: + print_verbose(6, "-------->> throw away because : " + str(yc)) + can_throw_away = True + + if can_throw_away: + self.delete_cols(year_cols[max_overlap_yc].c1 + 1, self.num_cols) + + def throw_away_duplicate_looking_cols( + self, + ): # throw away columns that are looking like duplicates, and indicating another table + def are_cols_similar(c0, c1): + return ( + self.col_looks_like_text_col(c0) + and self.col_looks_like_text_col(c1) + and Format_Analyzer.cnt_overlapping_items(self.get_all_cols_as_text(c0), self.get_all_cols_as_text(c1)) + > 3 + ) + + if self.num_cols < 3: + return + + for i in range(2, self.num_cols): + if are_cols_similar(0, i): + print_verbose(7, "------->> cols 0 and " + str(i) + " are similar. Throw away from " + str(i)) + self.delete_cols(i, self.num_cols) + return + + def identify_headline(self): + if self.num_rows == 0 or self.num_cols == 0: + return + + if not self.has_item_at(0, 0) or not Format_Analyzer.looks_words(self.get_item(0, 0).txt): + return + + for j in range(1, self.num_cols): + if self.has_item_at(0, j): + return + + self.headline_idx.append(self.get_idx(0, 0)) + self.delete_rows(0, 1) + + def identify_non_numeric_special_items(self): + def col_looks_numeric(c0): + num_numbers = 0 + num_words = 0 + for i in range(self.num_rows): + if self.has_item_at(i, c0): + txt = self.get_item(i, c0).txt + if Format_Analyzer.looks_numeric(txt): + num_numbers += 1 + elif Format_Analyzer.looks_words(txt): + num_words += 1 + return num_numbers >= 3 and num_words < num_numbers * 0.4 + + def is_only_item_in_row(r0, c0): + for j in range(c0): + if self.has_item_at(r0, j): + return False + for j in range(c0 + 1, self.num_cols): + if self.has_item_at(r0, j): + return False + return True + + for j in range(self.num_cols): + if col_looks_numeric(j): + print_verbose(5, "Numeric col found : " + str(j)) + # find first possible special item of this col + r0 = self.find_first_non_empty_row_in_col(j) + r1 = r0 + font_char = self.get_item(r0, j).get_font_characteristics() + for i in range(r0, self.num_rows): + r1 = i + if not self.has_item_at(i, j): + r1 = i + 1 + break # empty line + if Format_Analyzer.looks_numeric(self.get_item(i, j).txt): + r1 = i + 1 + break # number + if self.get_item(i, j).get_font_characteristics() != font_char: + break # different font + + # remove now special items (but not before first occurence) + for i in range(r1, self.num_rows): + if self.has_item_at(i, j): + txt = self.get_item(i, j).txt + if Format_Analyzer.looks_words(txt) or Format_Analyzer.looks_other_special_item(txt): + cur_idx = self.get_idx(i, j) + self.idx[self.get_ix(i, j)] = -1 + self.special_idx.append(cur_idx) + + # remove even special items from headline, but only if there are no other headline items + if ( + self.has_item_at(r0, j) + and is_only_item_in_row(r0, j) + and ( + Format_Analyzer.looks_words(self.get_item(r0, j).txt) + or Format_Analyzer.looks_other_special_item(self.get_item(r0, j).txt) + ) + ): + cur_idx = self.get_idx(r0, j) + self.idx[self.get_ix(r0, j)] = -1 + self.special_idx.append(cur_idx) + """ + if(self.has_item_at(0, j) and is_only_item_in_row(0, j) and (Format_Analyzer.looks_words(self.get_item(0, j).txt) or Format_Analyzer.looks_other_special_item(self.get_item(0, j).txt))): + cur_idx = self.get_idx(0, j) + self.idx[self.get_ix(0,j)] = -1 + self.special_idx.append(cur_idx) + """ -class HTMLTable: - # refers to a table extracted by the HTML-files, to which PDFs have been converted - # this is not related with an acutal table in HTML syntax - - rows = None - cols = None - idx = None - num_rows = None - num_cols = None - items = None #dont export - marks = None - col_aligned_pos_x = None - headline_idx = None - special_idx = None # e.g., annonations - table_rect = None - - - def __init__(self): - self.rows = [] - self.cols = [] - self.idx = [] - self.items = [] - self.num_rows = 0 - self.num_cols = 0 - self.marks = [] - self.col_aligned_pos_x = [] - self.headline_idx = [] - self.special_idx = [] - self.table_rect = Rect(99999,99999,-1,-1) - - - def get_ix(self, i, j): # i=row, j=col - return i * self.num_cols + j - - def get_row_and_col_by_ix(self, ix): - return ix//self.num_cols, ix%self.num_cols - - - def get_idx(self, i, j): # i=row, j=col - return self.idx[i * self.num_cols + j] - - def get_item(self, i, j): # i=row, j=col - ix = self.get_idx(i, j) - return self.items[ix] if ix >= 0 else None - - def get_item_by_ix(self, i): # i=ix - if(i < 0): - return None - ix = self.idx[i] - return self.items[self.idx[i]] if ix >= 0 else None - - def has_item_at_ix(self, i): # i=ix - if(i < 0): - return False - return self.idx[i] >= 0 - - def has_item_at(self, i, j): # i=row, j=col - return self.idx[i * self.num_cols + j] >= 0 - - def has_non_empty_item_at(self, i, j): - return self.has_item_at(i, j) and self.get_item(i, j).txt != '' - - def count_marks(self, mark): - return self.marks.count(mark) - - def reset_marks(self): - for i in range(len(self.marks)): - self.marks[i] = 0 if self.idx[i] >= 0 else 9999999 - - def set_temp_assignment(self, value=1): - for i in self.idx: - self.items[i].temp_assignment = value - - def count_actual_items(self): - return len(self.idx) - self.idx.count(-1) - - def get_all_idx(self): - return self.idx + self.headline_idx + self.special_idx - - def find_applying_special_item_ix(self, r0): - # precondition : special_idx must be sorted pos_y ascending - for i in range(len(self.special_idx)-1, -1, -1): - if(self.rows[r0].y0 + self.items[self.special_idx[i]].height * 0.2 >= self.items[self.special_idx[i]].pos_y): - return i - return None - - def find_applying_special_item(self, r0): - ix = self.find_applying_special_item_ix(r0) - return None if ix is None else self.items[self.special_idx[ix]] - - - def find_special_item_idx_in_rect(self, rect): - res = [] - for i in self.special_idx: - if(Rect.calc_intersection_area(self.items[i].get_rect(), rect) / self.items[i].get_rect().get_area() > 0.1): - res.append(i) - return res - - def find_nearest_cell_ix(self, item): - pos_x = item.get_aligned_pos_x() - pos_y = item.pos_y + item.height - r0 = self.num_rows-1 - c0 = self.num_cols-1 - for c in range(len(self.cols)): - if(self.cols[c].x0 > pos_x): - c0 = max(0, c-1) - break - - for r in range(len(self.rows)): - if(self.rows[r].y0 > pos_y): - r0 = max(0, r-1) - break - - return self.get_ix(r0, c0) - - def count_regular_items_in_rect(self, rect): - res = 0 - for it in self.items: - if(it.category in [CAT_HEADLINE, CAT_OTHER_TEXT, CAT_RUNNING_TEXT, CAT_FOOTER] and Rect.calc_intersection_area(it.get_rect(), rect) / it.get_rect().get_area() > 0.1 ): - res += 1 - return res - - def unfold_idx_to_items(self, idx_list): - res = [] - for i in idx_list: - res.append(self.items[i]) - return res - - def unfold_ix_to_idx(self, ix_list): - res = [] - for i in ix_list: - res.append(self.idx[i]) - return res - - - def find_applying_special_item(self, r0): - # precondition : special_idx must be sorted pos_y ascending - ix = self.find_applying_special_item_ix(r0) - return self.items[self.special_idx[ix]] if ix is not None else None - - def find_first_non_empty_row_in_col(self, c0): - for i in range(self.num_rows): - if(self.has_item_at(i, c0)): - return i - return self.num_rows - - - - def col_looks_like_text_col(self, c0): - num_numbers = 0 - num_words = 0 - for i in range(self.num_rows): - if(self.has_item_at(i, c0)): - txt = self.get_item(i, c0).txt - #print(txt) - if(Format_Analyzer.looks_numeric(txt)): - num_numbers += 1 - #print('.. looks numeric') - elif(Format_Analyzer.looks_words(txt)): - num_words += 1 - #print('.. looks words') - return num_words >= 5 and num_words > num_numbers * 0.3 - - def get_all_cols_as_text(self, c0): - res = [] - for i in range(self.num_rows): - if(self.has_item_at(i, c0)): - txt = self.get_item(i, c0).txt - res.append(txt) - return res - - - - def recalc_geometry(self): - - - self.rows = [] - self.cols = [] - self.col_aligned_pos_x = [] - - table_rect = Rect(9999999, 9999999, -1, -1) - - # Calc table rect - for i in range(self.num_rows): - for j in range(self.num_cols): - if(self.has_item_at(i, j)): - table_rect.grow(self.get_item(i, j).get_rect()) - - self.table_rect = table_rect - - # Calc row rects - for i in range(self.num_rows): - row_rect = Rect(table_rect.x0, 9999999, table_rect.x1, table_rect.y1) - for j in range(self.num_cols): - if(self.has_item_at(i, j)): - row_rect.y0 = min(row_rect.y0, self.get_item(i, j).pos_y) - self.rows.append(row_rect) - if(i>0): - self.rows[i-1].y1 = max(row_rect.y0, self.rows[i-1].y0+1) - - # Calc column rects - for j in range(self.num_cols): - col_rect = Rect(9999999, table_rect.y0, table_rect.x1, table_rect.y1) - for i in range(self.num_rows): - if(self.has_item_at(i, j)): - col_rect.x0 = min(col_rect.x0, self.get_item(i, j).pos_x) - self.cols.append(col_rect) - if(j>0): - self.cols[j-1].x1 = max(col_rect.x0, self.cols[j-1].x0+1) - - # Calc column aligned positions - for j in range(self.num_cols): - # find main alignment - ctr_left = 0 - ctr_right = 0 - ctr_center = 0 # New-27.06.2022 - for i in range(self.num_rows): - if(self.has_item_at(i,j)): - if(self.get_item(i,j).alignment == ALIGN_LEFT): - ctr_left += 1 - elif(self.get_item(i,j).alignment == ALIGN_RIGHT): - ctr_right += 1 - else: - ctr_center += 1 - - - #col_alignment = ALIGN_LEFT if ctr_left >= ctr_right else ALIGN_RIGHT - col_alignment = ALIGN_LEFT #default - if(ctr_left >= ctr_right and ctr_left >= ctr_center): - col_alignment = ALIGN_LEFT - elif(ctr_right >= ctr_left and ctr_right >= ctr_center): - col_alignment = ALIGN_RIGHT - else: - col_alignment = ALIGN_CENTER - - - sum_align_pos_x = 0 - ctr = 0 - - for i in range(self.num_rows): - if(self.has_item_at(i, j)): - # New-27.06.2022 - #sum_align_pos_x += self.get_item(i, j).pos_x if col_alignment == ALIGN_LEFT else (self.get_item(i, j).pos_x + self.get_item(i, j).width) - to_add = 0 - if(col_alignment == ALIGN_LEFT): - to_add = self.get_item(i, j).pos_x - if(col_alignment == ALIGN_RIGHT): - to_add = self.get_item(i, j).pos_x + self.get_item(i, j).width - if(col_alignment == ALIGN_CENTER): - to_add = self.get_item(i, j).pos_x + self.get_item(i, j).width * 0.5 - sum_align_pos_x += to_add - - ctr += 1 - - self.col_aligned_pos_x.append(sum_align_pos_x / ctr if ctr > 0 else -1) - - - def insert_row(self, r0): # row will be inserted AFTER r0 - if(r0 < -1 or r0 > self.num_rows): - raise ValueError('Rows r0='+str(r0)+' out of range.') - - c = self.num_cols - self.idx = self.idx[0:c*(r0+1)] + ([-1] * c) + self.idx[c*(r0+1):] - self.marks = self.marks[0:c*(r0+1)] + ([9999999] * c) + self.marks[c*(r0+1):] - self.num_rows += 1 - - def is_row_insertion_possible(self, r0, pos_y): # can we insert row, starting at pos_y, right below r0? - - y1 = self.rows[r0].y0 - for j in range(self.num_cols): - if(self.has_item_at(r0, j)): - y1 = max(y1, self.get_item(r0, j).get_rect().y1) - - return pos_y > y1 - - def delete_rows(self, r0, r1, do_recalc_geometry = True): #delete rows r, where r0 <= r < r1 - if(r0 < 0 or r0 > self.num_rows): - raise ValueError('Rows r0='+str(r0)+' out of range.') - if(r1 < 0 or r1 > self.num_rows): - raise ValueError('Rows r1='+str(r1)+' out of range.') - if(r1 <= r0): - raise ValueError('Rows r1='+str(r1)+' <= r0='+str(r0)) - - c = self.num_cols - self.idx = self.idx[0:c*r0] + self.idx[c*r1:] - self.marks = self.marks[0:c*r0] + self.marks[c*r1:] - self.num_rows -= (r1-r0) - if(do_recalc_geometry): - self.recalc_geometry() - - def delete_cols(self, c0, c1, do_recalc_geometry = True): #delete cols c, where c0 <= c < c1 - if(c0 < 0 or c0 > self.num_cols): - raise ValueError('Cols c0='+str(c0)+' out of range.') - if(c1 < 0 or c1 > self.num_cols): - raise ValueError('Cols c1='+str(c1)+' out of range.') - if(c1 <= c0): - raise ValueError('Cols c1='+str(c1)+' <= c0='+str(c0)) - - c = self.num_cols - tmp_idx = [] - tmp_marks = [] - for i in range(self.num_rows): - tmp_idx += self.idx[i*c:i*c+c0] + self.idx[i*c+c1:(i+1)*c] - tmp_marks += self.marks[i*c:i*c+c0] + self.marks[i*c+c1:(i+1)*c] - - self.idx = tmp_idx - self.marks = tmp_marks - self.num_cols -= (c1-c0) - if(do_recalc_geometry): - self.recalc_geometry() - - def is_row_mergable(self, r0): # return True iff row r0 and r0+1 can be merged - if(r0 < 0 or r0 >= self.num_rows - 1): - raise ValueError('Rows r0='+str(r0)+' and r0+1 out of range.') - - y0_max = 0 - y1_min = 9999999 - for j in range(self.num_cols): - if(self.has_item_at(r0, j) and self.has_item_at(r0+1, j) and not self.get_item(r0,j).is_mergable(self.get_item(r0+1,j))): - return False - cur_y0 = self.get_item(r0, j).pos_y + self.get_item(r0, j).height if self.has_item_at(r0, j) else 0 - cur_y1 = self.get_item(r0+1, j).pos_y if self.has_item_at(r0+1, j) else 9999999 - y0_max = max(y0_max, cur_y0) - y1_min = min(y1_min, cur_y1) - - - if(y0_max < y1_min): - return False # rows would not overlap - - #rows are overlapping, now check for items what would block merging - - c0 = 0 - c1 = 0 - r1 = r0+1 - while(c0 < self.num_cols and c1 < self.num_cols): - if(self.has_item_at(r0, c0) and self.has_item_at(r1, c1)): - x0_0 = self.get_item(r0, c0).pos_x - x1_0 = self.get_item(r0, c0).pos_x+self.get_item(r0, c0).width - x0_1 = self.get_item(r1, c1).pos_x - x1_1 = self.get_item(r1, c1).pos_x+self.get_item(r1, c1).width - if(min(x1_0, x1_1) - max(x0_0, x0_1) >= 0): - print_verbose(7, "-----> Can't merge rows "+str(r0)+" and next one, because c0="+str(c0)+", c1="+str(c1)+" would overlap") - print_verbose(7, "These are the items "+str(self.get_item(r0,c0))+", and "+str(self.get_item(r1,c1))) - return False # two columns would overlaping in the same row - if(x1_0 < x1_1): - c0 += 1 - else: - c1 += 1 - elif(self.has_item_at(r0, c0)): - c1 += 1 - else: - c0 += 1 - - print_verbose(7, "----> Merge r0="+str(r0)+" where y0_max,y1_min = "+str(y0_max)+","+str(y1_min)) - return True - - - def is_col_mergable(self, c0): # return True iff col c0 and c0+1 can be merged - if(c0 < 0 or c0 >= self.num_cols -1): - raise ValueError('Cols c0='+str(c0)+' and c0 out of range') - - for i in range(self.num_rows): - if(self.has_item_at(i, c0) and self.has_item_at(i, c0+1)): - return False - - return True - - - def merge_rows(self, r0, reconnect=False): # merge rows r0 and r0+1 - if(r0 < 0 or r0 >= self.num_rows - 1): - raise ValueError('Rows r0='+str(r0)+' and r0+1 out of range.') - - for j in range(self.num_cols): - ix0 = self.get_ix(r0, j) - ix1 = self.get_ix(r0+1, j) - if(self.idx[ix0]==-1): - self.idx[ix0] = self.idx[ix1] - elif(self.idx[ix1]==-1): - #self.idx[ix0] = ix0 - pass - else: - if(reconnect): - #print("\n\n\n\nRECONNECT: "+str(self.items[self.idx[ix0]].this_id) + " <-> " + str(self.items[self.idx[ix1]].this_id))#zz - self.items[self.idx[ix0]].reconnect(self.items[self.idx[ix1]], self.items) - self.items[self.idx[ix0]].merge(self.items[self.idx[ix1]]) - #if(reconnect):#zz - # print("After reconnect: " + str(self.items))#zz - self.delete_rows(r0+1, r0+2, False) - - def merge_cols(self, c0): # merge cols c0 and c0+1 - if(c0 < 0 or c0 >= self.num_cols -1): - raise ValueError('Cols c0='+str(c0)+' and c0 out of range') - - for i in range(self.num_rows): - ix0 = self.get_ix(i, c0) - ix1 = self.get_ix(i, c0+1) - if(self.idx[ix0] == -1): - self.idx[ix0] = self.idx[ix1] - - self.delete_cols(c0+1, c0+2, False) - - - - def merge_down_all_rows(self): - #print_subset(0, self.items, self.idx) - for i in range(self.num_rows-1): - if(self.is_row_mergable(i)): - self.merge_rows(i) - self.merge_down_all_rows() - return - - def merge_down_all_cols(self): - for j in range(self.num_cols-1): - if(self.is_col_mergable(j)): - self.merge_cols(j) - self.merge_down_all_cols() - return - - - - def is_empty_row(self, r): - for j in range(self.num_cols): - if(self.has_item_at(r, j)): - return False - return True - - def is_empty_col(self, c): - for i in range(self.num_rows): - if(self.has_item_at(i, c)): # and self.get_item(i,c).txt != ''): - #print('HAS ITEM AT '+str(i)+','+str(c)+', namely: '+str(self.get_item(i,c))) - return False - return True - - def compactify(self): #remove all empty rows and cols - - has_changed = True - while(has_changed): - has_changed = False - #rows - for i in range(self.num_rows-1, -1, -1): - if(self.is_empty_row(i)): - print_verbose(7, "Delete empty row : "+str(i)) - has_changed = True - self.delete_rows(i, i+1, False) - #cols - for j in range(self.num_cols-1, -1, -1): - if(self.is_empty_col(j)): - print_verbose(7, "Delete empty column : "+str(j)) - has_changed = True - self.delete_cols(j, j+1, False) - - self.recalc_geometry() - - - - def throw_away_non_connected_rows(self): # throw away rows that are probably not connected to the table - def is_connected_row(r0): - min_y0 = 9999999 - max_y0 = -1 - min_y1 = 9999999 - for j in range(self.num_cols): - if(self.has_item_at(r0, j)): - min_y0 = min(min_y0, self.get_item(r0, j).pos_y) - max_y0 = max(max_y0, self.get_item(r0, j).pos_y + self.get_item(r0, j).height) - if(self.has_item_at(r0+1, j)): - min_y1 = min(min_y1, self.get_item(r0+1, j).pos_y) - - if(min_y1 == 9999999 or min_y0 == 9999999): - return True# at least one row was empty => consider always as connected - - - y_limit = min_y0 + (max_y0 - min_y0) * config.global_row_connection_threshold #at least space that woulld occupy 4x row r0 are (not) empty #New: 27.06.2022 (was previously: * 4) - - if(min_y1 <= y_limit): - return True - - sp_idx = self.find_special_item_idx_in_rect(self.rows[r0]) - - for k in sp_idx: - cur_y = self.items[k].pos_y - if(cur_y <= y_limit): - return True - - return False - - - for i in range(self.num_rows-1): - if(not is_connected_row(i)): - print_verbose(5, "Throw away non-connected rows after /excl. :"+str(i)) - print_verbose(5, "Current table: "+str(self)) - self.delete_rows(i+1, self.num_rows) - return - - def throw_away_rows_after_new_header(self): - - if(self.num_cols < 2 or self.num_rows < 3): - return - - num_numeric_rows = 0 - num_rows_with_left_txt = 0 - last_pos_y = 9999999 - last_delta_y = 9999999 - - for i in range(self.num_rows): - if(self.has_item_at(i,0) and Format_Analyzer.looks_words(self.get_item(i,0).txt)): - num_rows_with_left_txt += 1 - cur_numeric_values = 0 - cur_header_values = 0 - cur_other_values = 0 - cur_pos_y = 9999999 - for j in range(1, self.num_cols): - if(self.has_item_at(i,j)): - if(cur_pos_y == 9999999): - cur_pos_y = self.get_item(i,j).pos_y - txt = self.get_item(i,j).txt - if(Format_Analyzer.looks_numeric(txt) and not Format_Analyzer.looks_year(txt)): - cur_numeric_values += 1 - elif((Format_Analyzer.looks_words(txt) and txt[0].isupper()) or Format_Analyzer.looks_year(txt)): - cur_header_values += 1 - else: - cur_other_values += 1 - - if(cur_numeric_values > max(self.num_cols * 0.6,1)): - num_numeric_rows += 1 - - cur_delta_y = cur_pos_y - last_pos_y - - if(num_rows_with_left_txt > 2 and num_numeric_rows > 2 and cur_delta_y > last_delta_y * 1.05 + 2): - if(cur_numeric_values == 0 and cur_header_values > 0 and cur_other_values == 0): - print_verbose(5, "Throw away non-connected rows after probably new headline at row = "+str(i)+", cur/last_delta_y="+str(cur_delta_y)+"/"+str(last_delta_y)) - self.delete_rows(i, self.num_rows) - return - - last_pos_y = cur_pos_y - last_delta_y = cur_delta_y - - - def throw_away_last_headline(self): - # a headline at the end of a table probably doesnt belong to it, rather, it belongs to the next table - if(self.num_rows < 2 or self.num_cols < 2): - return - - if(not self.has_non_empty_item_at(self.num_rows-1, 0)): - return - - if(Format_Analyzer.looks_numeric(self.get_item(self.num_rows-1, 0).txt)): - return - - for j in range(1, self.num_cols): - if(self.has_non_empty_item_at(self.num_rows-1, j)): - return - - self.delete_rows(self.num_rows-1, self.num_rows) - - - - - def throw_away_non_connected_cols(self, page_width): # throw away cols that are probably not connected to the table - def is_connected_col(c0): - min_x0 = 9999999 - max_x0 = -1 - min_x1 = 9999999 - for i in range(self.num_rows): - if(self.has_item_at(i, c0)): - min_x0 = min(min_x0, self.get_item(i, c0).pos_x) - max_x0 = max(max_x0, self.get_item(i, c0).pos_x + self.get_item(i, c0).width) - if(self.has_item_at(i, c0+1)): - min_x1 = min(min_x1, self.get_item(i, c0+1).pos_x) - - if(min_x1 == 9999999 or min_x0 == 9999999): - return True# at least one col was empty => consider always as connected - - if (min_x1 <= max_x0 + DEFAULT_HTHROWAWAY_DIST * page_width): #at least that much space is not empty - return True - - - num_reg_text = 0 - for i in range(self.num_rows): - if(not self.has_item_at(i,c0)): - continue - y0 = self.get_item(i, c0).pos_y - y1 = self.get_item(i, c0).pos_y + self.get_item(i, c0).height - x0 = max((min_x0 + max_x0)/2.0, self.get_item(i, c0).pos_x) - x1 = min_x1 - cur_rect = Rect(x0, y0, x1, y1) - for it in self.items: - if(it.category not in [CAT_HEADLINE, CAT_OTHER_TEXT, CAT_RUNNING_TEXT, CAT_FOOTER]): - continue - it_rect = it.get_rect() - if(Rect.calc_intersection_area(cur_rect, it_rect) > 0): - print_verbose(2, '----->> With ' +str(self.get_item(i, c0))+ ' the item ' +str(it) + ' overlaps') - num_reg_text += 1 - - - #num_reg_text = self.count_regular_items_in_rect(Rect( (min_x0 + max_x0)/2.0, self.table_rect.y0, min_x1, self.table_rect.y1) ) - - return num_reg_text == 0 - - - for j in range(self.num_cols-1): - if(not is_connected_col(j)): - print_verbose(5, "Throw away non-connected cols after /excl. :"+str(j)) - print_verbose(5, "Current table: "+str(self)) - self.delete_cols(j+1, self.num_cols) - return - - - def throw_away_cols_at_next_paragraph(self, paragraphs): - def find_cur_paragraph_idx(x0, x1, my_paragraphs): - max_overlap = 0 - res = -1 - for i in range(len(my_paragraphs)): - p0 = my_paragraphs[i] - p1 = my_paragraphs[i+1] if i < len(my_paragraphs) -1 else 9999999 - overlap = max(min(x1, p1) - max(x0, p0), 0) - if(overlap > max_overlap): - max_overlap = overlap - res = i - return res - - - - if(self.num_cols == 0 or len(paragraphs) < 2): - return - - self.recalc_geometry() - - # Find relevant paragraphs - my_paragraphs = [] - for p in paragraphs: - cnt = 0 - for it in self.items: - if(it.pos_x == p and it.category in (CAT_RUNNING_TEXT , CAT_HEADLINE) and it.pos_x < self.table_rect.x1 and it.pos_x+it.width > self.table_rect.x0): - if((it.pos_y < self.table_rect.y1 and it.pos_y + it.height * 5 > self.table_rect.y0) or #TODO 5 was 10, test it - (it.pos_y > self.table_rect.y0 and it.pos_y - it.height * 5 < self.table_rect.y1)): #TODO 5 was 10, test it - cnt += 1 - #print(it) - if(cnt > 3): - my_paragraphs.append(p) - - my_paragraphs.sort() - - print_verbose(5, "Relevant paragraphs:"+str(my_paragraphs)) - - - last_para_idx = -1 - - for j in range(self.num_cols): - x0 = 9999999 - x1 = -1 - for i in range(self.num_rows): - if(self.has_item_at(i,j)): - x0 = min(x0, self.get_item(i,j).pos_x) - x1 = max(x1, self.get_item(i,j).pos_x + self.get_item(i,j).width) - if(x1 == -1): - continue #b/c empty col - - #print(j, x0, x1, my_paragraphs) - cur_para_idx = find_cur_paragraph_idx(x0, x1, my_paragraphs) - print_verbose(7, "--> Col ="+str(j) + ", x0/x1="+str(self.cols[j].x0)+"/"+str(self.cols[j].x1)+" belong to paragraph p_idx = " + str(cur_para_idx) + ", which is at " + str(my_paragraphs[cur_para_idx] if cur_para_idx != -1 else None) + " px") - if(j>1 and cur_para_idx != last_para_idx): #TODO 1 was 0, test is - # table is here probably split between two text paragraphs - print_verbose(5, "Throw away cols at next paragraph: col at next j="+str(j)) - self.delete_cols(j, self.num_cols) - return - last_para_idx = cur_para_idx - - - - def throw_away_cols_after_year_list(self): #throw away cols after a list of years is over - - class YearCols: - r = None - c0 = None - c1 = None - def __init__(self, r, c0): - self.r=r - self.c0=c0 - - def __repr__(self): - return "(r="+str(self.r)+",c0="+str(self.c0)+",c1="+str(self.c1)+")" - - if(self.num_cols < 5): - return - - year_cols = [] - - #print(str(self)) - #print("Cols:"+str(self.num_cols)) - - for i in range(self.num_rows): - j = 0 - while(j < self.num_cols - 1): - if(self.has_item_at(i,j) and self.has_item_at(i,j+1) and \ - Format_Analyzer.looks_year(self.get_item(i,j).txt) and \ - Format_Analyzer.looks_year(self.get_item(i,j+1).txt) and \ - abs(Format_Analyzer.to_year(self.get_item(i,j+1).txt) - Format_Analyzer.to_year(self.get_item(i,j).txt)) == 1): - dir = Format_Analyzer.to_year(self.get_item(i,j+1).txt) - Format_Analyzer.to_year(self.get_item(i,j).txt) - cur_year_cols = YearCols(i, j) - #find last year col - #print("Now at cell:"+str(i)+","+str(j)) - for j1 in range(j+1, self.num_cols): - if(self.has_item_at(i,j1) and \ - Format_Analyzer.looks_year(self.get_item(i,j1).txt) and \ - Format_Analyzer.to_year(self.get_item(i,j1).txt) - Format_Analyzer.to_year(self.get_item(i,j1-1).txt) == dir): - cur_year_cols.c1 = j1 - else: - break - year_cols.append(cur_year_cols) - j = cur_year_cols.c1 - j += 1 - - print_verbose(6, '----->> Found year lists at: ' +str(year_cols)) - - - if(len(year_cols) < 2): - return - - # test if we can throw away - - # find min - - min_yc = -1 - for k in range(len(year_cols)): - if(min_yc == -1 or (year_cols[k].c0 < year_cols[min_yc].c0)): - min_yc = k - - print_verbose(6, "------->> min:" + str(year_cols[min_yc]) + " at idx " + str(min_yc)) - - #find overlapping max - max_overlap_yc = -1 - for k in range(len(year_cols)): - if(year_cols[k].c0 < year_cols[min_yc].c1): - if(max_overlap_yc == -1 or (year_cols[k].c1 > year_cols[max_overlap_yc].c1)): - max_overlap_yc = k - - print_verbose(6, "------->> max overlap:" + str(year_cols[max_overlap_yc]) + " at idx " + str(max_overlap_yc)) - - # are there any year cols after max overlap? if so, we can throw that part away - can_throw_away = False - for yc in year_cols: - if(yc.c0 > year_cols[max_overlap_yc].c1): - print_verbose(6, "-------->> throw away because : " + str(yc) ) - can_throw_away = True - - if(can_throw_away): - self.delete_cols(year_cols[max_overlap_yc].c1 + 1, self.num_cols) - - - def throw_away_duplicate_looking_cols(self): #throw away columns that are looking like duplicates, and indicating another table - def are_cols_similar(c0, c1): - return self.col_looks_like_text_col(c0) and self.col_looks_like_text_col(c1) \ - and Format_Analyzer.cnt_overlapping_items(self.get_all_cols_as_text(c0), self.get_all_cols_as_text(c1)) > 3 - - - if(self.num_cols < 3): - return - - for i in range(2, self.num_cols): - if(are_cols_similar(0, i)): - print_verbose(7, "------->> cols 0 and " + str(i) + " are similar. Throw away from " + str(i)) - self.delete_cols(i, self.num_cols) - return - - - - - def identify_headline(self): - if(self.num_rows == 0 or self.num_cols == 0): - return - - if(not self.has_item_at(0,0) or not Format_Analyzer.looks_words(self.get_item(0,0).txt)): - return - - for j in range(1, self.num_cols): - if(self.has_item_at(0, j)): - return - - self.headline_idx.append(self.get_idx(0,0)) - self.delete_rows(0, 1) - - - def identify_non_numeric_special_items(self): - def col_looks_numeric(c0): - num_numbers = 0 - num_words = 0 - for i in range(self.num_rows): - if(self.has_item_at(i, c0)): - txt = self.get_item(i, c0).txt - if(Format_Analyzer.looks_numeric(txt)): - num_numbers += 1 - elif(Format_Analyzer.looks_words(txt)): - num_words += 1 - return num_numbers >= 3 and num_words < num_numbers * 0.4 - - def is_only_item_in_row(r0, c0): - for j in range(c0): - if(self.has_item_at(r0, j)): - return False - for j in range(c0+1, self.num_cols): - if(self.has_item_at(r0, j)): - return False - return True - - for j in range(self.num_cols): - if(col_looks_numeric(j)): - print_verbose(5, 'Numeric col found : '+str(j)) - # find first possible special item of this col - r0 = self.find_first_non_empty_row_in_col(j) - r1 = r0 - font_char = self.get_item(r0, j).get_font_characteristics() - for i in range(r0, self.num_rows): - r1 = i - if(not self.has_item_at(i,j)): - r1 = i + 1 - break #empty line - if(Format_Analyzer.looks_numeric(self.get_item(i,j).txt)): - r1 = i + 1 - break #number - if(self.get_item(i,j).get_font_characteristics() != font_char): - break #different font - - # remove now special items (but not before first occurence) - for i in range(r1, self.num_rows): - if(self.has_item_at(i, j)): - txt = self.get_item(i, j).txt - if(Format_Analyzer.looks_words(txt) or Format_Analyzer.looks_other_special_item(txt)): - cur_idx = self.get_idx(i,j) - self.idx[self.get_ix(i,j)] = -1 - self.special_idx.append(cur_idx) - - #remove even special items from headline, but only if there are no other headline items - if(self.has_item_at(r0, j) and is_only_item_in_row(r0, j) and (Format_Analyzer.looks_words(self.get_item(r0, j).txt) or Format_Analyzer.looks_other_special_item(self.get_item(r0, j).txt))): - cur_idx = self.get_idx(r0, j) - self.idx[self.get_ix(r0,j)] = -1 - self.special_idx.append(cur_idx) - """ - if(self.has_item_at(0, j) and is_only_item_in_row(0, j) and (Format_Analyzer.looks_words(self.get_item(0, j).txt) or Format_Analyzer.looks_other_special_item(self.get_item(0, j).txt))): - cur_idx = self.get_idx(0, j) - self.idx[self.get_ix(0,j)] = -1 - self.special_idx.append(cur_idx) - """ - - - - def identify_overlapping_special_items(self): - # * Identify all remaing items that must be set to be special items - # because otherwise they would overlap with other columns. - # * This is tricky, because we need to find a minimal set of such items, - # which leads to an NP-complete problem. - # * To solve it (fast in most cases), we employ a Backtracking algorithm - - rec_counter = 0 - timeout = False - t_start = 0 - looks_numeric = [] - tmp_boundaries = [] - - def calc_single_col_boundary(tmp_idx, col0): - cur_bdry = (9999999, -1) - for i in range(self.num_rows): - cur_ix = self.get_ix(i,col0) - cur_idx = tmp_idx[cur_ix] - if(cur_idx != -1): - cur_bdry =( min(cur_bdry[0], self.items[cur_idx].pos_x), max(cur_bdry[1], self.items[cur_idx].pos_x + self.items[cur_idx].width) ) - return cur_bdry - - def calc_col_boundaries(tmp_idx): - res = [(0,0)] * self.num_cols - - for j in range(self.num_cols): - res[j] = calc_single_col_boundary(tmp_idx, j) - return res - - def find_first_overlapping_col(boundaries): - for j in range(self.num_cols-1): - if(boundaries[j][1] == -1): - continue # skip empty columns - - #find nextnon-empty col - k = j+1 - while(k boundaries[k][0]): - #0.99 tolerance, because sometimes font width is overestimated - return j - return -1 - - def find_possible_overlapping_ix(c0, tmp_idx, boundaries): - bdry = (min(boundaries[c0][0],boundaries[c0+1][0]), max(boundaries[c0][1], boundaries[c0+1][1])) - res = [] - found_something = False - for j in range(self.num_cols): - if(boundaries[j][1] > bdry[0] and boundaries[j][0] < bdry[1]): - #all items from this col might be overlaping - for i in range(self.num_rows): - cur_ix = self.get_ix(i,j) - if(tmp_idx[cur_ix] != -1): - res.append(cur_ix) - found_something = True - res = sorted(res, key=lambda ix: - self.items[tmp_idx[ix]].width) #sort descending - if(not found_something): - raise ValueError('Some columns are overlapping, but there are no relevant items.') - return res - - - - def find_allowed_set_rec(tmp_idx, num_sp_items, last_sp_ix, lowest_num_so_far): - # returns set of ix'es, such that after removing them, the rest is allowed (e.g., no overlap) - nonlocal rec_counter - nonlocal t_start - nonlocal timeout - nonlocal looks_numeric - nonlocal tmp_boundaries - - rec_counter += 1 - - if(rec_counter % 1000 == 0): - t_now = time.time() - if(t_now - t_start > config.global_max_identify_complex_items_timeout): #max 5 sec TODO - timeout = True - - - if(num_sp_items >= lowest_num_so_far or timeout): - print_verbose(20,"No better solution exists") - return 9999999, [] # we cant find a better solution - - first_overlapping_col = find_first_overlapping_col(tmp_boundaries) - if(first_overlapping_col == -1): - print_verbose(9,"Found solution, num_sp_items="+str(num_sp_items)) - return num_sp_items, [last_sp_ix] #we found allowed set, where only num_sp_items items are excluded - - - #raise ValueError('first_overlapping_col='+str(first_overlapping_col)) - - possible_overlap_ix = find_possible_overlapping_ix(first_overlapping_col, tmp_idx, tmp_boundaries) - best_sp_ix = [] - - #raise ValueError('possible_overlap_ix='+str(possible_overlap_ix)) - - #print_verbose(15, "num_sp_item="+str(num_sp_items)+", Boundaries = "+str(tmp_boundaries)+", first_overlapping_col = " \ - # +str(first_overlapping_col) + " lowest_num_so_far="+str(lowest_num_so_far)) # ", possible_overlap_ix= " +str(possible_overlap_ix) + - - for ix in possible_overlap_ix: - # try this ix - if(looks_numeric[ix]): - continue # never use numbers as special items - - #if(num_sp_items==0): - # print_verbose(0, "---> Trying: " +str(self.items[tmp_idx[ix]])) - - col = ix % self.num_cols - - old = tmp_idx[ix] - tmp_idx[ix] = -1 - old_bdry = tmp_boundaries[col] - tmp_boundaries[col] = calc_single_col_boundary(tmp_idx, col) - cur_lowest_num, cur_sp_ix = find_allowed_set_rec(tmp_idx, num_sp_items+1, ix, lowest_num_so_far) - tmp_idx[ix] = old - tmp_boundaries[col] = old_bdry - if(cur_lowest_num < lowest_num_so_far): - lowest_num_so_far = cur_lowest_num - best_sp_ix = cur_sp_ix - if(cur_lowest_num == num_sp_items+1): - break # we cannot find a better solution, so escape early - - return lowest_num_so_far, best_sp_ix + [last_sp_ix] - - - if(self.num_cols == 0): - return # nothing to do - - - tmp_idx = self.idx.copy() - - for cur_idx in tmp_idx: - looks_numeric.append(Format_Analyzer.looks_numeric(self.items[cur_idx].txt)) - - - tmp_boundaries = calc_col_boundaries(tmp_idx) - - t_start = time.time() - lowest_num_so_far, sp_ix = find_allowed_set_rec(tmp_idx, 0, -1, 9999999) - t_end = time.time() - print_verbose(3, "---> find_allowed_set_rec completed after time="+str(t_end-t_start)+"sec, recursions="+str(rec_counter)) - - if(timeout and lowest_num_so_far == 9999999): - #we coulnt find a solution => give up on this table - print_verbose(3, "---> No solution. Give up") - #for k in range(len(self.idx)): - # self.special_idx.append(self.idx[k]) - # self.idx[k] = -1 - return - - - if(lowest_num_so_far == 9999999): - #print(str(self)) - #raise ValueError('Some columns are overlapping, but after backtracking, no items were found that could be removed.') - return - - # make sure, that we dont throw out too much items - for ix in sp_ix: - if(ix != -1): - tmp_idx[ix] = -1 - - - - tmp_bdry = calc_col_boundaries(tmp_idx) - #print(tmp_bdry) - - sp_ix_final = [] - for ix in sp_ix: - if(ix != -1): - #could this one stay? - tmp_idx[ix] = self.idx[ix] - tmp_bdry = calc_col_boundaries(tmp_idx) - #print(self.items[self.idx[ix]].txt) - #print(tmp_bdry) - if(find_first_overlapping_col(tmp_bdry) != -1): - #no, it can't - sp_ix_final.append(ix) - #print('----> No!') - tmp_idx[ix] = -1 - - - - - - - - - for ix in sp_ix_final: - if(ix != -1): - self.special_idx.append(self.idx[ix]) - self.idx[ix] = -1 - - - self.recalc_geometry() - - - - return - - - - """ + def identify_overlapping_special_items(self): + # * Identify all remaing items that must be set to be special items + # because otherwise they would overlap with other columns. + # * This is tricky, because we need to find a minimal set of such items, + # which leads to an NP-complete problem. + # * To solve it (fast in most cases), we employ a Backtracking algorithm + + rec_counter = 0 + timeout = False + t_start = 0 + looks_numeric = [] + tmp_boundaries = [] + + def calc_single_col_boundary(tmp_idx, col0): + cur_bdry = (9999999, -1) + for i in range(self.num_rows): + cur_ix = self.get_ix(i, col0) + cur_idx = tmp_idx[cur_ix] + if cur_idx != -1: + cur_bdry = ( + min(cur_bdry[0], self.items[cur_idx].pos_x), + max(cur_bdry[1], self.items[cur_idx].pos_x + self.items[cur_idx].width), + ) + return cur_bdry + + def calc_col_boundaries(tmp_idx): + res = [(0, 0)] * self.num_cols + + for j in range(self.num_cols): + res[j] = calc_single_col_boundary(tmp_idx, j) + return res + + def find_first_overlapping_col(boundaries): + for j in range(self.num_cols - 1): + if boundaries[j][1] == -1: + continue # skip empty columns + + # find nextnon-empty col + k = j + 1 + while k < self.num_cols and boundaries[k][1] == -1: + k += 1 + + if k == self.num_cols: + return -1 # only empty columns left + + if boundaries[j][0] + (boundaries[j][1] - boundaries[j][0]) * 0.99 > boundaries[k][0]: + # 0.99 tolerance, because sometimes font width is overestimated + return j + return -1 + + def find_possible_overlapping_ix(c0, tmp_idx, boundaries): + bdry = (min(boundaries[c0][0], boundaries[c0 + 1][0]), max(boundaries[c0][1], boundaries[c0 + 1][1])) + res = [] + found_something = False + for j in range(self.num_cols): + if boundaries[j][1] > bdry[0] and boundaries[j][0] < bdry[1]: + # all items from this col might be overlaping + for i in range(self.num_rows): + cur_ix = self.get_ix(i, j) + if tmp_idx[cur_ix] != -1: + res.append(cur_ix) + found_something = True + res = sorted(res, key=lambda ix: -self.items[tmp_idx[ix]].width) # sort descending + if not found_something: + raise ValueError("Some columns are overlapping, but there are no relevant items.") + return res + + def find_allowed_set_rec(tmp_idx, num_sp_items, last_sp_ix, lowest_num_so_far): + # returns set of ix'es, such that after removing them, the rest is allowed (e.g., no overlap) + nonlocal rec_counter + nonlocal t_start + nonlocal timeout + nonlocal looks_numeric + nonlocal tmp_boundaries + + rec_counter += 1 + + if rec_counter % 1000 == 0: + t_now = time.time() + if t_now - t_start > config.global_max_identify_complex_items_timeout: # max 5 sec TODO + timeout = True + + if num_sp_items >= lowest_num_so_far or timeout: + print_verbose(20, "No better solution exists") + return 9999999, [] # we cant find a better solution + + first_overlapping_col = find_first_overlapping_col(tmp_boundaries) + if first_overlapping_col == -1: + print_verbose(9, "Found solution, num_sp_items=" + str(num_sp_items)) + return num_sp_items, [last_sp_ix] # we found allowed set, where only num_sp_items items are excluded + + # raise ValueError('first_overlapping_col='+str(first_overlapping_col)) + + possible_overlap_ix = find_possible_overlapping_ix(first_overlapping_col, tmp_idx, tmp_boundaries) + best_sp_ix = [] + + # raise ValueError('possible_overlap_ix='+str(possible_overlap_ix)) + + # print_verbose(15, "num_sp_item="+str(num_sp_items)+", Boundaries = "+str(tmp_boundaries)+", first_overlapping_col = " \ + # +str(first_overlapping_col) + " lowest_num_so_far="+str(lowest_num_so_far)) # ", possible_overlap_ix= " +str(possible_overlap_ix) + + + for ix in possible_overlap_ix: + # try this ix + if looks_numeric[ix]: + continue # never use numbers as special items + + # if(num_sp_items==0): + # print_verbose(0, "---> Trying: " +str(self.items[tmp_idx[ix]])) + + col = ix % self.num_cols + + old = tmp_idx[ix] + tmp_idx[ix] = -1 + old_bdry = tmp_boundaries[col] + tmp_boundaries[col] = calc_single_col_boundary(tmp_idx, col) + cur_lowest_num, cur_sp_ix = find_allowed_set_rec(tmp_idx, num_sp_items + 1, ix, lowest_num_so_far) + tmp_idx[ix] = old + tmp_boundaries[col] = old_bdry + if cur_lowest_num < lowest_num_so_far: + lowest_num_so_far = cur_lowest_num + best_sp_ix = cur_sp_ix + if cur_lowest_num == num_sp_items + 1: + break # we cannot find a better solution, so escape early + + return lowest_num_so_far, best_sp_ix + [last_sp_ix] + + if self.num_cols == 0: + return # nothing to do + + tmp_idx = self.idx.copy() + + for cur_idx in tmp_idx: + looks_numeric.append(Format_Analyzer.looks_numeric(self.items[cur_idx].txt)) + + tmp_boundaries = calc_col_boundaries(tmp_idx) + + t_start = time.time() + lowest_num_so_far, sp_ix = find_allowed_set_rec(tmp_idx, 0, -1, 9999999) + t_end = time.time() + print_verbose( + 3, + "---> find_allowed_set_rec completed after time=" + + str(t_end - t_start) + + "sec, recursions=" + + str(rec_counter), + ) + + if timeout and lowest_num_so_far == 9999999: + # we coulnt find a solution => give up on this table + print_verbose(3, "---> No solution. Give up") + # for k in range(len(self.idx)): + # self.special_idx.append(self.idx[k]) + # self.idx[k] = -1 + return + + if lowest_num_so_far == 9999999: + # print(str(self)) + # raise ValueError('Some columns are overlapping, but after backtracking, no items were found that could be removed.') + return + + # make sure, that we dont throw out too much items + for ix in sp_ix: + if ix != -1: + tmp_idx[ix] = -1 + + tmp_bdry = calc_col_boundaries(tmp_idx) + # print(tmp_bdry) + + sp_ix_final = [] + for ix in sp_ix: + if ix != -1: + # could this one stay? + tmp_idx[ix] = self.idx[ix] + tmp_bdry = calc_col_boundaries(tmp_idx) + # print(self.items[self.idx[ix]].txt) + # print(tmp_bdry) + if find_first_overlapping_col(tmp_bdry) != -1: + # no, it can't + sp_ix_final.append(ix) + # print('----> No!') + tmp_idx[ix] = -1 + + for ix in sp_ix_final: + if ix != -1: + self.special_idx.append(self.idx[ix]) + self.idx[ix] = -1 + + self.recalc_geometry() + + return + + """ @@ -1240,9 +1253,7 @@ def find_allowed_set_rec(tmp_idx, num_sp_items, sp_ix, lowest_num_so_far): """ - - - """def identify_remaining_special_items(self, page_width): + """def identify_remaining_special_items(self, page_width): # return True iff something was found and changed for i in range(self.num_rows): for j in range(self.num_cols): @@ -1254,155 +1265,143 @@ def find_allowed_set_rec(tmp_idx, num_sp_items, sp_ix, lowest_num_so_far): self.special_idx.append(cur_idx) return True return False - """ - - - - - - - def throw_away_distant_special_items(self, page_width): - tmp = [] - for i in self.special_idx: - dist = Rect.distance(self.table_rect, self.items[i].get_rect()) / page_width - if(dist <= DEFAULT_SPECIAL_ITEM_MAX_DIST and self.items[i].pos_y <= self.table_rect.y1): - tmp.append(i) - else: - print_verbose(5, "Throw away special item: "+str(self.items[i]) + " from table with rect : " + - str(self.table_rect) + " and distance: " + str(dist)) - self.special_idx = sorted(tmp, key=lambda i: self.items[i].pos_y ) - - - - - def cleanup_table(self, page_width, paragraphs): - - bak_items = deepcopy(self.items) - bak_idx = deepcopy(self.idx) - - num_cells = -1 - old_num_actual_items = -1 - old_num_rows = -1 - old_num_cols = -1 - - #global config.global_verbosity - #config.global_verbosity = 6 - - - print_verbose(3, 'Table before cleanup: '+str(self)) - - - - - while(True): - cur_num_actual_items = self.count_actual_items() - if(cur_num_actual_items == old_num_actual_items and self.num_rows == old_num_rows and self.num_cols == old_num_cols): - break #no more changes - - old_num_actual_items = cur_num_actual_items - old_num_rows = self.num_rows - old_num_cols = self.num_cols - - - print_verbose(3, "--> Next cleanuptable iteration") - self.compactify() - print_verbose(6, "------>> After compactify:" + str(self.get_printed_repr())) - - self.recalc_geometry() - print_verbose(6, "------>> After recalc_geometry:" + str(self.get_printed_repr())) - - self.compactify() - self.throw_away_non_connected_rows() - print_verbose(6, "------>> After throw_away_non_connected_rows:" + str(self.get_printed_repr())) - - self.compactify() - self.throw_away_rows_after_new_header() - print_verbose(6, "------>> After throw_away_rows_after_new_header:" + str(self.get_printed_repr())) - - self.compactify() - self.throw_away_non_connected_cols(page_width) - print_verbose(6, "------>> After throw_away_non_connected_cols:" + str(self.get_printed_repr())) - - self.compactify() - self.throw_away_cols_at_next_paragraph(paragraphs) - print_verbose(6, "------>> After throw_away_cols_at_next_paragraph:" + str(self.get_printed_repr())) - - self.compactify() - self.throw_away_cols_after_year_list() - print_verbose(6, "------>> After throw_away_cols_after_year_list:" + str(self.get_printed_repr())) - - self.compactify() - self.throw_away_duplicate_looking_cols() - print_verbose(6, "------>> After throw_away_duplicate_looking_cols:" + str(self.get_printed_repr())) - - self.compactify() - self.merge_down_all_rows() - print_verbose(6, "------>> After merge_down_all_rows:" + str(self.get_printed_repr())) - - self.compactify() - self.merge_down_all_cols() - print_verbose(6, "------>> After merge_down_all_cols:" + str(self.get_printed_repr())) - - self.compactify() - self.identify_headline() - print_verbose(6, "------>> After identify_headline:" + str(self.get_printed_repr())) - - self.compactify() - self.throw_away_last_headline() - print_verbose(6, "------>> After throw_away_last_headline:" + str(self.get_printed_repr())) - - self.compactify() - self.identify_overlapping_special_items() - print_verbose(6, "------>> After identify_overlapping_special_items:" + str(self.get_printed_repr())) - #raise ValueError('XXX') - - self.compactify() - self.identify_non_numeric_special_items() - print_verbose(6, "------>> After identify_non_numeric_special_items:" + str(self.get_printed_repr())) - - self.compactify() - self.recalc_geometry() - print_verbose(6, "------>> After recalc_geometry:" + str(self.get_printed_repr())) - - - self.compactify() + """ - self.special_idx = sorted(self.special_idx, key=lambda i: self.items[i].pos_y ) - self.throw_away_distant_special_items(page_width) - - print_verbose(6, "------>> After throw_away_distant_special_items:" + str(self.get_printed_repr())) - - # restore all items that are no longer part of that table - to_restore = list(set(bak_idx) - set(self.idx + [-1])) - print_verbose(3, "Restoring old items with idx: "+str(to_restore)) - for i in to_restore: - # was this item merged and the merged item is still used? - was_merged = False - for k in self.items[i].merged_list: - if k in self.idx: - was_merged = True - break - if(not was_merged): - print_verbose(6, '----> Old item '+str(i)+' was not merged => Restore') - self.items[i] = bak_items[i] - else: - print_verbose(6, '----> Old item '+str(i)+' was merged => Dont touch') - - - - - - self.compactify() - - print_verbose(6, "------>> After restoring old item:" + str(self.get_printed_repr())) - - print_verbose(3, "===============>>>>>>>>>>>>>>>> Cleanup done <<<<<<<<<<<< =====================") - + def throw_away_distant_special_items(self, page_width): + tmp = [] + for i in self.special_idx: + dist = Rect.distance(self.table_rect, self.items[i].get_rect()) / page_width + if dist <= DEFAULT_SPECIAL_ITEM_MAX_DIST and self.items[i].pos_y <= self.table_rect.y1: + tmp.append(i) + else: + print_verbose( + 5, + "Throw away special item: " + + str(self.items[i]) + + " from table with rect : " + + str(self.table_rect) + + " and distance: " + + str(dist), + ) + self.special_idx = sorted(tmp, key=lambda i: self.items[i].pos_y) + + def cleanup_table(self, page_width, paragraphs): + bak_items = deepcopy(self.items) + bak_idx = deepcopy(self.idx) + + num_cells = -1 + old_num_actual_items = -1 + old_num_rows = -1 + old_num_cols = -1 + + # global config.global_verbosity + # config.global_verbosity = 6 + + print_verbose(3, "Table before cleanup: " + str(self)) + + while True: + cur_num_actual_items = self.count_actual_items() + if ( + cur_num_actual_items == old_num_actual_items + and self.num_rows == old_num_rows + and self.num_cols == old_num_cols + ): + break # no more changes + + old_num_actual_items = cur_num_actual_items + old_num_rows = self.num_rows + old_num_cols = self.num_cols + + print_verbose(3, "--> Next cleanuptable iteration") + self.compactify() + print_verbose(6, "------>> After compactify:" + str(self.get_printed_repr())) + + self.recalc_geometry() + print_verbose(6, "------>> After recalc_geometry:" + str(self.get_printed_repr())) + + self.compactify() + self.throw_away_non_connected_rows() + print_verbose(6, "------>> After throw_away_non_connected_rows:" + str(self.get_printed_repr())) + + self.compactify() + self.throw_away_rows_after_new_header() + print_verbose(6, "------>> After throw_away_rows_after_new_header:" + str(self.get_printed_repr())) + + self.compactify() + self.throw_away_non_connected_cols(page_width) + print_verbose(6, "------>> After throw_away_non_connected_cols:" + str(self.get_printed_repr())) + + self.compactify() + self.throw_away_cols_at_next_paragraph(paragraphs) + print_verbose(6, "------>> After throw_away_cols_at_next_paragraph:" + str(self.get_printed_repr())) + + self.compactify() + self.throw_away_cols_after_year_list() + print_verbose(6, "------>> After throw_away_cols_after_year_list:" + str(self.get_printed_repr())) + + self.compactify() + self.throw_away_duplicate_looking_cols() + print_verbose(6, "------>> After throw_away_duplicate_looking_cols:" + str(self.get_printed_repr())) + + self.compactify() + self.merge_down_all_rows() + print_verbose(6, "------>> After merge_down_all_rows:" + str(self.get_printed_repr())) + + self.compactify() + self.merge_down_all_cols() + print_verbose(6, "------>> After merge_down_all_cols:" + str(self.get_printed_repr())) + + self.compactify() + self.identify_headline() + print_verbose(6, "------>> After identify_headline:" + str(self.get_printed_repr())) + + self.compactify() + self.throw_away_last_headline() + print_verbose(6, "------>> After throw_away_last_headline:" + str(self.get_printed_repr())) + + self.compactify() + self.identify_overlapping_special_items() + print_verbose(6, "------>> After identify_overlapping_special_items:" + str(self.get_printed_repr())) + # raise ValueError('XXX') + + self.compactify() + self.identify_non_numeric_special_items() + print_verbose(6, "------>> After identify_non_numeric_special_items:" + str(self.get_printed_repr())) + + self.compactify() + self.recalc_geometry() + print_verbose(6, "------>> After recalc_geometry:" + str(self.get_printed_repr())) + + self.compactify() + + self.special_idx = sorted(self.special_idx, key=lambda i: self.items[i].pos_y) + self.throw_away_distant_special_items(page_width) + + print_verbose(6, "------>> After throw_away_distant_special_items:" + str(self.get_printed_repr())) + + # restore all items that are no longer part of that table + to_restore = list(set(bak_idx) - set(self.idx + [-1])) + print_verbose(3, "Restoring old items with idx: " + str(to_restore)) + for i in to_restore: + # was this item merged and the merged item is still used? + was_merged = False + for k in self.items[i].merged_list: + if k in self.idx: + was_merged = True + break + if not was_merged: + print_verbose(6, "----> Old item " + str(i) + " was not merged => Restore") + self.items[i] = bak_items[i] + else: + print_verbose(6, "----> Old item " + str(i) + " was merged => Dont touch") + + self.compactify() - - - - """ + print_verbose(6, "------>> After restoring old item:" + str(self.get_printed_repr())) + + print_verbose(3, "===============>>>>>>>>>>>>>>>> Cleanup done <<<<<<<<<<<< =====================") + + """ TODO :doest really work, remove in the future def unfold_patched_numbers(self): def doit(strict_alignment): @@ -1498,87 +1497,90 @@ def doit(strict_alignment): break """ - - - def is_good_table(self): - neccessary_actual_items = 4 if not config.global_be_more_generous_with_good_tables else 2 - if(not(self.num_rows >= 2 and self.num_cols >= 2 and self.count_actual_items() >= neccessary_actual_items)): - print_verbose(7, "----->> bad, reason:1") - return False - - num_items = self.count_actual_items() - density = num_items / (len(self.idx)+0.00001) - - if(density < 0.2): - print_verbose(7, "----->> bad, reason:2") - return False # density less than threshold - - num_sp_items = 0 - for sp in self.special_idx: - if(self.items[sp].txt != ''): - num_sp_items += 1 - - if(num_sp_items > num_items * 0.33 and num_items > 50): - print_verbose(7, "----->> bad, reason:4") - return False # strange: too many sp. items => probably not a table - - cnt_numerics = 0 - cnt_weak_numerics = 0 - for i in self.idx: - if(i != -1): - txt = self.items[i].txt - if(Format_Analyzer.looks_numeric(txt) and not Format_Analyzer.looks_year(txt)): - cnt_numerics += 1 - cnt_weak_numerics += 1 - elif(Format_Analyzer.looks_weak_numeric(txt)): - cnt_weak_numerics +=1 - - print_verbose(7, "----->> reached end of is_good_table") - return (cnt_numerics > 3 and density > 0.6) or (cnt_numerics > 7 and density > 0.4) or cnt_numerics > 10 \ - or (cnt_weak_numerics > 3 and num_items > 5 and density > 0.4) \ - or (cnt_weak_numerics > 0 and num_items > 2 and density > 0.4 and config.global_be_more_generous_with_good_tables) - - - def categorize_as_table(self): - print_verbose(7, "--> Categorize as new table: "+str(self)) - for i in self.idx: - if(i!=-1): - self.items[i].category = CAT_TABLE_DATA - for i in self.headline_idx: - if(i!=-1): - self.items[i].category = CAT_TABLE_HEADLINE - for i in self.special_idx: - if(i!=-1): - self.items[i].category = CAT_TABLE_SPECIAL - - - def categorize_as_misc(self): - print_verbose(7, "--> Categorize as misc: "+str(self)) - for i in self.get_all_idx(): - if(i!=-1): - self.items[i].category = CAT_MISC - - - - def init_by_cols(self, p_idx, p_items): - self.items = p_items - self.idx = p_idx.copy() - self.idx = sorted(self.idx, key=lambda i: self.items[i].pos_y) - - self.num_cols = 1 - self.num_rows = len(self.idx) - - sum_align_pos_x = 0 - - for i in self.idx: - #self.rows.append(self.items[i].get_rect()) - self.marks.append(0) - sum_align_pos_x += self.items[i].get_aligned_pos_x() - - self.col_aligned_pos_x.append(sum_align_pos_x / len(self.idx)) - - """ + def is_good_table(self): + neccessary_actual_items = 4 if not config.global_be_more_generous_with_good_tables else 2 + if not (self.num_rows >= 2 and self.num_cols >= 2 and self.count_actual_items() >= neccessary_actual_items): + print_verbose(7, "----->> bad, reason:1") + return False + + num_items = self.count_actual_items() + density = num_items / (len(self.idx) + 0.00001) + + if density < 0.2: + print_verbose(7, "----->> bad, reason:2") + return False # density less than threshold + + num_sp_items = 0 + for sp in self.special_idx: + if self.items[sp].txt != "": + num_sp_items += 1 + + if num_sp_items > num_items * 0.33 and num_items > 50: + print_verbose(7, "----->> bad, reason:4") + return False # strange: too many sp. items => probably not a table + + cnt_numerics = 0 + cnt_weak_numerics = 0 + for i in self.idx: + if i != -1: + txt = self.items[i].txt + if Format_Analyzer.looks_numeric(txt) and not Format_Analyzer.looks_year(txt): + cnt_numerics += 1 + cnt_weak_numerics += 1 + elif Format_Analyzer.looks_weak_numeric(txt): + cnt_weak_numerics += 1 + + print_verbose(7, "----->> reached end of is_good_table") + return ( + (cnt_numerics > 3 and density > 0.6) + or (cnt_numerics > 7 and density > 0.4) + or cnt_numerics > 10 + or (cnt_weak_numerics > 3 and num_items > 5 and density > 0.4) + or ( + cnt_weak_numerics > 0 + and num_items > 2 + and density > 0.4 + and config.global_be_more_generous_with_good_tables + ) + ) + + def categorize_as_table(self): + print_verbose(7, "--> Categorize as new table: " + str(self)) + for i in self.idx: + if i != -1: + self.items[i].category = CAT_TABLE_DATA + for i in self.headline_idx: + if i != -1: + self.items[i].category = CAT_TABLE_HEADLINE + for i in self.special_idx: + if i != -1: + self.items[i].category = CAT_TABLE_SPECIAL + + def categorize_as_misc(self): + print_verbose(7, "--> Categorize as misc: " + str(self)) + for i in self.get_all_idx(): + if i != -1: + self.items[i].category = CAT_MISC + + def init_by_cols(self, p_idx, p_items): + self.items = p_items + self.idx = p_idx.copy() + self.idx = sorted(self.idx, key=lambda i: self.items[i].pos_y) + + self.num_cols = 1 + self.num_rows = len(self.idx) + + sum_align_pos_x = 0 + + for i in self.idx: + # self.rows.append(self.items[i].get_rect()) + self.marks.append(0) + sum_align_pos_x += self.items[i].get_aligned_pos_x() + + self.col_aligned_pos_x.append(sum_align_pos_x / len(self.idx)) + + """ col_rect = Rect(9999999, 9999999, -1, -1) for r in self.rows: col_rect.x0 = min(col_rect.x0, r.x0) @@ -1588,267 +1590,293 @@ def init_by_cols(self, p_idx, p_items): self.cols.append(col_rect) """ - - self.recalc_geometry() - - def find_top_marked_idx(self, mark): - res = -1 - pos_y = 9999999 - for i in range(len(self.idx)): - if(self.has_item_at_ix(i)): - cur_y = self.get_item_by_ix(i).pos_y - if(cur_y < pos_y and self.marks[i] == mark): - pos_y = cur_y - res = i - return res - - def find_left_marked_idx(self, mark): - res = -1 - pos_x = 9999999 - for i in range(len(self.idx)): - if(self.has_item_at_ix(i)): - cur_x = self.get_item_by_ix(i).pos_x - if(cur_x < pos_x and self.marks[i] == mark): - pos_x = cur_x - res = i - return res - - - - def find_marked_idx_at_y0(self, mark, id, y0, new_mark): - res = [] - for i in range(len(self.idx)): - if(self.has_item_at_ix(i) and self.marks[i] == mark): - if(self.get_item_by_ix(i).pos_y == y0): - res.append((id, i)) - self.marks[i] = new_mark - return res - - def find_marked_idx_between_y0_y1_at_col(self, mark, id, y0, y1, col, new_mark): - res = [] - for i in range(self.num_rows): - ix = self.get_ix(i, col) - if(self.has_item_at_ix(ix) and self.marks[ix] == mark): - r = self.get_item_by_ix(ix).get_rect() - if(r.y0 < y1 and r.y1 >= y0): - res.append((id, ix)) - self.marks[ix] = new_mark - return res - - - - - - @staticmethod - def merge(tab1, tab2, page_width): #note: tables must belong to same HTMLPage ! - - def cols_are_mergable(tab1, tab2, col1, col2): - c1_x0 = 9999999 - c1_x1 = -1 - - c2_x0 = 9999999 - c2_x1 = -1 - - for i in range(tab1.num_rows): - c1_x0 = min(c1_x0, tab1.get_item(i, col1).pos_x if tab1.has_item_at(i, col1) else 9999999) - c1_x1 = max(c1_x1, tab1.get_item(i, col1).pos_x + tab1.get_item(i, col1).width if tab1.has_item_at(i, col1) else -1) - - for i in range(tab2.num_rows): - c2_x0 = min(c2_x0, tab2.get_item(i, col2).pos_x if tab2.has_item_at(i, col2) else 9999999) - c2_x1 = max(c2_x1, tab2.get_item(i, col2).pos_x + tab2.get_item(i, col2).width if tab2.has_item_at(i, col2) else -1) - - #print("MERGE : !!"+str(c1_x0)+" "+str(c1_x1)+" <-> "+str(c2_x0)+" "+str(c2_x1)+" ") - - if(c1_x0 < c2_x0): - if(not c1_x1 > c2_x0): - return False - else: - if(not c2_x1 > c1_x0): - return False - - # make sure that now rows are overlapping - - row1 = 0 - row2 = 0 - while(row1 < tab1.num_rows and row2 < tab2.num_rows): - if(tab1.has_item_at(row1, col1) and tab2.has_item_at(row2, col2)): - it1 = tab1.get_item(row1, col1) - it2 = tab2.get_item(row2, col2) - if(min(it1.pos_y + it1.height, it2.pos_y + it2.height) - max(it1.pos_y, it2.pos_y) >= 0): - return False # overlap - if(it1.pos_y + it1.height < it2.pos_y + it2.height): - row1 += 1 - else: - row2 += 1 - elif(tab1.has_item_at(row1, col1)): - row2 += 1 - else: - row1 += 1 - - return True # no overlaps found - - - def find_all_new_columns(tab1, tab2, tmp_rows, threshold_px): - tab1.reset_marks() - tab2.reset_marks() - num_rows = len(tmp_rows) - - tmp_cols = [] - tab1_col = 0 - tab2_col = 0 - - while(tab1_col < tab1.num_cols or tab2_col < tab2.num_cols): - tmp_items = [] - # find leftmost col - print_verbose(5, "----> tab1_col=" + str(tab1_col) + ", tab2_col=" + str(tab2_col)) - tab1_col_x = tab1.col_aligned_pos_x[tab1_col] if tab1_col < tab1.num_cols else 9999999 - tab2_col_x = tab2.col_aligned_pos_x[tab2_col] if tab2_col < tab2.num_cols else 9999999 - use_tab1 = tab1_col_x <= tab2_col_x + threshold_px - use_tab2 = tab2_col_x <= tab1_col_x + threshold_px - if(use_tab1 and use_tab2 and not cols_are_mergable(tab1, tab2, tab1_col, tab2_col)): # TODO: and cols do not intersect by any rect - # cols cant be merged => use only first one - use_tab1 = tab1_col_x < tab2_col_x - use_tab2 = not use_tab1 - - print_verbose(5, "------> tab1_col_x=" + str(tab1_col_x) + ", tab2_col_x=" + str(tab2_col_x)+ ", use1/2="+str(use_tab1)+"/"+str(use_tab2)) - - - # insert items from that col(s) - for i in range(num_rows): - y0 = tmp_rows[i] - y1 = tmp_rows[i+1] if i < num_rows - 1 else 9999999 - - list_idx = [] - if(use_tab1): - list_idx += tab1.find_marked_idx_between_y0_y1_at_col(0, 1, y0, y1, tab1_col, 1) - if(use_tab2): - list_idx += tab2.find_marked_idx_between_y0_y1_at_col(0, 2, y0, y1, tab2_col, 1) - - if(len(list_idx) > 1): - list_idx = sorted(list_idx, key=lambda id_i: tab1.items[tab1.idx[id_i[1]]].pos_y if id_i[0] == 1 else tab2.items[tab2.idx[id_i[1]]].pos_y ) - for k in range(len(list_idx)-1): - id, ix = list_idx[k] - id1, ix1 = list_idx[k+1] - it = tab1.items[tab1.idx[ix]] if id == 1 else tab2.items[tab2.idx[ix]] - it1 = tab1.items[tab1.idx[ix1]] if id1 == 1 else tab2.items[tab2.idx[ix1]] - if(not it.is_mergable(it1)): - #we found two items, that can't be merged => split row, and try again - if(it.pos_y == it1.pos_y): - #very strange case! should normally never occurence. bad can happen due to bad pdf formatting - print_verbose(6, "------>>> Bad case! Must rearrange item") - it1.pos_y += it1.height * 0.0001 - print_verbose(5, "-----> Split neccessary: " + str(it) + " cant be merged with " +str(it1)) - print_verbose(5, "-----> Split is here: " + str(tmp_rows[0:i+1]) +" <-> " +str(it1.pos_y) + " <-> "+ str(tmp_rows[i+1:])) - tmp_rows = tmp_rows[0:i+1] + [it1.pos_y] + tmp_rows[i+1:] - return False, [], tmp_rows - - #everything can be merged => merge - for k in range(len(list_idx)-1, 0, -1): - id, ix = list_idx[k] - id1, ix1 = list_idx[k-1] - it = tab1.items[tab1.idx[ix]] if id == 1 else tab2.items[tab2.idx[ix]] - it1 = tab1.items[tab1.idx[ix1]] if id1 == 1 else tab2.items[tab2.idx[ix1]] - it1.merge(it) - - if(len(list_idx) == 0): - tmp_items.append(-1) - else: - id, ix = list_idx[0] - tmp_items.append(tab1.idx[ix] if id == 1 else tab2.idx[ix]) - - tmp_cols.append(tmp_items) - - # continue with next col - if(use_tab1): - tab1_col += 1 - if(use_tab2): - tab2_col += 1 - - - return True, tmp_cols, tmp_rows - - - - - if(tab1.items != tab2.items): - raise ValueError('tab1 and tab2 belong to different HTMLPages') - - tmp_idx = tab1.idx + tab2.idx - tmp_idx = list(filter(lambda ix: ix != -1, tmp_idx)) - if(len(tmp_idx) != len(set(tmp_idx))): - raise ValueError('tab1 '+str(tab1.idx)+' and tab2 '+str(tab2.idx)+' intersect') - - # Find all new rows - print_verbose(5, "--> Finding rows") - tab1.reset_marks() - tab2.reset_marks() - - tmp_rows = [] - - threshold_px = DEFAULT_VTHRESHOLD * page_width - - while(tab1.count_marks(0) > 0 or tab2.count_marks(0) > 0): - idx1 = tab1.find_top_marked_idx(0) - idx2 = tab2.find_top_marked_idx(0) - if(idx1 == -1 and idx2 ==-1): - #strange! this should never happen! - raise ValueError('idx1 and idx2 are both -1, but marks=0 exist!') - - if(idx2 == -1 or (idx1 != -1 and tab1.get_item_by_ix(idx1).pos_y Continue with idx1="+str(idx1)+",idx2="+str(idx2)+", r="+str(cur_rect)+", list_idx="+str(list_idx)) - - - min_y = 9999999 - for (id, i) in list_idx: - cur_y = tab1.items[tab1.idx[i]].pos_y if id == 1 else tab2.items[tab2.idx[i]].pos_y - if(cur_y < min_y): - min_y = cur_y - tmp_rows.append(min_y) #(list_idx, min_y)) - - - print_verbose(5, "----> Rows: "+ str(tmp_rows)) - # Find all new columns - print_verbose(3, "--> Rows found, continuing with columns") - finding_cols_done = False - tmp_cols = [] - while(not finding_cols_done): - print_verbose(7, "--> Next try") - finding_cols_done, tmp_cols, tmp_rows = find_all_new_columns(tab1, tab2, tmp_rows, threshold_px) - print_verbose(7, "----> New Rows: "+ str(tmp_rows)) - - # Build resulting table - print_verbose(3, "--> Columns found, now build final table") - - res = HTMLTable() - - res.items = tab1.items - res.num_rows = len(tmp_rows) - res.num_cols = len(tmp_cols) - res.idx = [-1] * (res.num_rows * res.num_cols) - res.marks = [0] * (res.num_rows * res.num_cols) - - for j in range(res.num_cols): - for i in range(res.num_rows): - res.idx[i * res.num_cols + j] = tmp_cols[j][i] - - res.recalc_geometry() - - for it in tab1.items: - it.recalc_geometry() - - return res - - - #TODO: Remove: - """ + self.recalc_geometry() + + def find_top_marked_idx(self, mark): + res = -1 + pos_y = 9999999 + for i in range(len(self.idx)): + if self.has_item_at_ix(i): + cur_y = self.get_item_by_ix(i).pos_y + if cur_y < pos_y and self.marks[i] == mark: + pos_y = cur_y + res = i + return res + + def find_left_marked_idx(self, mark): + res = -1 + pos_x = 9999999 + for i in range(len(self.idx)): + if self.has_item_at_ix(i): + cur_x = self.get_item_by_ix(i).pos_x + if cur_x < pos_x and self.marks[i] == mark: + pos_x = cur_x + res = i + return res + + def find_marked_idx_at_y0(self, mark, id, y0, new_mark): + res = [] + for i in range(len(self.idx)): + if self.has_item_at_ix(i) and self.marks[i] == mark: + if self.get_item_by_ix(i).pos_y == y0: + res.append((id, i)) + self.marks[i] = new_mark + return res + + def find_marked_idx_between_y0_y1_at_col(self, mark, id, y0, y1, col, new_mark): + res = [] + for i in range(self.num_rows): + ix = self.get_ix(i, col) + if self.has_item_at_ix(ix) and self.marks[ix] == mark: + r = self.get_item_by_ix(ix).get_rect() + if r.y0 < y1 and r.y1 >= y0: + res.append((id, ix)) + self.marks[ix] = new_mark + return res + + @staticmethod + def merge(tab1, tab2, page_width): # note: tables must belong to same HTMLPage ! + def cols_are_mergable(tab1, tab2, col1, col2): + c1_x0 = 9999999 + c1_x1 = -1 + + c2_x0 = 9999999 + c2_x1 = -1 + + for i in range(tab1.num_rows): + c1_x0 = min(c1_x0, tab1.get_item(i, col1).pos_x if tab1.has_item_at(i, col1) else 9999999) + c1_x1 = max( + c1_x1, + tab1.get_item(i, col1).pos_x + tab1.get_item(i, col1).width if tab1.has_item_at(i, col1) else -1, + ) + + for i in range(tab2.num_rows): + c2_x0 = min(c2_x0, tab2.get_item(i, col2).pos_x if tab2.has_item_at(i, col2) else 9999999) + c2_x1 = max( + c2_x1, + tab2.get_item(i, col2).pos_x + tab2.get_item(i, col2).width if tab2.has_item_at(i, col2) else -1, + ) + + # print("MERGE : !!"+str(c1_x0)+" "+str(c1_x1)+" <-> "+str(c2_x0)+" "+str(c2_x1)+" ") + + if c1_x0 < c2_x0: + if not c1_x1 > c2_x0: + return False + else: + if not c2_x1 > c1_x0: + return False + + # make sure that now rows are overlapping + + row1 = 0 + row2 = 0 + while row1 < tab1.num_rows and row2 < tab2.num_rows: + if tab1.has_item_at(row1, col1) and tab2.has_item_at(row2, col2): + it1 = tab1.get_item(row1, col1) + it2 = tab2.get_item(row2, col2) + if min(it1.pos_y + it1.height, it2.pos_y + it2.height) - max(it1.pos_y, it2.pos_y) >= 0: + return False # overlap + if it1.pos_y + it1.height < it2.pos_y + it2.height: + row1 += 1 + else: + row2 += 1 + elif tab1.has_item_at(row1, col1): + row2 += 1 + else: + row1 += 1 + + return True # no overlaps found + + def find_all_new_columns(tab1, tab2, tmp_rows, threshold_px): + tab1.reset_marks() + tab2.reset_marks() + num_rows = len(tmp_rows) + + tmp_cols = [] + tab1_col = 0 + tab2_col = 0 + + while tab1_col < tab1.num_cols or tab2_col < tab2.num_cols: + tmp_items = [] + # find leftmost col + print_verbose(5, "----> tab1_col=" + str(tab1_col) + ", tab2_col=" + str(tab2_col)) + tab1_col_x = tab1.col_aligned_pos_x[tab1_col] if tab1_col < tab1.num_cols else 9999999 + tab2_col_x = tab2.col_aligned_pos_x[tab2_col] if tab2_col < tab2.num_cols else 9999999 + use_tab1 = tab1_col_x <= tab2_col_x + threshold_px + use_tab2 = tab2_col_x <= tab1_col_x + threshold_px + if ( + use_tab1 and use_tab2 and not cols_are_mergable(tab1, tab2, tab1_col, tab2_col) + ): # TODO: and cols do not intersect by any rect + # cols cant be merged => use only first one + use_tab1 = tab1_col_x < tab2_col_x + use_tab2 = not use_tab1 + + print_verbose( + 5, + "------> tab1_col_x=" + + str(tab1_col_x) + + ", tab2_col_x=" + + str(tab2_col_x) + + ", use1/2=" + + str(use_tab1) + + "/" + + str(use_tab2), + ) + + # insert items from that col(s) + for i in range(num_rows): + y0 = tmp_rows[i] + y1 = tmp_rows[i + 1] if i < num_rows - 1 else 9999999 + + list_idx = [] + if use_tab1: + list_idx += tab1.find_marked_idx_between_y0_y1_at_col(0, 1, y0, y1, tab1_col, 1) + if use_tab2: + list_idx += tab2.find_marked_idx_between_y0_y1_at_col(0, 2, y0, y1, tab2_col, 1) + + if len(list_idx) > 1: + list_idx = sorted( + list_idx, + key=lambda id_i: tab1.items[tab1.idx[id_i[1]]].pos_y + if id_i[0] == 1 + else tab2.items[tab2.idx[id_i[1]]].pos_y, + ) + for k in range(len(list_idx) - 1): + id, ix = list_idx[k] + id1, ix1 = list_idx[k + 1] + it = tab1.items[tab1.idx[ix]] if id == 1 else tab2.items[tab2.idx[ix]] + it1 = tab1.items[tab1.idx[ix1]] if id1 == 1 else tab2.items[tab2.idx[ix1]] + if not it.is_mergable(it1): + # we found two items, that can't be merged => split row, and try again + if it.pos_y == it1.pos_y: + # very strange case! should normally never occurence. bad can happen due to bad pdf formatting + print_verbose(6, "------>>> Bad case! Must rearrange item") + it1.pos_y += it1.height * 0.0001 + print_verbose( + 5, "-----> Split neccessary: " + str(it) + " cant be merged with " + str(it1) + ) + print_verbose( + 5, + "-----> Split is here: " + + str(tmp_rows[0 : i + 1]) + + " <-> " + + str(it1.pos_y) + + " <-> " + + str(tmp_rows[i + 1 :]), + ) + tmp_rows = tmp_rows[0 : i + 1] + [it1.pos_y] + tmp_rows[i + 1 :] + return False, [], tmp_rows + + # everything can be merged => merge + for k in range(len(list_idx) - 1, 0, -1): + id, ix = list_idx[k] + id1, ix1 = list_idx[k - 1] + it = tab1.items[tab1.idx[ix]] if id == 1 else tab2.items[tab2.idx[ix]] + it1 = tab1.items[tab1.idx[ix1]] if id1 == 1 else tab2.items[tab2.idx[ix1]] + it1.merge(it) + + if len(list_idx) == 0: + tmp_items.append(-1) + else: + id, ix = list_idx[0] + tmp_items.append(tab1.idx[ix] if id == 1 else tab2.idx[ix]) + + tmp_cols.append(tmp_items) + + # continue with next col + if use_tab1: + tab1_col += 1 + if use_tab2: + tab2_col += 1 + + return True, tmp_cols, tmp_rows + + if tab1.items != tab2.items: + raise ValueError("tab1 and tab2 belong to different HTMLPages") + + tmp_idx = tab1.idx + tab2.idx + tmp_idx = list(filter(lambda ix: ix != -1, tmp_idx)) + if len(tmp_idx) != len(set(tmp_idx)): + raise ValueError("tab1 " + str(tab1.idx) + " and tab2 " + str(tab2.idx) + " intersect") + + # Find all new rows + print_verbose(5, "--> Finding rows") + tab1.reset_marks() + tab2.reset_marks() + + tmp_rows = [] + + threshold_px = DEFAULT_VTHRESHOLD * page_width + + while tab1.count_marks(0) > 0 or tab2.count_marks(0) > 0: + idx1 = tab1.find_top_marked_idx(0) + idx2 = tab2.find_top_marked_idx(0) + if idx1 == -1 and idx2 == -1: + # strange! this should never happen! + raise ValueError("idx1 and idx2 are both -1, but marks=0 exist!") + + if idx2 == -1 or (idx1 != -1 and tab1.get_item_by_ix(idx1).pos_y < tab2.get_item_by_ix(idx2).pos_y): + cur_rect = tab1.get_item_by_ix(idx1).get_rect() + else: + cur_rect = tab2.get_item_by_ix(idx2).get_rect() + + list_idx = tab1.find_marked_idx_at_y0(0, 1, cur_rect.y0, 1) + list_idx += tab2.find_marked_idx_at_y0(0, 2, cur_rect.y0, 1) + + print_verbose( + 9, + "----> Continue with idx1=" + + str(idx1) + + ",idx2=" + + str(idx2) + + ", r=" + + str(cur_rect) + + ", list_idx=" + + str(list_idx), + ) + + min_y = 9999999 + for id, i in list_idx: + cur_y = tab1.items[tab1.idx[i]].pos_y if id == 1 else tab2.items[tab2.idx[i]].pos_y + if cur_y < min_y: + min_y = cur_y + tmp_rows.append(min_y) # (list_idx, min_y)) + + print_verbose(5, "----> Rows: " + str(tmp_rows)) + # Find all new columns + print_verbose(3, "--> Rows found, continuing with columns") + finding_cols_done = False + tmp_cols = [] + while not finding_cols_done: + print_verbose(7, "--> Next try") + finding_cols_done, tmp_cols, tmp_rows = find_all_new_columns(tab1, tab2, tmp_rows, threshold_px) + print_verbose(7, "----> New Rows: " + str(tmp_rows)) + + # Build resulting table + print_verbose(3, "--> Columns found, now build final table") + + res = HTMLTable() + + res.items = tab1.items + res.num_rows = len(tmp_rows) + res.num_cols = len(tmp_cols) + res.idx = [-1] * (res.num_rows * res.num_cols) + res.marks = [0] * (res.num_rows * res.num_cols) + + for j in range(res.num_cols): + for i in range(res.num_rows): + res.idx[i * res.num_cols + j] = tmp_cols[j][i] + + res.recalc_geometry() + + for it in tab1.items: + it.recalc_geometry() + + return res + + # TODO: Remove: + """ def get_printed_repr(self): COL_WIDTH = 10 @@ -1863,301 +1891,317 @@ def get_printed_repr(self): return res """ - - def is_compatible_with_existing_row(self, r0, it0): - min_y0 = 9999999 - max_y1 = -1 - max_height = 0 - for j in range(self.num_cols): - if(self.has_non_empty_item_at(r0, j)): - it = self.get_item(r0, j) - min_y0 = min(min_y0, it.pos_y) - max_y1 = max(max_y1, it.pos_y + it.height) - max_height = max(max_height, it.height) - - y0 = it0.pos_y - y1 = it0.pos_y + it0.height - if(y1 < min_y0): - return min_y0 - y1 < 0.5 * max_height - if(y0 > max_y1): - return y0 - max_y1 < 0.5 * max_height - return True - - - def force_special_items_into_table(self): - new_special_idx = [] - for idx in self.special_idx: - ix = self.find_nearest_cell_ix(self.items[idx]) - r, c = self.get_row_and_col_by_ix(ix) - if(not(self.has_non_empty_item_at(r, c)) and self.is_compatible_with_existing_row(r, self.items[idx])): - #free cell - self.idx[ix] = idx - else: - #no free cell. can we insert a new row? - if(self.is_row_insertion_possible(r, self.items[idx].pos_y)): - self.insert_row(r) - self.idx[self.get_ix(r+1, c)] = idx - self.recalc_geometry() - else: - #no empty space. leave as special item - new_special_idx.append(idx) - self.special_idx = new_special_idx - - self.recalc_geometry() - self.compactify() - - - - def is_non_overlapping_row_mergable(self, r0): # return True iff row r0 and r0+1 can be merged - if(r0 < 0 or r0 >= self.num_rows - 1): - raise ValueError('Rows r0='+str(r0)+' and r0+1 out of range.') - - - # only one row is allowed to contain numbers - n0 = False - n1 = False - for j in range(self.num_cols): - if(self.has_item_at(r0,j) and Format_Analyzer.looks_numeric(self.get_item(r0,j).txt)): - n0 = True - if(self.has_item_at(r0+1,j) and Format_Analyzer.looks_numeric(self.get_item(r0+1,j).txt)): - n1 = True - - if(n0 and n1): - print_verbose(8, '--->> is_non_overlapping_row_mergable: Rows r0='+str(r0)+' and r0+1 have both numbers') - return False,0 # both rows contain numbers - - #if(config.global_table_merging_only_if_numbers_come_first and not n0): - # print_verbose(8, '--->> is_non_overlapping_row_mergable: Rows r0='+str(r0)+' and r0+1, the first one has no numbers.') - # return False,0 # - - - y0_max = 0 - y1_min = 9999999 - - has_mergable_candidates = False - font_chars = "" - for j in range(self.num_cols): - both_filled = self.has_item_at(r0, j) and self.has_item_at(r0+1, j) - if(both_filled and not self.get_item(r0,j).is_weakly_mergable_after_reconnect(self.get_item(r0+1,j))): - print_verbose(8, '--->> is_non_overlapping_row_mergable: Rows r0='+str(r0)+' and r0+1 are not mergable. Items:' \ - + str(self.get_item(r0,j)) + ' and ' + str(self.get_item(r0+1,j))) - return False,0 - if(both_filled): - has_mergable_candidates = True - font_chars = self.get_item(r0,j).get_font_characteristics() - if(self.get_item(r0,j).pos_y + self.get_item(r0,j).height*(3.0 if n0 else 2.1) < self.get_item(r0+1,j).pos_y): - print_verbose(8, '--->> is_non_overlapping_row_mergable: Rows r0='+str(r0)+' and r0+1 are too far apart. Items:' \ - + str(self.get_item(r0,j)) + ' and ' + str(self.get_item(r0+1,j))) - return False,0 - cur_y0 = self.get_item(r0, j).pos_y + self.get_item(r0, j).height if self.has_item_at(r0, j) else 0 - cur_y1 = self.get_item(r0+1, j).pos_y if self.has_item_at(r0+1, j) else 9999999 - y0_max = max(y0_max, cur_y0) - y1_min = min(y1_min, cur_y1) - - if(not has_mergable_candidates): - print_verbose(8, '--->> is_non_overlapping_row_mergable: Rows r0='+str(r0)+' and r0+1 have no mergable candidates') - return False,0 - - if(y0_max >= y1_min): - print_verbose(8, '--->> is_non_overlapping_row_mergable: Rows r0='+str(r0)+' and r0+1 would overlap') - return False,0 - - # make sure, all same font characteristics - for j in range(self.num_cols): - if((self.has_item_at(r0, j) and self.get_item(r0,j).get_font_characteristics() != font_chars) \ - or (self.has_item_at(r0+1, j) and self.get_item(r0+1,j).get_font_characteristics() != font_chars)): - print_verbose(8, '--->> is_non_overlapping_row_mergable: Rows r0='+str(r0)+' and r0+1 have different font chars') - return False,0 - - mod_dist = math.floor((y1_min-y0_max)*100.0)/100.0 * (0.66 if n0 else 1) - print_verbose(8, "---->>> For row="+str(r0)+", mod_dist="+str(mod_dist) + ", y0_max=" +str(y0_max) + ",y1_min="+str(y1_min)) - return True, mod_dist # TODO: maybe use only the distance, where both rows are filled? - - - - def merge_non_overlapping_rows_single(self): - # this will merge rows, even if both them contain text, by appending the text - merge_list = [] - min_dist = 9999999 - min_row = -1 - for i in range(self.num_rows-1): - is_merg, dist = self.is_non_overlapping_row_mergable(i) - if(is_merg): - merge_list.append(i) - if(dist> Found year lists at: ' +str(year_cols)) - - - yl = [] - for yc in year_cols: - yl.append((yc.c0, yc.c1)) - - yl = list(set(yl)) - - res = [] - for y in yl: - if(y[0] - 1 > 0): - cur_tab = deepcopy(self) - if(y[1]+1 < cur_tab.num_cols): - cur_tab.delete_cols(y[1]+1, cur_tab.num_cols) - cur_tab.delete_cols(0, y[0]-1) - print_verbose(6, "Found sub-table:\n"+str(cur_tab.get_printed_repr())) - res.append(cur_tab) - - return res - - - - - def save_to_csv(self, csv_file): - ctab = ConsoleTable(self.num_cols) - - for ix in range(self.num_rows * self.num_cols): - ctab.cells.append(self.items[self.idx[ix]].txt if self.has_item_at_ix(ix) else '') - - res = ctab.to_string(use_format = ConsoleTable.FORMAT_CSV) - - save_txt_to_file(res, csv_file) - - - - - def get_printed_repr(self): - COL_WIDTH = 10 - - special_printed = [False] * len(self.special_idx) - - res = '' - # table headline title - for i in self.headline_idx: - res += '===> ' + self.items[i].txt + '<===\n' - - - # headline - res += '\u2554' - for j in range(self.num_cols): - res += '\u2550' * (COL_WIDTH-1) - res += '\u2566' if j < self.num_cols - 1 else '\u2557\n' - - # content - for i in range(self.num_rows): - # frame line - if(i>0): - res += '\u2560' - for j in range(self.num_cols): - res += '\u2550'*(COL_WIDTH-1) - res += '\u256c' if j < self.num_cols -1 else '\u2563' - res += '\n' - - # content line - res += '\u2551' - for j in range(self.num_cols): - if(self.has_item_at(i,j)): - txt = self.get_item(i, j).txt.replace('\n', ' ') - res += str(txt)[:(COL_WIDTH-1)].ljust(COL_WIDTH-1, ' ') - else: - res +=' '.ljust(COL_WIDTH-1, ' ') - res += '\u2551' - - # special items - sp_ix = self.find_applying_special_item_ix(i) - if(sp_ix is not None and special_printed[sp_ix] == False): - special_printed[sp_ix] = True - res += ' * ' + self.items[self.special_idx[sp_ix]].txt[:COL_WIDTH] - - - - res += '\n' - - # footer line - res += '\u255a' - for j in range(self.num_cols): - res += '\u2550' * (COL_WIDTH-1) - res += '\u2569' if j < self.num_cols - 1 else '\u255d\n' - - - - return res - - - - - def __repr__(self): - res = 'Row-Dim: ' + str(self.rows) - res += '\nCol-Dim: ' +str(self.cols) - res += '\nSpecial-Items: ' - for k in self.special_idx: - res += str(self.items[k]) + ' ' - res += '\n' + self.get_printed_repr() - return res - - - - - - - + def is_compatible_with_existing_row(self, r0, it0): + min_y0 = 9999999 + max_y1 = -1 + max_height = 0 + for j in range(self.num_cols): + if self.has_non_empty_item_at(r0, j): + it = self.get_item(r0, j) + min_y0 = min(min_y0, it.pos_y) + max_y1 = max(max_y1, it.pos_y + it.height) + max_height = max(max_height, it.height) + + y0 = it0.pos_y + y1 = it0.pos_y + it0.height + if y1 < min_y0: + return min_y0 - y1 < 0.5 * max_height + if y0 > max_y1: + return y0 - max_y1 < 0.5 * max_height + return True + + def force_special_items_into_table(self): + new_special_idx = [] + for idx in self.special_idx: + ix = self.find_nearest_cell_ix(self.items[idx]) + r, c = self.get_row_and_col_by_ix(ix) + if not (self.has_non_empty_item_at(r, c)) and self.is_compatible_with_existing_row(r, self.items[idx]): + # free cell + self.idx[ix] = idx + else: + # no free cell. can we insert a new row? + if self.is_row_insertion_possible(r, self.items[idx].pos_y): + self.insert_row(r) + self.idx[self.get_ix(r + 1, c)] = idx + self.recalc_geometry() + else: + # no empty space. leave as special item + new_special_idx.append(idx) + self.special_idx = new_special_idx + + self.recalc_geometry() + self.compactify() + + def is_non_overlapping_row_mergable(self, r0): # return True iff row r0 and r0+1 can be merged + if r0 < 0 or r0 >= self.num_rows - 1: + raise ValueError("Rows r0=" + str(r0) + " and r0+1 out of range.") + + # only one row is allowed to contain numbers + n0 = False + n1 = False + for j in range(self.num_cols): + if self.has_item_at(r0, j) and Format_Analyzer.looks_numeric(self.get_item(r0, j).txt): + n0 = True + if self.has_item_at(r0 + 1, j) and Format_Analyzer.looks_numeric(self.get_item(r0 + 1, j).txt): + n1 = True + + if n0 and n1: + print_verbose( + 8, "--->> is_non_overlapping_row_mergable: Rows r0=" + str(r0) + " and r0+1 have both numbers" + ) + return False, 0 # both rows contain numbers + + # if(config.global_table_merging_only_if_numbers_come_first and not n0): + # print_verbose(8, '--->> is_non_overlapping_row_mergable: Rows r0='+str(r0)+' and r0+1, the first one has no numbers.') + # return False,0 # + + y0_max = 0 + y1_min = 9999999 + + has_mergable_candidates = False + font_chars = "" + for j in range(self.num_cols): + both_filled = self.has_item_at(r0, j) and self.has_item_at(r0 + 1, j) + if both_filled and not self.get_item(r0, j).is_weakly_mergable_after_reconnect(self.get_item(r0 + 1, j)): + print_verbose( + 8, + "--->> is_non_overlapping_row_mergable: Rows r0=" + + str(r0) + + " and r0+1 are not mergable. Items:" + + str(self.get_item(r0, j)) + + " and " + + str(self.get_item(r0 + 1, j)), + ) + return False, 0 + if both_filled: + has_mergable_candidates = True + font_chars = self.get_item(r0, j).get_font_characteristics() + if ( + self.get_item(r0, j).pos_y + self.get_item(r0, j).height * (3.0 if n0 else 2.1) + < self.get_item(r0 + 1, j).pos_y + ): + print_verbose( + 8, + "--->> is_non_overlapping_row_mergable: Rows r0=" + + str(r0) + + " and r0+1 are too far apart. Items:" + + str(self.get_item(r0, j)) + + " and " + + str(self.get_item(r0 + 1, j)), + ) + return False, 0 + cur_y0 = self.get_item(r0, j).pos_y + self.get_item(r0, j).height if self.has_item_at(r0, j) else 0 + cur_y1 = self.get_item(r0 + 1, j).pos_y if self.has_item_at(r0 + 1, j) else 9999999 + y0_max = max(y0_max, cur_y0) + y1_min = min(y1_min, cur_y1) + + if not has_mergable_candidates: + print_verbose( + 8, "--->> is_non_overlapping_row_mergable: Rows r0=" + str(r0) + " and r0+1 have no mergable candidates" + ) + return False, 0 + + if y0_max >= y1_min: + print_verbose(8, "--->> is_non_overlapping_row_mergable: Rows r0=" + str(r0) + " and r0+1 would overlap") + return False, 0 + + # make sure, all same font characteristics + for j in range(self.num_cols): + if (self.has_item_at(r0, j) and self.get_item(r0, j).get_font_characteristics() != font_chars) or ( + self.has_item_at(r0 + 1, j) and self.get_item(r0 + 1, j).get_font_characteristics() != font_chars + ): + print_verbose( + 8, + "--->> is_non_overlapping_row_mergable: Rows r0=" + str(r0) + " and r0+1 have different font chars", + ) + return False, 0 + + mod_dist = math.floor((y1_min - y0_max) * 100.0) / 100.0 * (0.66 if n0 else 1) + print_verbose( + 8, + "---->>> For row=" + + str(r0) + + ", mod_dist=" + + str(mod_dist) + + ", y0_max=" + + str(y0_max) + + ",y1_min=" + + str(y1_min), + ) + return True, mod_dist # TODO: maybe use only the distance, where both rows are filled? + + def merge_non_overlapping_rows_single(self): + # this will merge rows, even if both them contain text, by appending the text + merge_list = [] + min_dist = 9999999 + min_row = -1 + for i in range(self.num_rows - 1): + is_merg, dist = self.is_non_overlapping_row_mergable(i) + if is_merg: + merge_list.append(i) + if dist < min_dist: + min_dist = min(min_dist, dist) + min_row = i + + print_verbose(7, "The following non-overlapping rows could be merged: " + str(merge_list)) + if len(merge_list) == 0: + return False # nothing was merged + + print_verbose(7, "We merge now " + str(min_row)) + + print_verbose(9, "Before merging: " + self.get_printed_repr()) + self.merge_rows(min_row, True) + print_verbose(9, "After merging rows " + str(min_row) + " and r+1: " + self.get_printed_repr()) + + self.recalc_geometry() + print_verbose(9, "After recalc geometry: " + self.get_printed_repr()) + + return True # we merged something! + + def merge_non_overlapping_rows(self): + print_verbose(7, "merge_non_overlapping_rows, Table=" + self.get_printed_repr()) + while True: + if not self.merge_non_overlapping_rows_single(): + return + + def generate_sub_tables(self): # generates sub-tables that are helpful for the AnalyzerTable, based on year-columns + class YearCols: + r = None + c0 = None + c1 = None + + def __init__(self, r, c0): + self.r = r + self.c0 = c0 + + def __repr__(self): + return "(r=" + str(self.r) + ",c0=" + str(self.c0) + ",c1=" + str(self.c1) + ")" + + year_cols = [] + + for i in range(self.num_rows): + j = 0 + while j < self.num_cols - 1: + if ( + self.has_item_at(i, j) + and self.has_item_at(i, j + 1) + and Format_Analyzer.looks_year(self.get_item(i, j).txt) + and Format_Analyzer.looks_year(self.get_item(i, j + 1).txt) + and abs( + Format_Analyzer.to_year(self.get_item(i, j + 1).txt) + - Format_Analyzer.to_year(self.get_item(i, j).txt) + ) + == 1 + ): + dir = Format_Analyzer.to_year(self.get_item(i, j + 1).txt) - Format_Analyzer.to_year( + self.get_item(i, j).txt + ) + cur_year_cols = YearCols(i, j) + # find last year col + # print("Now at cell:"+str(i)+","+str(j)) + for j1 in range(j + 1, self.num_cols): + if ( + self.has_item_at(i, j1) + and Format_Analyzer.looks_year(self.get_item(i, j1).txt) + and Format_Analyzer.to_year(self.get_item(i, j1).txt) + - Format_Analyzer.to_year(self.get_item(i, j1 - 1).txt) + == dir + ): + cur_year_cols.c1 = j1 + else: + break + year_cols.append(cur_year_cols) + j = cur_year_cols.c1 + j += 1 + + print_verbose(6, "----->> Found year lists at: " + str(year_cols)) + + yl = [] + for yc in year_cols: + yl.append((yc.c0, yc.c1)) + + yl = list(set(yl)) + + res = [] + for y in yl: + if y[0] - 1 > 0: + cur_tab = deepcopy(self) + if y[1] + 1 < cur_tab.num_cols: + cur_tab.delete_cols(y[1] + 1, cur_tab.num_cols) + cur_tab.delete_cols(0, y[0] - 1) + print_verbose(6, "Found sub-table:\n" + str(cur_tab.get_printed_repr())) + res.append(cur_tab) + + return res + + def save_to_csv(self, csv_file): + ctab = ConsoleTable(self.num_cols) + + for ix in range(self.num_rows * self.num_cols): + ctab.cells.append(self.items[self.idx[ix]].txt if self.has_item_at_ix(ix) else "") + + res = ctab.to_string(use_format=ConsoleTable.FORMAT_CSV) + + save_txt_to_file(res, csv_file) + + def get_printed_repr(self): + COL_WIDTH = 10 + + special_printed = [False] * len(self.special_idx) + + res = "" + # table headline title + for i in self.headline_idx: + res += "===> " + self.items[i].txt + "<===\n" + + # headline + res += "\u2554" + for j in range(self.num_cols): + res += "\u2550" * (COL_WIDTH - 1) + res += "\u2566" if j < self.num_cols - 1 else "\u2557\n" + + # content + for i in range(self.num_rows): + # frame line + if i > 0: + res += "\u2560" + for j in range(self.num_cols): + res += "\u2550" * (COL_WIDTH - 1) + res += "\u256c" if j < self.num_cols - 1 else "\u2563" + res += "\n" + + # content line + res += "\u2551" + for j in range(self.num_cols): + if self.has_item_at(i, j): + txt = self.get_item(i, j).txt.replace("\n", " ") + res += str(txt)[: (COL_WIDTH - 1)].ljust(COL_WIDTH - 1, " ") + else: + res += " ".ljust(COL_WIDTH - 1, " ") + res += "\u2551" + + # special items + sp_ix = self.find_applying_special_item_ix(i) + if sp_ix is not None and special_printed[sp_ix] == False: + special_printed[sp_ix] = True + res += " * " + self.items[self.special_idx[sp_ix]].txt[:COL_WIDTH] + + res += "\n" + + # footer line + res += "\u255a" + for j in range(self.num_cols): + res += "\u2550" * (COL_WIDTH - 1) + res += "\u2569" if j < self.num_cols - 1 else "\u255d\n" + + return res + + def __repr__(self): + res = "Row-Dim: " + str(self.rows) + res += "\nCol-Dim: " + str(self.cols) + res += "\nSpecial-Items: " + for k in self.special_idx: + res += str(self.items[k]) + " " + res += "\n" + self.get_printed_repr() + return res diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLWord.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLWord.py index 3cedd4f..351e251 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLWord.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/HTMLWord.py @@ -12,13 +12,11 @@ class HTMLWord: - txt = None - rect = None - item_id = None # to which HTMLItem id does this word belong? - - - def __init__(self): - self.txt = '' - self.rect = Rect(99999,99999,-1,-1) - self.item_id = -1 + txt = None + rect = None + item_id = None # to which HTMLItem id does this word belong? + def __init__(self): + self.txt = "" + self.rect = Rect(99999, 99999, -1, -1) + self.item_id = -1 diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPIMeasure.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPIMeasure.py index 7456819..466c30e 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPIMeasure.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPIMeasure.py @@ -9,145 +9,153 @@ from globals import * from Format_Analyzer import * -class KPIMeasure: - kpi_id = None - kpi_name = None - src_file = None - src_path = None - company_name = None - page_num = None - item_ids = None - pos_x = None - pos_y = None - raw_txt = None - year = None - value = None - score = None - unit = None - match_type = None - tmp = None # for temporary values used by an Analyzer - - - def __init__(self): - self.kpi_id = -1 - self.kpi_name = '' - self.src_file = '' - self.src_path = '' - self.company_name = '' - self.page_num = -1 - self.item_ids = [] - self.pos_x = -1 - self.pos_y = -1 - self.raw_txt = '' - self.year = -1 - self.value = '' - self.score = -1 - self.unit = '' - self.match_type = '' - self.tmp = None - - def set_file_path(self, file_path): - fp = file_path.replace('\\', '/') - self.src_file = Format_Analyzer.extract_file_name(fp) - self.src_path = remove_trailing_slash(Format_Analyzer.extract_file_path(fp)[0]) - self.company_name = self.src_path[self.src_path.rfind('/')+1:] - - def __repr__(self): - return "" - - - - @staticmethod - def remove_all_years(lst): - res = [] - for it in lst: - it.year=-1 - res.append(it) - return res - - @staticmethod - def remove_duplicates(lst): # from the list of KPIMeasure "lst", remove all duplicates (same kpi_name and year), with less than best score - - #return lst # Dont remoe anything (for debug purposes only) - - keep = [False] * len(lst) - for i in range(len(lst)): - better_kpi_exists = False - for j in range(len(lst)): - if(i==j): - continue - if(lst[j].kpi_name == lst[i].kpi_name and lst[j].year == lst[i].year and (lst[j].score > lst[i].score or (lst[j].score == lst[i].score and j > i ))): - better_kpi_exists = True - break - keep[i] = not better_kpi_exists - - res = [] - for i in range(len(lst)): - if(keep[i]): - res.append(lst[i]) - - return res - - @staticmethod - def remove_bad_scores(lst, minimum_score): - - - #return lst # Dont remoe anything (for debug purposes only) - - res = [] - - max_score = {} - - for k in lst: - max_score[k.kpi_name] = k.score if not k.kpi_name in max_score else max(k.score, max_score[k.kpi_name]) - - for k in lst: - if(k.score >= minimum_score and k.score >= max_score[k.kpi_name] * 0.75 ): - res.append(k) - - return res - - @staticmethod - def remove_bad_years(lst, default_year): - - - # do we have entries with a year? - year_exist = [] - - for k in lst: - if(k.year != -1): - year_exist.append(k.kpi_name) - break - - - res = [] - year_exist = list(set(year_exist)) - - - - for k in lst: - if(k.year == -1): - if(not k.kpi_name in year_exist): - k.year = default_year - res.append(k) - - else: - res.append(k) - - return res - - - - - - - - - - - - - +class KPIMeasure: + kpi_id = None + kpi_name = None + src_file = None + src_path = None + company_name = None + page_num = None + item_ids = None + pos_x = None + pos_y = None + raw_txt = None + year = None + value = None + score = None + unit = None + match_type = None + tmp = None # for temporary values used by an Analyzer + + def __init__(self): + self.kpi_id = -1 + self.kpi_name = "" + self.src_file = "" + self.src_path = "" + self.company_name = "" + self.page_num = -1 + self.item_ids = [] + self.pos_x = -1 + self.pos_y = -1 + self.raw_txt = "" + self.year = -1 + self.value = "" + self.score = -1 + self.unit = "" + self.match_type = "" + self.tmp = None + + def set_file_path(self, file_path): + fp = file_path.replace("\\", "/") + self.src_file = Format_Analyzer.extract_file_name(fp) + self.src_path = remove_trailing_slash(Format_Analyzer.extract_file_path(fp)[0]) + self.company_name = self.src_path[self.src_path.rfind("/") + 1 :] + + def __repr__(self): + return ( + "" + ) + + @staticmethod + def remove_all_years(lst): + res = [] + for it in lst: + it.year = -1 + res.append(it) + return res + + @staticmethod + def remove_duplicates( + lst, + ): # from the list of KPIMeasure "lst", remove all duplicates (same kpi_name and year), with less than best score + # return lst # Dont remoe anything (for debug purposes only) + + keep = [False] * len(lst) + for i in range(len(lst)): + better_kpi_exists = False + for j in range(len(lst)): + if i == j: + continue + if ( + lst[j].kpi_name == lst[i].kpi_name + and lst[j].year == lst[i].year + and (lst[j].score > lst[i].score or (lst[j].score == lst[i].score and j > i)) + ): + better_kpi_exists = True + break + keep[i] = not better_kpi_exists + + res = [] + for i in range(len(lst)): + if keep[i]: + res.append(lst[i]) + + return res + + @staticmethod + def remove_bad_scores(lst, minimum_score): + # return lst # Dont remoe anything (for debug purposes only) + + res = [] + + max_score = {} + + for k in lst: + max_score[k.kpi_name] = k.score if not k.kpi_name in max_score else max(k.score, max_score[k.kpi_name]) + + for k in lst: + if k.score >= minimum_score and k.score >= max_score[k.kpi_name] * 0.75: + res.append(k) + + return res + + @staticmethod + def remove_bad_years(lst, default_year): + # do we have entries with a year? + year_exist = [] + + for k in lst: + if k.year != -1: + year_exist.append(k.kpi_name) + break + + res = [] + year_exist = list(set(year_exist)) + + for k in lst: + if k.year == -1: + if not k.kpi_name in year_exist: + k.year = default_year + res.append(k) + + else: + res.append(k) + + return res diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPIResultSet.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPIResultSet.py index d270ae3..8d64e79 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPIResultSet.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPIResultSet.py @@ -10,94 +10,85 @@ from KPIMeasure import * from ConsoleTable import * + class KPIResultSet: + kpimeasures = None + + def __init__(self, kpimeasures=[]): + self.kpimeasures = kpimeasures + + def extend(self, kpiresultset): + self.kpimeasures.extend(kpiresultset.kpimeasures) + + def to_ctab(self): + ctab = ConsoleTable(13) + ctab.cells.append("KPI_ID") + ctab.cells.append("KPI_NAME") + ctab.cells.append("SRC_FILE") + ctab.cells.append("PAGE_NUM") + ctab.cells.append("ITEM_IDS") + ctab.cells.append("POS_X") + ctab.cells.append("POS_Y") + ctab.cells.append("RAW_TXT") + ctab.cells.append("YEAR") + ctab.cells.append("VALUE") + ctab.cells.append("SCORE") + ctab.cells.append("UNIT") + ctab.cells.append("MATCH_TYPE") + + for k in self.kpimeasures: + ctab.cells.append(str(k.kpi_id)) + ctab.cells.append(str(k.kpi_name)) + ctab.cells.append(str(k.src_file)) + ctab.cells.append(str(k.page_num)) + ctab.cells.append(str(k.item_ids)) + ctab.cells.append(str(k.pos_x)) + ctab.cells.append(str(k.pos_y)) + ctab.cells.append(str(k.raw_txt)) + ctab.cells.append(str(k.year)) + ctab.cells.append(str(k.value)) + ctab.cells.append(str(k.score)) + ctab.cells.append(str(k.unit)) + ctab.cells.append(str(k.match_type)) + + return ctab + + def to_string(self, max_width, min_col_width): + ctab = self.to_ctab() + return ctab.to_string(max_width, min_col_width) + + def to_json(self): + jsonpickle.set_preferred_backend("json") + jsonpickle.set_encoder_options("json", sort_keys=True, indent=4) + data = jsonpickle.encode(self) + + return data + + def save_to_file(self, json_file): + data = self.to_json() + f = open(json_file, "w") + f.write(data) + f.close() + + def save_to_csv_file(self, csv_file): + ctab = self.to_ctab() + csv_str = ctab.to_string(use_format=ConsoleTable.FORMAT_CSV) + + f = open(csv_file, "w", encoding="utf-8") + f.write(csv_str) + f.close() + + @staticmethod + def load_from_json(data): + obj = jsonpickle.decode(data) + return obj + + @staticmethod + def load_from_file(json_file): + f = open(json_file, "r") + data = f.read() + f.close() + return KPIResultSet.load_from_json(data) - kpimeasures = None - - def __init__(self, kpimeasures = []): - self.kpimeasures = kpimeasures - - def extend(self, kpiresultset): - self.kpimeasures.extend(kpiresultset.kpimeasures) - - def to_ctab(self): - ctab = ConsoleTable(13) - ctab.cells.append('KPI_ID') - ctab.cells.append('KPI_NAME') - ctab.cells.append('SRC_FILE') - ctab.cells.append('PAGE_NUM') - ctab.cells.append('ITEM_IDS') - ctab.cells.append('POS_X') - ctab.cells.append('POS_Y') - ctab.cells.append('RAW_TXT') - ctab.cells.append('YEAR') - ctab.cells.append('VALUE') - ctab.cells.append('SCORE') - ctab.cells.append('UNIT') - ctab.cells.append('MATCH_TYPE') - - for k in self.kpimeasures: - ctab.cells.append(str(k.kpi_id )) - ctab.cells.append(str(k.kpi_name )) - ctab.cells.append(str(k.src_file )) - ctab.cells.append(str(k.page_num )) - ctab.cells.append(str(k.item_ids )) - ctab.cells.append(str(k.pos_x )) - ctab.cells.append(str(k.pos_y )) - ctab.cells.append(str(k.raw_txt )) - ctab.cells.append(str(k.year )) - ctab.cells.append(str(k.value )) - ctab.cells.append(str(k.score )) - ctab.cells.append(str(k.unit )) - ctab.cells.append(str(k.match_type )) - - return ctab - - def to_string(self, max_width, min_col_width): - ctab = self.to_ctab() - return ctab.to_string(max_width, min_col_width) - - - - def to_json(self): - jsonpickle.set_preferred_backend('json') - jsonpickle.set_encoder_options('json', sort_keys=True, indent=4) - data = jsonpickle.encode(self) - - return data - - def save_to_file(self, json_file): - data = self.to_json() - f = open(json_file, "w") - f.write(data) - f.close() - - def save_to_csv_file(self, csv_file): - ctab = self.to_ctab() - csv_str = ctab.to_string(use_format = ConsoleTable.FORMAT_CSV) - - f = open(csv_file, "w", encoding="utf-8") - f.write(csv_str) - f.close() - - - - @staticmethod - def load_from_json(data): - obj = jsonpickle.decode(data) - return obj - - @staticmethod - def load_from_file(json_file): - f = open(json_file, "r") - data = f.read() - f.close() - return KPIResultSet.load_from_json(data) - - - - def __repr__(self): - - return self.to_string(120, 5) - - \ No newline at end of file + def __repr__(self): + return self.to_string(120, 5) diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPISpecs.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPISpecs.py index bf3790a..c230622 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPISpecs.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/KPISpecs.py @@ -12,358 +12,407 @@ # Matching modes: -MATCHING_MUST_INCLUDE = 0 # no match, if not included -MATCHING_MUST_INCLUDE_EACH_NODE = 4 # must be includd in each node, otherwise no match. Note: this is even more strict then MATCHING_MUST_INCLUDE +MATCHING_MUST_INCLUDE = 0 # no match, if not included +MATCHING_MUST_INCLUDE_EACH_NODE = ( + 4 # must be includd in each node, otherwise no match. Note: this is even more strict then MATCHING_MUST_INCLUDE +) -MATCHING_MAY_INCLUDE = 1 # should be included, but inclusion is not neccessary, altough AT LEAST ONE such item must be included. can also have a negative score -MATCHING_CAN_INCLUDE = 2 # should be included, but inclusion is not neccessary. can also have a negative score -MATCHING_MUST_EXCLUDE = 3 # no match, if included +MATCHING_MAY_INCLUDE = 1 # should be included, but inclusion is not neccessary, altough AT LEAST ONE such item must be included. can also have a negative score +MATCHING_CAN_INCLUDE = 2 # should be included, but inclusion is not neccessary. can also have a negative score +MATCHING_MUST_EXCLUDE = 3 # no match, if included # Distance modes -DISTANCE_EUCLIDIAN = 0 # euclidian distance -DISTANCE_MOD_EUCLID = 1 # euclidian distance, but with modification such that orthogonal aligned objects are given a smaller distance (thus prefering table-like alignments) -DISTANCE_MOD_EUCLID_UP_LEFT = 2 # like 1, but we prefer looking upwards and to the left, which conforms to the typical strucutre of a table -DISTANCE_MOD_EUCLID_UP_ONLY = 3 # like 1, but we strictly enforce only looking upwards (else, score=0) +DISTANCE_EUCLIDIAN = 0 # euclidian distance +DISTANCE_MOD_EUCLID = 1 # euclidian distance, but with modification such that orthogonal aligned objects are given a smaller distance (thus prefering table-like alignments) +DISTANCE_MOD_EUCLID_UP_LEFT = ( + 2 # like 1, but we prefer looking upwards and to the left, which conforms to the typical strucutre of a table +) +DISTANCE_MOD_EUCLID_UP_ONLY = 3 # like 1, but we strictly enforce only looking upwards (else, score=0) # Percentage Matching -VALUE_PERCENTAGE_DONT_CARE = 0 -VALUE_PERCENTAGE_MUST = 1 -VALUE_PERCENTAGE_MUST_NOT = 2 +VALUE_PERCENTAGE_DONT_CARE = 0 +VALUE_PERCENTAGE_MUST = 1 +VALUE_PERCENTAGE_MUST_NOT = 2 class KPISpecs: - # This class contains specifications for one KPI that should be extracted - - class DescRegExMatch: # regex matcher for description of KPI - pattern_raw = None - pattern_regex = None - score = None # must always be > 0 - matching_mode = None - score_decay = None # Should be between 0 and 1. It determines how fast the score decays, when we traverse into more distant nodes. 1=No decay, 0=Full decay after first node - case_sensitive = None - multi_match_decay = None # If a pattern is hit multiple times, the score will decay after each hit. 1=No decay, 0=Full decay after first node - letter_decay = None # Generally, we prefer shorter texts over longer ones, because they contain less distractive garbage. So for each letter, the score decays. 1=No decay, 0 =Immediate, full decay - letter_decay_disregard = None # this number of letters will not be affected by decay at all - count_if_matched = None # if this is TRUE (default), then a match will be counted, in cases where a single match is needed - allow_matching_against_concat_txt = None # if this is TRUE, then we will try to match against a concatenation of all nodes. Default: FALSE - - def __init__(self, pattern_raw, score, matching_mode, score_decay, case_sensitive, multi_match_decay, letter_decay_hl, letter_decay_disregard = 0, count_if_matched = True, allow_matching_against_concat_txt = False): #specify half-life of letter decay! 0 =No decay - self.pattern_raw = pattern_raw - self.score = score - self.matching_mode = matching_mode - self.score_decay = score_decay - self.case_sensitive = case_sensitive - self.pattern_regex = re.compile(pattern_raw) # if case_sensitive else pattern_raw.lower()) # Note: using lower-case here would destory regexp patterns like \S - self.multi_match_decay = multi_match_decay - self.letter_decay = 0.5 ** (1.0 / letter_decay_hl) if letter_decay_hl > 0 else 1 - self.letter_decay_disregard = letter_decay_disregard - self.count_if_matched = count_if_matched - self.allow_matching_against_concat_txt = allow_matching_against_concat_txt - - def match_single_node(self, txt): - return True if self.pattern_regex.match(txt if self.case_sensitive else txt.lower()) else False #b/c regexp matcher returns not just a boolean value - - def match_nodes(self, txt_nodes): #check if nodes are matched by this, and if yes return True togehter with score - matched = False - final_score = 0 - num_hits = 0 - concat_match = False - concat_final_score = 0 - - if(self.allow_matching_against_concat_txt): - concat_txt = ' '.join(txt_nodes) - if(self.match_single_node(concat_txt)): - concat_final_score = self.score * (self.letter_decay ** max(len(Format_Analyzer.cleanup_text(concat_txt)) - self.letter_decay_disregard, 0) ) - concat_match = True - - for i in range(len(txt_nodes)): - if(self.match_single_node(txt_nodes[i])): - if(self.matching_mode == MATCHING_MUST_EXCLUDE): - return False, -1 # we matched something that must not be included - matched = True - #print_verbose(7, '..............txt="' + str(txt_nodes[i])+ '", len_txt=' + str(len(Format_Analyzer.cleanup_text(txt_nodes[i]))) + ', len_disr= ' + str(self.letter_decay_disregard)) - final_score += self.score * (self.score_decay ** i) * (self.multi_match_decay ** num_hits) * \ - (self.letter_decay ** max(len(Format_Analyzer.cleanup_text(txt_nodes[i])) - self.letter_decay_disregard, 0) ) - num_hits += 1 - else: - if(self.matching_mode == MATCHING_MUST_INCLUDE_EACH_NODE): - if(not concat_match): - return False, 0 - - - if(self.matching_mode == MATCHING_MUST_INCLUDE and not matched and not concat_match): - return False, 0 # something must be included was never matched - - - return True, max(concat_final_score, final_score) - - - class GeneralRegExMatch: # regex matcher for value or unit of KPI - pattern_raw = None - pattern_regex = None - case_sensitive = None - - def __init__(self, pattern_raw, case_sensitive): - self.pattern_raw = pattern_raw - self.case_sensitive = case_sensitive - self.pattern_regex = re.compile(pattern_raw if case_sensitive else pattern_raw.upper()) - - def match(self, txt): - return True if self.pattern_regex.match(txt if self.case_sensitive else txt.upper()) else False #b/c regexp matcher returns not just a boolean value - - class AnywhereRegExMatch: #regex matcher for text anywhere on the page near the KPI - general_match = None - distance_mode = None - score = None # must always be > 0 - matching_mode = None - score_decay = None # Should be between 0 and 1. It determines how fast the score decays, when we reach more distant items. 1=No decay, 0=Full decay after 1 px - multi_match_decay = None # If a pattern is hit multiple times, the score will decay after each hit. 1=No decay, 0=Full decay after first hit - letter_decay = None # Generally, we prefer shorter texts over longer ones, because they contain less distractive garbage. So for each letter, the score decays. 1=No decay, 0 =Immediate, full decay - letter_decay_disregard = None # this number of letters will not be affected by decay at all - - - - def __init__(self, general_match, distance_mode, score, matching_mode, score_decay, multi_match_decay, letter_decay_hl, letter_decay_disregard = 0): - self.general_match = general_match - self.distance_mode = distance_mode - self.score = score - self.matching_mode = matching_mode - self.score_decay = score_decay - self.multi_match_decay = multi_match_decay - self.letter_decay = 0.5 ** (1.0 / letter_decay_hl) if letter_decay_hl > 0 else 1 - self.letter_decay_disregard = letter_decay_disregard - - - def calc_distance(self, a, b, threshold): - if(self.distance_mode==DISTANCE_EUCLIDIAN): - return ((b[0] - a[0])**2 + (b[1] - a[1])**2)**0.5 - if(self.distance_mode==DISTANCE_MOD_EUCLID): - penalty = 1 - if(a[1] < b[1] - threshold): - penalty = 50 # reference_point text below basepoint - if(a[0] < b[0] - threshold): - penalty = 50 if penalty==1 else 90 # reference_point text right of basepoint - dx = abs(b[0] - a[0]) - dy = abs(b[1] - a[1]) - if(dx > dy): - dx *= 0.01 #by Lei - else: - dy *= 0.01 #by Lei - return penalty * (dx*dx+dy*dy)**0.5 - if(self.distance_mode==DISTANCE_MOD_EUCLID_UP_ONLY): - penalty = 1 - if(a[1] < b[1]): - return -1 # reference_point text below basepoint - if(a[0] < b[0] - threshold): - penalty = 50 if penalty==1 else 90 # reference_point text right of basepoint - dx = abs(b[0] - a[0]) - dy = abs(b[1] - a[1]) - if(dx > dy): - dx *= 0.01 #by Lei - else: - dy *= 0.01 #by Lei - return penalty * (dx*dx+dy*dy)**0.5 - - return None # not implemented - - - def match(self, htmlpage, cur_item_idx): - taken = [False] * len(htmlpage.items) - matches = [] # list of (idx, txt, score_base) - base_rect = htmlpage.items[cur_item_idx].get_rect() - base_point = ((base_rect.x0 + base_rect.x1) * 0.5, (base_rect.y0 + base_rect.y1) * 0.5) - page_diag = (htmlpage.page_width**2 + htmlpage.page_height**2)**0.5 - page_threshold = page_diag * 0.0007 - - - for i in range(len(htmlpage.items)): - if(taken[i]): - continue - idx_list = htmlpage.explode_item_by_idx(i) - # mark as taken - for j in idx_list: - taken[j] = True - txt = htmlpage.explode_item(i) - if(self.general_match.match(txt)): - rect = Rect(9999999, 9999999, -1, -1) - for j in idx_list: - rect.grow(htmlpage.items[j].get_rect()) - reference_point = ((rect.x0 + rect.x1) * 0.5, (rect.y0 + rect.y1) * 0.5) - - dist = self.calc_distance(base_point, reference_point, page_threshold) - if(dist==-1): - continue - dist_exp = dist / (0.1 * page_diag) - - score_base = self.score * (self.score_decay ** dist_exp ) * (self.letter_decay ** max(len(Format_Analyzer.cleanup_text(txt)) - self.letter_decay_disregard, 0) ) - print_verbose(9,'..........txt:'+str(txt)+' has dist_exp='+str(dist_exp)+' and score='+str(score_base)) - - matches.append((i, txt, score_base)) - - matches = sorted(matches, key=lambda m: - m[2]) # sort desc by score_base - - if(len(matches) > 0): - print_verbose(8, 'AnywhereRegExMatch.match of item ' + str(htmlpage.items[cur_item_idx]) + ' matches with: ' + str(matches)) - - - - final_score = 0 - num_hits = 0 - for i in range(len(matches)): - if(self.matching_mode == MATCHING_MUST_EXCLUDE): - return False, -1 # we matched something that must not be included - - - final_score += matches[i][2] * (self.multi_match_decay ** num_hits) - num_hits += 1 - - if(self.matching_mode in (MATCHING_MUST_INCLUDE, MATCHING_MUST_INCLUDE_EACH_NODE) and len(matches) == 0 ): - return False, 0 # something must be included was never matched - - return True, final_score - - - - - - - kpi_id = None - kpi_name = None - desc_regex_match_list = None - value_regex_match_list = None - unit_regex_match_list = None - value_must_be_numeric = None - value_percentage_match = None - value_must_be_year = None - anywhere_regex_match_list = None - minimum_score = None - minimum_score_desc_regex = None - #value_preference = None # 1=all values are equally peferable; >1= prefer greater values; <1= prefer smaller values (not yet implemented and probably not neccessary) - - - - def __init__(self): - self.kpi_id = -1 - self.kpi_name = '' - self.desc_regex_match_list = [] - self.value_regex_match_list = [] - self.unit_regex_match_list = [] - self.value_must_be_numeric = False - self.value_must_be_year = False - self.value_percentage_match = VALUE_PERCENTAGE_DONT_CARE - self.anywhere_regex_match_list = [] - self.minimum_score = 0 - self.minimum_score_desc_regex = 0 - #self.value_preference = 1.0 - - - def has_unit(self): - return len(self.unit_regex_match_list) > 0 - - - def match_nodes(self, desc_nodes): #check if nodes are matched by this, and if yes return True togehter with score - final_score = 0 - at_least_one_match = False - bad_match = False - min_score = 0 - - - for d in self.desc_regex_match_list: - match, score = d.match_nodes(desc_nodes) - print_verbose(7,'..... matching "'+d.pattern_raw+'" => match,score='+str(match)+','+str(score)) - if(not match): - # must included, but not included. or must excluded, but included - bad_match = True - min_score = min(min_score, score) - - if(d.matching_mode in (MATCHING_MAY_INCLUDE, MATCHING_MUST_INCLUDE, MATCHING_MUST_INCLUDE_EACH_NODE) and d.count_if_matched and match and score > 0): - print_verbose(9,'............. at least one match here!') - at_least_one_match = True - final_score += score - - if(bad_match): - return False, min_score #0 = if no match, -1 if must-excluded item was matched - - if(not at_least_one_match): - return False, 0 - - return final_score >= self.minimum_score_desc_regex and final_score > 0, final_score - - - def match_unit(self, unit_str): - for u in self.unit_regex_match_list: - if(not u.match(unit_str)): - return False - return True - - - - def match_value(self, val_str): # check if extracted value is a match - if(self.value_must_be_numeric and (val_str == '' or not Format_Analyzer.looks_numeric(val_str))): - return False - - if(self.value_percentage_match == VALUE_PERCENTAGE_MUST): - if(not Format_Analyzer.looks_percentage(val_str)): - return False - - if(self.value_percentage_match == VALUE_PERCENTAGE_MUST_NOT): - if(Format_Analyzer.looks_percentage(val_str)): - return False - - if(self.value_must_be_year and not Format_Analyzer.looks_year(val_str)): - return False # this is not a year! - - for v in self.value_regex_match_list: - if(not v.match(val_str)): - return False - return True - - - def match_anywhere_on_page(self, htmlpage, cur_item_idx): - - if(len(self.anywhere_regex_match_list) == 0): - return True, 0 - - final_score = 0 - at_least_one_match = False - at_least_one_match_needed = False - bad_match = False - min_score = 0 - - - for d in self.anywhere_regex_match_list: - if(d.matching_mode in (MATCHING_MAY_INCLUDE, MATCHING_MUST_INCLUDE)): - at_least_one_match_needed = True - - match, score = d.match(htmlpage, cur_item_idx) - print_verbose(7,'..... matching anywhere "'+d.general_match.pattern_raw+'" => match,score='+str(match)+','+str(score)) - if(not match): - # must included, but not included. or must excluded, but included - bad_match = True - min_score = min(min_score, score) - - if(d.matching_mode in (MATCHING_MAY_INCLUDE, MATCHING_MUST_INCLUDE)): - at_least_one_match = True - final_score += score - - if(bad_match): - print_verbose(9,'........... bad_match') - return False, min_score #0 = if no match, -1 if must-excluded item was matched - - if(not at_least_one_match and at_least_one_match_needed): - print_verbose(9,'........... not at least one needed match found') - return False, 0 - - return final_score > 0, final_score - - - - - - def extract_value(self, val_str): - #for now just return the input - #converting to standardized numbers could be done here - return val_str - - \ No newline at end of file + # This class contains specifications for one KPI that should be extracted + + class DescRegExMatch: # regex matcher for description of KPI + pattern_raw = None + pattern_regex = None + score = None # must always be > 0 + matching_mode = None + score_decay = None # Should be between 0 and 1. It determines how fast the score decays, when we traverse into more distant nodes. 1=No decay, 0=Full decay after first node + case_sensitive = None + multi_match_decay = None # If a pattern is hit multiple times, the score will decay after each hit. 1=No decay, 0=Full decay after first node + letter_decay = None # Generally, we prefer shorter texts over longer ones, because they contain less distractive garbage. So for each letter, the score decays. 1=No decay, 0 =Immediate, full decay + letter_decay_disregard = None # this number of letters will not be affected by decay at all + count_if_matched = ( + None # if this is TRUE (default), then a match will be counted, in cases where a single match is needed + ) + allow_matching_against_concat_txt = ( + None # if this is TRUE, then we will try to match against a concatenation of all nodes. Default: FALSE + ) + + def __init__( + self, + pattern_raw, + score, + matching_mode, + score_decay, + case_sensitive, + multi_match_decay, + letter_decay_hl, + letter_decay_disregard=0, + count_if_matched=True, + allow_matching_against_concat_txt=False, + ): # specify half-life of letter decay! 0 =No decay + self.pattern_raw = pattern_raw + self.score = score + self.matching_mode = matching_mode + self.score_decay = score_decay + self.case_sensitive = case_sensitive + self.pattern_regex = re.compile( + pattern_raw + ) # if case_sensitive else pattern_raw.lower()) # Note: using lower-case here would destory regexp patterns like \S + self.multi_match_decay = multi_match_decay + self.letter_decay = 0.5 ** (1.0 / letter_decay_hl) if letter_decay_hl > 0 else 1 + self.letter_decay_disregard = letter_decay_disregard + self.count_if_matched = count_if_matched + self.allow_matching_against_concat_txt = allow_matching_against_concat_txt + + def match_single_node(self, txt): + return ( + True if self.pattern_regex.match(txt if self.case_sensitive else txt.lower()) else False + ) # b/c regexp matcher returns not just a boolean value + + def match_nodes( + self, txt_nodes + ): # check if nodes are matched by this, and if yes return True togehter with score + matched = False + final_score = 0 + num_hits = 0 + concat_match = False + concat_final_score = 0 + + if self.allow_matching_against_concat_txt: + concat_txt = " ".join(txt_nodes) + if self.match_single_node(concat_txt): + concat_final_score = self.score * ( + self.letter_decay + ** max(len(Format_Analyzer.cleanup_text(concat_txt)) - self.letter_decay_disregard, 0) + ) + concat_match = True + + for i in range(len(txt_nodes)): + if self.match_single_node(txt_nodes[i]): + if self.matching_mode == MATCHING_MUST_EXCLUDE: + return False, -1 # we matched something that must not be included + matched = True + # print_verbose(7, '..............txt="' + str(txt_nodes[i])+ '", len_txt=' + str(len(Format_Analyzer.cleanup_text(txt_nodes[i]))) + ', len_disr= ' + str(self.letter_decay_disregard)) + final_score += ( + self.score + * (self.score_decay**i) + * (self.multi_match_decay**num_hits) + * ( + self.letter_decay + ** max(len(Format_Analyzer.cleanup_text(txt_nodes[i])) - self.letter_decay_disregard, 0) + ) + ) + num_hits += 1 + else: + if self.matching_mode == MATCHING_MUST_INCLUDE_EACH_NODE: + if not concat_match: + return False, 0 + + if self.matching_mode == MATCHING_MUST_INCLUDE and not matched and not concat_match: + return False, 0 # something must be included was never matched + + return True, max(concat_final_score, final_score) + + class GeneralRegExMatch: # regex matcher for value or unit of KPI + pattern_raw = None + pattern_regex = None + case_sensitive = None + + def __init__(self, pattern_raw, case_sensitive): + self.pattern_raw = pattern_raw + self.case_sensitive = case_sensitive + self.pattern_regex = re.compile(pattern_raw if case_sensitive else pattern_raw.upper()) + + def match(self, txt): + return ( + True if self.pattern_regex.match(txt if self.case_sensitive else txt.upper()) else False + ) # b/c regexp matcher returns not just a boolean value + + class AnywhereRegExMatch: # regex matcher for text anywhere on the page near the KPI + general_match = None + distance_mode = None + score = None # must always be > 0 + matching_mode = None + score_decay = None # Should be between 0 and 1. It determines how fast the score decays, when we reach more distant items. 1=No decay, 0=Full decay after 1 px + multi_match_decay = None # If a pattern is hit multiple times, the score will decay after each hit. 1=No decay, 0=Full decay after first hit + letter_decay = None # Generally, we prefer shorter texts over longer ones, because they contain less distractive garbage. So for each letter, the score decays. 1=No decay, 0 =Immediate, full decay + letter_decay_disregard = None # this number of letters will not be affected by decay at all + + def __init__( + self, + general_match, + distance_mode, + score, + matching_mode, + score_decay, + multi_match_decay, + letter_decay_hl, + letter_decay_disregard=0, + ): + self.general_match = general_match + self.distance_mode = distance_mode + self.score = score + self.matching_mode = matching_mode + self.score_decay = score_decay + self.multi_match_decay = multi_match_decay + self.letter_decay = 0.5 ** (1.0 / letter_decay_hl) if letter_decay_hl > 0 else 1 + self.letter_decay_disregard = letter_decay_disregard + + def calc_distance(self, a, b, threshold): + if self.distance_mode == DISTANCE_EUCLIDIAN: + return ((b[0] - a[0]) ** 2 + (b[1] - a[1]) ** 2) ** 0.5 + if self.distance_mode == DISTANCE_MOD_EUCLID: + penalty = 1 + if a[1] < b[1] - threshold: + penalty = 50 # reference_point text below basepoint + if a[0] < b[0] - threshold: + penalty = 50 if penalty == 1 else 90 # reference_point text right of basepoint + dx = abs(b[0] - a[0]) + dy = abs(b[1] - a[1]) + if dx > dy: + dx *= 0.01 # by Lei + else: + dy *= 0.01 # by Lei + return penalty * (dx * dx + dy * dy) ** 0.5 + if self.distance_mode == DISTANCE_MOD_EUCLID_UP_ONLY: + penalty = 1 + if a[1] < b[1]: + return -1 # reference_point text below basepoint + if a[0] < b[0] - threshold: + penalty = 50 if penalty == 1 else 90 # reference_point text right of basepoint + dx = abs(b[0] - a[0]) + dy = abs(b[1] - a[1]) + if dx > dy: + dx *= 0.01 # by Lei + else: + dy *= 0.01 # by Lei + return penalty * (dx * dx + dy * dy) ** 0.5 + + return None # not implemented + + def match(self, htmlpage, cur_item_idx): + taken = [False] * len(htmlpage.items) + matches = [] # list of (idx, txt, score_base) + base_rect = htmlpage.items[cur_item_idx].get_rect() + base_point = ((base_rect.x0 + base_rect.x1) * 0.5, (base_rect.y0 + base_rect.y1) * 0.5) + page_diag = (htmlpage.page_width**2 + htmlpage.page_height**2) ** 0.5 + page_threshold = page_diag * 0.0007 + + for i in range(len(htmlpage.items)): + if taken[i]: + continue + idx_list = htmlpage.explode_item_by_idx(i) + # mark as taken + for j in idx_list: + taken[j] = True + txt = htmlpage.explode_item(i) + if self.general_match.match(txt): + rect = Rect(9999999, 9999999, -1, -1) + for j in idx_list: + rect.grow(htmlpage.items[j].get_rect()) + reference_point = ((rect.x0 + rect.x1) * 0.5, (rect.y0 + rect.y1) * 0.5) + + dist = self.calc_distance(base_point, reference_point, page_threshold) + if dist == -1: + continue + dist_exp = dist / (0.1 * page_diag) + + score_base = ( + self.score + * (self.score_decay**dist_exp) + * ( + self.letter_decay + ** max(len(Format_Analyzer.cleanup_text(txt)) - self.letter_decay_disregard, 0) + ) + ) + print_verbose( + 9, + "..........txt:" + + str(txt) + + " has dist_exp=" + + str(dist_exp) + + " and score=" + + str(score_base), + ) + + matches.append((i, txt, score_base)) + + matches = sorted(matches, key=lambda m: -m[2]) # sort desc by score_base + + if len(matches) > 0: + print_verbose( + 8, + "AnywhereRegExMatch.match of item " + + str(htmlpage.items[cur_item_idx]) + + " matches with: " + + str(matches), + ) + + final_score = 0 + num_hits = 0 + for i in range(len(matches)): + if self.matching_mode == MATCHING_MUST_EXCLUDE: + return False, -1 # we matched something that must not be included + + final_score += matches[i][2] * (self.multi_match_decay**num_hits) + num_hits += 1 + + if self.matching_mode in (MATCHING_MUST_INCLUDE, MATCHING_MUST_INCLUDE_EACH_NODE) and len(matches) == 0: + return False, 0 # something must be included was never matched + + return True, final_score + + kpi_id = None + kpi_name = None + desc_regex_match_list = None + value_regex_match_list = None + unit_regex_match_list = None + value_must_be_numeric = None + value_percentage_match = None + value_must_be_year = None + anywhere_regex_match_list = None + minimum_score = None + minimum_score_desc_regex = None + # value_preference = None # 1=all values are equally peferable; >1= prefer greater values; <1= prefer smaller values (not yet implemented and probably not neccessary) + + def __init__(self): + self.kpi_id = -1 + self.kpi_name = "" + self.desc_regex_match_list = [] + self.value_regex_match_list = [] + self.unit_regex_match_list = [] + self.value_must_be_numeric = False + self.value_must_be_year = False + self.value_percentage_match = VALUE_PERCENTAGE_DONT_CARE + self.anywhere_regex_match_list = [] + self.minimum_score = 0 + self.minimum_score_desc_regex = 0 + # self.value_preference = 1.0 + + def has_unit(self): + return len(self.unit_regex_match_list) > 0 + + def match_nodes(self, desc_nodes): # check if nodes are matched by this, and if yes return True togehter with score + final_score = 0 + at_least_one_match = False + bad_match = False + min_score = 0 + + for d in self.desc_regex_match_list: + match, score = d.match_nodes(desc_nodes) + print_verbose(7, '..... matching "' + d.pattern_raw + '" => match,score=' + str(match) + "," + str(score)) + if not match: + # must included, but not included. or must excluded, but included + bad_match = True + min_score = min(min_score, score) + + if ( + d.matching_mode in (MATCHING_MAY_INCLUDE, MATCHING_MUST_INCLUDE, MATCHING_MUST_INCLUDE_EACH_NODE) + and d.count_if_matched + and match + and score > 0 + ): + print_verbose(9, "............. at least one match here!") + at_least_one_match = True + final_score += score + + if bad_match: + return False, min_score # 0 = if no match, -1 if must-excluded item was matched + + if not at_least_one_match: + return False, 0 + + return final_score >= self.minimum_score_desc_regex and final_score > 0, final_score + + def match_unit(self, unit_str): + for u in self.unit_regex_match_list: + if not u.match(unit_str): + return False + return True + + def match_value(self, val_str): # check if extracted value is a match + if self.value_must_be_numeric and (val_str == "" or not Format_Analyzer.looks_numeric(val_str)): + return False + + if self.value_percentage_match == VALUE_PERCENTAGE_MUST: + if not Format_Analyzer.looks_percentage(val_str): + return False + + if self.value_percentage_match == VALUE_PERCENTAGE_MUST_NOT: + if Format_Analyzer.looks_percentage(val_str): + return False + + if self.value_must_be_year and not Format_Analyzer.looks_year(val_str): + return False # this is not a year! + + for v in self.value_regex_match_list: + if not v.match(val_str): + return False + return True + + def match_anywhere_on_page(self, htmlpage, cur_item_idx): + if len(self.anywhere_regex_match_list) == 0: + return True, 0 + + final_score = 0 + at_least_one_match = False + at_least_one_match_needed = False + bad_match = False + min_score = 0 + + for d in self.anywhere_regex_match_list: + if d.matching_mode in (MATCHING_MAY_INCLUDE, MATCHING_MUST_INCLUDE): + at_least_one_match_needed = True + + match, score = d.match(htmlpage, cur_item_idx) + print_verbose( + 7, + '..... matching anywhere "' + + d.general_match.pattern_raw + + '" => match,score=' + + str(match) + + "," + + str(score), + ) + if not match: + # must included, but not included. or must excluded, but included + bad_match = True + min_score = min(min_score, score) + + if d.matching_mode in (MATCHING_MAY_INCLUDE, MATCHING_MUST_INCLUDE): + at_least_one_match = True + final_score += score + + if bad_match: + print_verbose(9, "........... bad_match") + return False, min_score # 0 = if no match, -1 if must-excluded item was matched + + if not at_least_one_match and at_least_one_match_needed: + print_verbose(9, "........... not at least one needed match found") + return False, 0 + + return final_score > 0, final_score + + def extract_value(self, val_str): + # for now just return the input + # converting to standardized numbers could be done here + return val_str diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/Rect.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/Rect.py index 7edfadf..eb9db49 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/Rect.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/Rect.py @@ -9,82 +9,77 @@ class Rect: - x0 = None - x1 = None - y0 = None - y1 = None + x0 = None + x1 = None + y0 = None + y1 = None + def __init__(self): + self.x0 = 0 + self.x1 = 0 + self.y0 = 0 + self.y1 = 0 - def __init__(self): - self.x0 = 0 - self.x1 = 0 - self.y0 = 0 - self.y1 = 0 + def __init__(self, x0_, y0_, x1_, y1_): + self.x0 = x0_ + self.x1 = x1_ + self.y0 = y0_ + self.y1 = y1_ - def __init__(self, x0_, y0_, x1_, y1_): - self.x0 = x0_ - self.x1 = x1_ - self.y0 = y0_ - self.y1 = y1_ + def get_width(self): + return self.x1 - self.x0 - - def get_width(self): - return self.x1 - self.x0 - - def get_height(self): - return self.y1 - self.y0 - - def get_area(self): - return self.get_width() * self.get_height() - - def grow(self, rect): - self.x0 = min(self.x0, rect.x0) - self.y0 = min(self.y0, rect.y0) - self.x1 = max(self.x1, rect.x1) - self.y1 = max(self.y1, rect.y1) - - def get_center(self): - return (self.x1+self.x0) * 0.5, (self.y1+self.y0) * 0.5 - - @staticmethod - def raw_rect_distance(x1, y1, x1b, y1b, x2, y2, x2b, y2b): - #see: https://stackoverflow.com/questions/4978323/how-to-calculate-distance-between-two-rectangles-context-a-game-in-lua - left = x2b < x1 - right = x1b < x2 - bottom = y2b < y1 - top = y1b < y2 - if top and left: - return dist(x1, y1b, x2b, y2) - elif left and bottom: - return dist(x1, y1, x2b, y2b) - elif bottom and right: - return dist(x1b, y1, x2, y2b) - elif right and top: - return dist(x1b, y1b, x2, y2) - elif left: - return x1 - x2b - elif right: - return x2 - x1b - elif bottom: - return y1 - y2b - elif top: - return y2 - y1b - else: # rectangles intersect - return 0. - - - @staticmethod - def calc_intersection_area(r1, r2): - return max(0, min(r1.x1, r2.x1) - max(r1.x0, r2.x0)) * max(0, min(r1.y1, r2.y1) - max(r1.y0, r2.y0)) - - @staticmethod - def distance(r1, r2): - return Rect.raw_rect_distance(r1.x0, r1.y0, r1.x1, r1.y1, r2.x0, r2.y0, r2.x1, r2.y1) - - def get_coordinates(self): - return (self.x0,self.y0) - - def __repr__(self): - return "" + def get_height(self): + return self.y1 - self.y0 + def get_area(self): + return self.get_width() * self.get_height() + def grow(self, rect): + self.x0 = min(self.x0, rect.x0) + self.y0 = min(self.y0, rect.y0) + self.x1 = max(self.x1, rect.x1) + self.y1 = max(self.y1, rect.y1) + + def get_center(self): + return (self.x1 + self.x0) * 0.5, (self.y1 + self.y0) * 0.5 + + @staticmethod + def raw_rect_distance(x1, y1, x1b, y1b, x2, y2, x2b, y2b): + # see: https://stackoverflow.com/questions/4978323/how-to-calculate-distance-between-two-rectangles-context-a-game-in-lua + left = x2b < x1 + right = x1b < x2 + bottom = y2b < y1 + top = y1b < y2 + if top and left: + return dist(x1, y1b, x2b, y2) + elif left and bottom: + return dist(x1, y1, x2b, y2b) + elif bottom and right: + return dist(x1b, y1, x2, y2b) + elif right and top: + return dist(x1b, y1b, x2, y2) + elif left: + return x1 - x2b + elif right: + return x2 - x1b + elif bottom: + return y1 - y2b + elif top: + return y2 - y1b + else: # rectangles intersect + return 0.0 + + @staticmethod + def calc_intersection_area(r1, r2): + return max(0, min(r1.x1, r2.x1) - max(r1.x0, r2.x0)) * max(0, min(r1.y1, r2.y1) - max(r1.y0, r2.y0)) + + @staticmethod + def distance(r1, r2): + return Rect.raw_rect_distance(r1.x0, r1.y0, r1.x1, r1.y1, r2.x0, r2.y0, r2.x1, r2.y1) + + def get_coordinates(self): + return (self.x0, self.y0) + + def __repr__(self): + return "" diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestData.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestData.py index 21dcde5..75cdb65 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestData.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestData.py @@ -12,275 +12,247 @@ class TestData: + samples = None - samples = None - - SRC_FILE_FORMAT_AUTO = 0 - SRC_FILE_FORMAT_OLD = 1 - SRC_FILE_FORMAT_NEW = 2 - - def __init__(self): - self.samples = [] - - def filter_kpis(self, by_kpi_id = None, by_data_type = None, by_source_file = None, by_has_fixed_source_file = False): - samples_new = [] - for s in self.samples: - keep = True - if(by_kpi_id is not None and s.data_kpi_id not in by_kpi_id): - keep = False - if(by_data_type is not None and s.data_data_type not in by_data_type): - keep = False - if(by_has_fixed_source_file and s.fixed_source_file is None): - keep = False - if(by_source_file is not None and s.data_source_file not in by_source_file): - keep = False - - if(keep): - samples_new.append(s) - - self.samples = samples_new - - - def get_pdf_list(self): - res = [] - for s in self.samples: - res.append(s.data_source_file) - res = list(set(res)) - res = sorted(res, key=lambda s: s.lower()) - return res - - def get_fixed_pdf_list(self): - res = [] - for s in self.samples: - res.append(s.fixed_source_file) - res = list(set(res)) - res = sorted(res, key=lambda s: s.lower()) - return res - - - def fix_file_names(self, fix_list): - for i in range(len(self.samples)): - for f in fix_list: - if(self.samples[i].data_source_file == f[0]): - self.samples[i].fixed_source_file = f[1] - break - - - - def load_from_csv(self, src_file_path, src_file_format = SRC_FILE_FORMAT_AUTO): - - - raw_data = '' - - def read_next_cell(p): - p0 = -1 - p1 = -1 - p2 = -1 - #print("====>> p = "+str(p)) - if(raw_data[p:(p+4)] == '"[""'): - p0 = p + 4 - p1 = raw_data.find('""]"', p+1) - p2 = p1 + 4 - elif(raw_data[p] == '"'): - p0 = p+1 - p_cur = p0 - while(True): - p1 = raw_data.find('"', p_cur) - if(raw_data[p1+1]!='"'): - break - p_cur = p1 + 2 - - p2 = p1 + 1 - else: - p0 = p - p2_a = raw_data.find(',' if src_file_format == TestData.SRC_FILE_FORMAT_OLD else ';' ,p) - p2_b = raw_data.find('\n',p) - if(p2_a == -1): - p2 = p2_b - elif(p2_b == -1): - p2 = p2_a - else: - p2 = min(p2_a, p2_b) - - p1 = p2 - #print("===>> p1="+str(p1)) - - if(p1 == -1 or raw_data[p2] not in (',' if src_file_format == TestData.SRC_FILE_FORMAT_OLD else ';', '\n')): - raise ValueError('No cell delimiter detected after position ' +str(p) + ' at "'+raw_data[p:p+20]+'..."') - - - cell_data = raw_data[p0:p1].replace('\n', ' ') - #print("===>>>" + cell_data) - - return cell_data, p2+1, raw_data[p2] == '\n' - - def read_next_row(p, n): - res = [] - for i in range(n): - cell_data, p, is_at_end = read_next_cell(p) - if(i==n-1): - if(not is_at_end): - raise ValueError('Row has not ended after position ' +str(p) + ' at "'+raw_data[p:p+20]+'..."') - else: - if(is_at_end): - raise ValueError('Row has ended too early after position ' +str(p) + ' at "'+raw_data[p:p+20]+'..."') - res.append(cell_data) - - #print('==>> next row starts at pos '+str(p) + ' at "'+raw_data[p:p+20]+'..."') - return res, p - - - if(src_file_format == TestData.SRC_FILE_FORMAT_AUTO): - try: - #try old format: - print_verbose(2, 'Trying old csv format') - return self.load_from_csv(src_file_path, TestData.SRC_FILE_FORMAT_OLD) - except ValueError: - #try new format: - print_verbose(2, 'Trying new csv format') - return self.load_from_csv(src_file_path, TestData.SRC_FILE_FORMAT_NEW) - - - - self.samples = [] - - with open(src_file_path, errors='ignore', encoding="ascii") as f: - data_lines = f.readlines() - - - #print(len(data_lines)) - - for i in range(len(data_lines)): - data_lines[i] = data_lines[i].replace('\n', '') - - raw_data = '\n'.join(data_lines[1:]) + '\n' - - # current format in sample csv file (old-format): - # Number,Sector,Unit,answer,"comments, questions",company,data_type,irrelevant_paragraphs,kpi_id,relevant_paragraphs,sector,source_file,source_page,year - - # and for new format: - # Number;company;source_file;source_page;kpi_id;year;answer;data_type;relevant_paragraphs;annotator;sector;comments - - p = 0 - - while(p < len(raw_data)): - - if(src_file_format == TestData.SRC_FILE_FORMAT_OLD): - # parse next row - row_data, p = read_next_row(p, 14) - #print(row_data) - - year = Format_Analyzer.to_int_number(row_data[13], 4) - if(not Format_Analyzer.looks_year(str(year))): - raise ValueError('Found invalid year "' +str(year) +'" at row ' + str(row_data)) - - sample = TestDataSample() - sample.data_number = Format_Analyzer.to_int_number(row_data[0]) #0 - sample.data_sector = row_data[1] #'' - sample.data_unit = row_data[2] #'' - sample.data_answer = row_data[3] #'' - sample.data_comments_questions = row_data[4] #'' - sample.data_company = row_data[5] #'' - sample.data_data_type = row_data[6] #'' - sample.data_irrelevant_paragraphs = row_data[7] #'' - sample.data_kpi_id = Format_Analyzer.to_int_number(row_data[8]) #0 - sample.data_relevant_paragraphs = row_data[9] #'' - sample.data_sector = row_data[10] #'' - sample.data_source_file = row_data[11] #'' - sample.data_source_page = Format_Analyzer.to_int_number(row_data[12]) #0 - sample.data_year = year #0 - - self.samples.append(sample) - - if(src_file_format == TestData.SRC_FILE_FORMAT_NEW): - # parse next row - row_data, p = read_next_row(p, 12) - #print(row_data) - - year = Format_Analyzer.to_int_number(row_data[5], 4) - if(not Format_Analyzer.looks_year(str(year))): - raise ValueError('Found invalid year "' +str(year) +'" at row ' + str(row_data)) - - sample = TestDataSample() - sample.data_number = Format_Analyzer.to_int_number(row_data[0]) #0 - sample.data_sector = row_data[10] #'' - sample.data_unit = 'N/A' - sample.data_answer = row_data[6] #'' - sample.data_comments_questions = row_data[11] #'' - sample.data_company = row_data[1] #'' - sample.data_data_type = row_data[7] #'' - sample.data_irrelevant_paragraphs = 'N/A' - sample.data_kpi_id = Format_Analyzer.to_float_number(row_data[4]) #0 - sample.data_relevant_paragraphs = row_data[8] #'' - sample.data_source_file = row_data[2] #'' - sample.data_source_page = Format_Analyzer.to_int_number(row_data[3]) #0 - sample.data_year = year #0 - - self.samples.append(sample) - - def generate_dummy_test_data(self, pdf_folder, filter = '*'): - def ext(f): - res = [f] - res.extend(Format_Analyzer.extract_file_path(f)) - return res - - file_paths = glob.glob(pdf_folder + '/**/' + filter + '.pdf', recursive=True) - file_paths = [ext(f.replace('\\','/')) for f in file_paths] #unixize all file paths - - cnt = 0 - for f in file_paths: - fname = f[2] + '.' + f[3] - - if(fname != Format_Analyzer.cleanup_filename(fname)): - print("Warning: Bad filename: '" + fname + "' - this file will be skipped") - continue - - - sample = TestDataSample() - sample.data_number = cnt - sample.data_sector = 'N/A' - sample.data_unit = 'N/A' - sample.data_answer = 'N/A' - sample.data_comments_questions = 'N/A' - sample.data_company = 'N/A' - sample.data_data_type = 'N/A' - sample.data_irrelevant_paragraphs = 'N/A' - sample.data_kpi_id = 0 - sample.data_relevant_paragraphs = 'N/A' - sample.data_sector = 'N/A' - sample.data_source_file = fname - sample.fixed_source_file = fname - sample.data_source_page = 0 - sample.data_year = 1900 - - self.samples.append(sample) - - cnt += 1 - - DataImportExport.save_info_file_contents(file_paths) - - - - - def save_to_csv(self, dst_file_path): - save_txt_to_file(TestDataSample.samples_to_csv(self.samples), dst_file_path) - - - - - - - - - - def __repr__(self): - return TestDataSample.samples_to_string(self.samples) - - - - - - - - - \ No newline at end of file + SRC_FILE_FORMAT_AUTO = 0 + SRC_FILE_FORMAT_OLD = 1 + SRC_FILE_FORMAT_NEW = 2 + + def __init__(self): + self.samples = [] + + def filter_kpis(self, by_kpi_id=None, by_data_type=None, by_source_file=None, by_has_fixed_source_file=False): + samples_new = [] + for s in self.samples: + keep = True + if by_kpi_id is not None and s.data_kpi_id not in by_kpi_id: + keep = False + if by_data_type is not None and s.data_data_type not in by_data_type: + keep = False + if by_has_fixed_source_file and s.fixed_source_file is None: + keep = False + if by_source_file is not None and s.data_source_file not in by_source_file: + keep = False + + if keep: + samples_new.append(s) + + self.samples = samples_new + + def get_pdf_list(self): + res = [] + for s in self.samples: + res.append(s.data_source_file) + res = list(set(res)) + res = sorted(res, key=lambda s: s.lower()) + return res + + def get_fixed_pdf_list(self): + res = [] + for s in self.samples: + res.append(s.fixed_source_file) + res = list(set(res)) + res = sorted(res, key=lambda s: s.lower()) + return res + + def fix_file_names(self, fix_list): + for i in range(len(self.samples)): + for f in fix_list: + if self.samples[i].data_source_file == f[0]: + self.samples[i].fixed_source_file = f[1] + break + + def load_from_csv(self, src_file_path, src_file_format=SRC_FILE_FORMAT_AUTO): + raw_data = "" + + def read_next_cell(p): + p0 = -1 + p1 = -1 + p2 = -1 + # print("====>> p = "+str(p)) + if raw_data[p : (p + 4)] == '"[""': + p0 = p + 4 + p1 = raw_data.find('""]"', p + 1) + p2 = p1 + 4 + elif raw_data[p] == '"': + p0 = p + 1 + p_cur = p0 + while True: + p1 = raw_data.find('"', p_cur) + if raw_data[p1 + 1] != '"': + break + p_cur = p1 + 2 + + p2 = p1 + 1 + else: + p0 = p + p2_a = raw_data.find("," if src_file_format == TestData.SRC_FILE_FORMAT_OLD else ";", p) + p2_b = raw_data.find("\n", p) + if p2_a == -1: + p2 = p2_b + elif p2_b == -1: + p2 = p2_a + else: + p2 = min(p2_a, p2_b) + + p1 = p2 + # print("===>> p1="+str(p1)) + + if p1 == -1 or raw_data[p2] not in ("," if src_file_format == TestData.SRC_FILE_FORMAT_OLD else ";", "\n"): + raise ValueError( + "No cell delimiter detected after position " + str(p) + ' at "' + raw_data[p : p + 20] + '..."' + ) + + cell_data = raw_data[p0:p1].replace("\n", " ") + # print("===>>>" + cell_data) + + return cell_data, p2 + 1, raw_data[p2] == "\n" + + def read_next_row(p, n): + res = [] + for i in range(n): + cell_data, p, is_at_end = read_next_cell(p) + if i == n - 1: + if not is_at_end: + raise ValueError( + "Row has not ended after position " + str(p) + ' at "' + raw_data[p : p + 20] + '..."' + ) + else: + if is_at_end: + raise ValueError( + "Row has ended too early after position " + str(p) + ' at "' + raw_data[p : p + 20] + '..."' + ) + res.append(cell_data) + + # print('==>> next row starts at pos '+str(p) + ' at "'+raw_data[p:p+20]+'..."') + return res, p + + if src_file_format == TestData.SRC_FILE_FORMAT_AUTO: + try: + # try old format: + print_verbose(2, "Trying old csv format") + return self.load_from_csv(src_file_path, TestData.SRC_FILE_FORMAT_OLD) + except ValueError: + # try new format: + print_verbose(2, "Trying new csv format") + return self.load_from_csv(src_file_path, TestData.SRC_FILE_FORMAT_NEW) + + self.samples = [] + + with open(src_file_path, errors="ignore", encoding="ascii") as f: + data_lines = f.readlines() + + # print(len(data_lines)) + + for i in range(len(data_lines)): + data_lines[i] = data_lines[i].replace("\n", "") + + raw_data = "\n".join(data_lines[1:]) + "\n" + + # current format in sample csv file (old-format): + # Number,Sector,Unit,answer,"comments, questions",company,data_type,irrelevant_paragraphs,kpi_id,relevant_paragraphs,sector,source_file,source_page,year + + # and for new format: + # Number;company;source_file;source_page;kpi_id;year;answer;data_type;relevant_paragraphs;annotator;sector;comments + + p = 0 + + while p < len(raw_data): + if src_file_format == TestData.SRC_FILE_FORMAT_OLD: + # parse next row + row_data, p = read_next_row(p, 14) + # print(row_data) + + year = Format_Analyzer.to_int_number(row_data[13], 4) + if not Format_Analyzer.looks_year(str(year)): + raise ValueError('Found invalid year "' + str(year) + '" at row ' + str(row_data)) + + sample = TestDataSample() + sample.data_number = Format_Analyzer.to_int_number(row_data[0]) # 0 + sample.data_sector = row_data[1] #'' + sample.data_unit = row_data[2] #'' + sample.data_answer = row_data[3] #'' + sample.data_comments_questions = row_data[4] #'' + sample.data_company = row_data[5] #'' + sample.data_data_type = row_data[6] #'' + sample.data_irrelevant_paragraphs = row_data[7] #'' + sample.data_kpi_id = Format_Analyzer.to_int_number(row_data[8]) # 0 + sample.data_relevant_paragraphs = row_data[9] #'' + sample.data_sector = row_data[10] #'' + sample.data_source_file = row_data[11] #'' + sample.data_source_page = Format_Analyzer.to_int_number(row_data[12]) # 0 + sample.data_year = year # 0 + + self.samples.append(sample) + + if src_file_format == TestData.SRC_FILE_FORMAT_NEW: + # parse next row + row_data, p = read_next_row(p, 12) + # print(row_data) + + year = Format_Analyzer.to_int_number(row_data[5], 4) + if not Format_Analyzer.looks_year(str(year)): + raise ValueError('Found invalid year "' + str(year) + '" at row ' + str(row_data)) + + sample = TestDataSample() + sample.data_number = Format_Analyzer.to_int_number(row_data[0]) # 0 + sample.data_sector = row_data[10] #'' + sample.data_unit = "N/A" + sample.data_answer = row_data[6] #'' + sample.data_comments_questions = row_data[11] #'' + sample.data_company = row_data[1] #'' + sample.data_data_type = row_data[7] #'' + sample.data_irrelevant_paragraphs = "N/A" + sample.data_kpi_id = Format_Analyzer.to_float_number(row_data[4]) # 0 + sample.data_relevant_paragraphs = row_data[8] #'' + sample.data_source_file = row_data[2] #'' + sample.data_source_page = Format_Analyzer.to_int_number(row_data[3]) # 0 + sample.data_year = year # 0 + + self.samples.append(sample) + + def generate_dummy_test_data(self, pdf_folder, filter="*"): + def ext(f): + res = [f] + res.extend(Format_Analyzer.extract_file_path(f)) + return res + + file_paths = glob.glob(pdf_folder + "/**/" + filter + ".pdf", recursive=True) + file_paths = [ext(f.replace("\\", "/")) for f in file_paths] # unixize all file paths + + cnt = 0 + for f in file_paths: + fname = f[2] + "." + f[3] + + if fname != Format_Analyzer.cleanup_filename(fname): + print("Warning: Bad filename: '" + fname + "' - this file will be skipped") + continue + + sample = TestDataSample() + sample.data_number = cnt + sample.data_sector = "N/A" + sample.data_unit = "N/A" + sample.data_answer = "N/A" + sample.data_comments_questions = "N/A" + sample.data_company = "N/A" + sample.data_data_type = "N/A" + sample.data_irrelevant_paragraphs = "N/A" + sample.data_kpi_id = 0 + sample.data_relevant_paragraphs = "N/A" + sample.data_sector = "N/A" + sample.data_source_file = fname + sample.fixed_source_file = fname + sample.data_source_page = 0 + sample.data_year = 1900 + + self.samples.append(sample) + + cnt += 1 + + DataImportExport.save_info_file_contents(file_paths) + + def save_to_csv(self, dst_file_path): + save_txt_to_file(TestDataSample.samples_to_csv(self.samples), dst_file_path) + + def __repr__(self): + return TestDataSample.samples_to_string(self.samples) diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestDataSample.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestDataSample.py index 4fe0654..26a68f7 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestDataSample.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestDataSample.py @@ -8,114 +8,106 @@ from ConsoleTable import * from Format_Analyzer import * + class TestDataSample: + # current format in sample csv file: + # Number,Sector,Unit,answer,"comments, questions",company,data_type,irrelevant_paragraphs,kpi_id,relevant_paragraphs,sector,source_file,source_page,year + + data_number = None + data_sector = None + data_unit = None + data_answer = None + data_comments_questions = None + data_company = None + data_data_type = None + data_irrelevant_paragraphs = None + data_kpi_id = None + data_relevant_paragraphs = None + data_sector = None + data_source_file = None + data_source_page = None + data_year = None + fixed_source_file = None + + def __init__(self): + self.data_number = 0 + self.data_sector = "" + self.data_unit = "" + self.data_answer = "" + self.data_comments_questions = "" + self.data_company = "" + self.data_data_type = "" + self.data_irrelevant_paragraphs = "" + self.data_kpi_id = 0 + self.data_relevant_paragraphs = "" + self.data_sector = "" + self.data_source_file = "" + self.data_source_page = 0 + self.data_year = 0 + self.fixed_source_file = None + + @staticmethod + def samples_to_string(lst, max_width=140, min_col_width=5): + ctab = ConsoleTable(14) + ctab.cells.append("NUMBER") + ctab.cells.append("SECTOR") + ctab.cells.append("UNIT") + ctab.cells.append("ANSWER") + ctab.cells.append("COMMENTS") + ctab.cells.append("COMPANY") + ctab.cells.append("DATA_TYPE") + ctab.cells.append("IRREL_PARAG") + ctab.cells.append("KPI_ID") + ctab.cells.append("RELEV_PARAG") + ctab.cells.append("SECTOR") + ctab.cells.append("SOURCE_FILE") + ctab.cells.append("SOURCE_PAGE") + ctab.cells.append("YEAR") + + for k in lst: + ctab.cells.append(str(k.data_number)) + ctab.cells.append(str(k.data_sector)) + ctab.cells.append(str(k.data_unit)) + ctab.cells.append(str(k.data_answer)) + ctab.cells.append(str(k.data_comments_questions)) + ctab.cells.append(str(k.data_company)) + ctab.cells.append(str(k.data_data_type)) + ctab.cells.append(str(k.data_irrelevant_paragraphs)) + ctab.cells.append(str(k.data_kpi_id)) + ctab.cells.append(str(k.data_relevant_paragraphs)) + ctab.cells.append(str(k.data_sector)) + ctab.cells.append(str(k.data_source_file)) + ctab.cells.append(str(k.data_source_page)) + ctab.cells.append(str(k.data_year)) + + return ctab.to_string(max_width, min_col_width) + + @staticmethod + def samples_to_csv(lst): + def escape(txt): + txt = txt.replace("\n", "") + txt = txt.replace("\r", "") + txt = txt.replace('"', '""') + return '"' + Format_Analyzer.trim_whitespaces(txt) + '"' - # current format in sample csv file: - # Number,Sector,Unit,answer,"comments, questions",company,data_type,irrelevant_paragraphs,kpi_id,relevant_paragraphs,sector,source_file,source_page,year - - data_number = None - data_sector = None - data_unit = None - data_answer = None - data_comments_questions = None - data_company = None - data_data_type = None - data_irrelevant_paragraphs = None - data_kpi_id = None - data_relevant_paragraphs = None - data_sector = None - data_source_file = None - data_source_page = None - data_year = None - fixed_source_file = None - - def __init__(self): - self.data_number = 0 - self.data_sector = '' - self.data_unit = '' - self.data_answer = '' - self.data_comments_questions = '' - self.data_company = '' - self.data_data_type = '' - self.data_irrelevant_paragraphs = '' - self.data_kpi_id = 0 - self.data_relevant_paragraphs = '' - self.data_sector = '' - self.data_source_file = '' - self.data_source_page = 0 - self.data_year = 0 - self.fixed_source_file = None - - @staticmethod - def samples_to_string(lst, max_width=140, min_col_width=5): - ctab = ConsoleTable(14) - ctab.cells.append('NUMBER') - ctab.cells.append('SECTOR') - ctab.cells.append('UNIT') - ctab.cells.append('ANSWER') - ctab.cells.append('COMMENTS') - ctab.cells.append('COMPANY') - ctab.cells.append('DATA_TYPE') - ctab.cells.append('IRREL_PARAG') - ctab.cells.append('KPI_ID') - ctab.cells.append('RELEV_PARAG') - ctab.cells.append('SECTOR') - ctab.cells.append('SOURCE_FILE') - ctab.cells.append('SOURCE_PAGE') - ctab.cells.append('YEAR') + res = "" + for k in lst: + res += escape(str(k.data_number)) + ";" + res += escape(str(k.data_sector)) + ";" + res += escape(str(k.data_unit)) + ";" + res += escape(str(k.data_answer)) + ";" + res += escape(str(k.data_comments_questions)) + ";" + res += escape(str(k.data_company)) + ";" + res += escape(str(k.data_data_type)) + ";" + res += escape(str(k.data_irrelevant_paragraphs)) + ";" + res += escape(str(k.data_kpi_id)) + ";" + res += escape(str(k.data_relevant_paragraphs)) + ";" + res += escape(str(k.data_sector)) + ";" + res += escape(str(k.data_source_file)) + ";" + res += escape(str(k.data_source_page)) + ";" + res += escape(str(k.data_year)) + "\n" - - for k in lst: - ctab.cells.append(str(k.data_number )) - ctab.cells.append(str(k.data_sector )) - ctab.cells.append(str(k.data_unit )) - ctab.cells.append(str(k.data_answer )) - ctab.cells.append(str(k.data_comments_questions )) - ctab.cells.append(str(k.data_company )) - ctab.cells.append(str(k.data_data_type )) - ctab.cells.append(str(k.data_irrelevant_paragraphs )) - ctab.cells.append(str(k.data_kpi_id )) - ctab.cells.append(str(k.data_relevant_paragraphs )) - ctab.cells.append(str(k.data_sector )) - ctab.cells.append(str(k.data_source_file )) - ctab.cells.append(str(k.data_source_page )) - ctab.cells.append(str(k.data_year )) + return res - - return ctab.to_string(max_width, min_col_width) - - - @staticmethod - def samples_to_csv(lst): - def escape(txt): - txt = txt.replace("\n", "") - txt = txt.replace("\r", "") - txt = txt.replace('"', '""') - return '"' + Format_Analyzer.trim_whitespaces(txt) + '"' - - res = "" - for k in lst: - res += escape(str(k.data_number )) + ";" - res += escape(str(k.data_sector )) + ";" - res += escape(str(k.data_unit )) + ";" - res += escape(str(k.data_answer )) + ";" - res += escape(str(k.data_comments_questions )) + ";" - res += escape(str(k.data_company )) + ";" - res += escape(str(k.data_data_type )) + ";" - res += escape(str(k.data_irrelevant_paragraphs )) + ";" - res += escape(str(k.data_kpi_id )) + ";" - res += escape(str(k.data_relevant_paragraphs )) + ";" - res += escape(str(k.data_sector )) + ";" - res += escape(str(k.data_source_file )) + ";" - res += escape(str(k.data_source_page )) + ";" - res += escape(str(k.data_year )) + "\n" - - return res - - - - - def __repr__(self): - - return TestDataSample.samples_to_string([self]) - \ No newline at end of file + def __repr__(self): + return TestDataSample.samples_to_string([self]) diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestEvaluation.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestEvaluation.py index 16187c8..c9b3ddf 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestEvaluation.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/TestEvaluation.py @@ -10,187 +10,180 @@ from KPIResultSet import * from ConsoleTable import * + class TestEvaluation: + EVAL_TRUE_POSITIVE = 0 + EVAL_FALSE_POSITIVE = 1 + EVAL_TRUE_NEGATIVE = 2 + EVAL_FALSE_NEGATIVE = 3 + + class TestEvalSample: + kpispec = None + kpimeasure = None + test_sample = None + year = None + pdf_file_name = None + + def __init__(self, kpispec, kpimeasure, test_sample, year, pdf_file_name): + self.kpispec = kpispec + self.kpimeasure = kpimeasure + self.test_sample = test_sample + self.year = year + self.pdf_file_name = pdf_file_name + + def get_true_value(self): + return None if self.test_sample is None else Format_Analyzer.to_float_number(self.test_sample.data_answer) + + def get_extracted_value(self): + return None if self.kpimeasure is None else Format_Analyzer.to_float_number(self.kpimeasure.value) + + def eval(self): + if self.kpimeasure is not None and self.test_sample is not None: + if abs(self.get_extracted_value() - self.get_true_value()) < 0.0001: + return TestEvaluation.EVAL_TRUE_POSITIVE + return TestEvaluation.EVAL_FALSE_POSITIVE + + if self.test_sample is not None: + return TestEvaluation.EVAL_FALSE_NEGATIVE + + if self.kpimeasure is not None: + return TestEvaluation.EVAL_FALSE_POSITIVE + + return TestEvaluation.EVAL_TRUE_NEGATIVE + + def eval_to_str(self): + eval_id = self.eval() + if eval_id == TestEvaluation.EVAL_TRUE_POSITIVE: + return "True Positive" + if eval_id == TestEvaluation.EVAL_FALSE_POSITIVE: + return "False Positive" + if eval_id == TestEvaluation.EVAL_TRUE_NEGATIVE: + return "True Negative" + if eval_id == TestEvaluation.EVAL_FALSE_NEGATIVE: + return "False Negative" + return "Unknown" + + eval_samples = None + num_true_positive = None + num_false_positive = None + num_true_negative = None + num_false_negative = None + measure_precision = None + measure_recall = None + + def __init__(self): + self.eval_samples = [] + self.num_true_positive = 0 + self.num_false_positive = 0 + self.num_true_negative = 0 + self.num_false_negative = 0 + self.measure_precision = 0.0 + self.measure_recall = 0.0 + + def do_evaluations(self): + self.num_true_positive = 0 + self.num_false_positive = 0 + self.num_true_negative = 0 + self.num_false_negative = 0 + for e in self.eval_samples: + eval_id = e.eval() + if eval_id == TestEvaluation.EVAL_TRUE_POSITIVE: + self.num_true_positive += 1 + if eval_id == TestEvaluation.EVAL_FALSE_POSITIVE: + self.num_false_positive += 1 + if eval_id == TestEvaluation.EVAL_TRUE_NEGATIVE: + self.num_true_negative += 1 + if eval_id == TestEvaluation.EVAL_FALSE_NEGATIVE: + self.num_false_negative += 1 + + if self.num_true_positive > 0: + self.measure_precision = self.num_true_positive / float(self.num_true_positive + self.num_false_positive) + self.measure_recall = self.num_true_positive / float(self.num_true_positive + self.num_false_negative) + else: + self.measure_precision = 0.0 + self.measure_recall = 0.0 + + def to_string(self, max_width, min_col_width, format): + ctab = ConsoleTable(7) + ctab.cells.append("KPI_ID") + ctab.cells.append("KPI_NAME") + ctab.cells.append("PDF_FILE") + ctab.cells.append("YEAR") + ctab.cells.append("TRUE VALUE") + ctab.cells.append("EXTRACTED VALUE") + ctab.cells.append("CLASSIFICATION") + + for e in self.eval_samples: + ctab.cells.append(str(e.kpispec.kpi_id)) + ctab.cells.append(str(e.kpispec.kpi_name)) + ctab.cells.append(str(e.pdf_file_name)) + ctab.cells.append(str(e.year)) + ctab.cells.append(str(e.get_true_value())) + ctab.cells.append(str(e.get_extracted_value())) + ctab.cells.append(e.eval_to_str().upper()) + + res = ctab.to_string(max_width, min_col_width, format) + + res += "\nSUMMARY:\n" + res += "True Positives : " + str(self.num_true_positive) + "\n" + res += "False Positives : " + str(self.num_false_positive) + "\n" + res += "True Negatives : " + str(self.num_true_negative) + "\n" + res += "False Negatives : " + str(self.num_false_negative) + "\n" + res += "Precision : " + str(self.measure_precision) + "\n" + res += "Recall : " + str(self.measure_recall) + "\n" + + return res + + def __repr__(self): + return self.to_string(120, 5, ConsoleTable.FORMAT_CSV) + + @staticmethod + def generate_evaluation(kpispecs, kpiresults, test_data): + pdf_file_names = test_data.get_fixed_pdf_list() + + res = TestEvaluation() + + for kpispec in kpispecs: + print_verbose(1, "Evaluating KPI: kpi_id=" + str(kpispec.kpi_id) + ', kpi_name="' + kpispec.kpi_name + '"') + for pdf_file_name in pdf_file_names: + print_verbose(1, '--->> Evaluating PDF = "' + pdf_file_name + '"') + # Find values in test data samples for this kpi/pdf: + for s in test_data.samples: + if s.data_kpi_id == kpispec.kpi_id and s.fixed_source_file == pdf_file_name: + # match (True KPI exists in pdf) + cur_eval_sample = None + # are there any matches in our results? + for k in kpiresults.kpimeasures: + if k.kpi_id == kpispec.kpi_id and k.src_file == pdf_file_name and k.year == s.data_year: + # yes (Extracted KPI exists) + cur_eval_sample = TestEvaluation.TestEvalSample(kpispec, k, s, k.year, pdf_file_name) + break + if cur_eval_sample is None: + # no + cur_eval_sample = TestEvaluation.TestEvalSample( + kpispec, None, s, s.data_year, pdf_file_name + ) + res.eval_samples.append(cur_eval_sample) + + # Any unmatched kpi results (ie. extracted KPIs) left? + for k in kpiresults.kpimeasures: + if k.src_file != pdf_file_name: + # print('skip: ' +str(k)) + continue + found = False + for e in res.eval_samples: + if ( + e.kpimeasure is not None + and k.kpi_id == e.kpispec.kpi_id + and k.year == e.year + and e.kpimeasure.src_file == pdf_file_name + ): + found = True + break + if not found: + # unmatched + cur_eval_sample = TestEvaluation.TestEvalSample(kpispec, k, None, k.year, pdf_file_name) + res.eval_samples.append(cur_eval_sample) - EVAL_TRUE_POSITIVE = 0 - EVAL_FALSE_POSITIVE = 1 - EVAL_TRUE_NEGATIVE = 2 - EVAL_FALSE_NEGATIVE = 3 - - class TestEvalSample: - kpispec = None - kpimeasure = None - test_sample = None - year = None - pdf_file_name = None - - def __init__(self, kpispec, kpimeasure, test_sample, year, pdf_file_name): - self.kpispec = kpispec - self.kpimeasure = kpimeasure - self.test_sample = test_sample - self.year = year - self.pdf_file_name = pdf_file_name - - def get_true_value(self): - return None if self.test_sample is None else Format_Analyzer.to_float_number(self.test_sample.data_answer) - - def get_extracted_value(self): - return None if self.kpimeasure is None else Format_Analyzer.to_float_number(self.kpimeasure.value) - - def eval(self): - if(self.kpimeasure is not None and self.test_sample is not None): - if(abs(self.get_extracted_value() - self.get_true_value()) < 0.0001): - return TestEvaluation.EVAL_TRUE_POSITIVE - return TestEvaluation.EVAL_FALSE_POSITIVE - - if(self.test_sample is not None): - return TestEvaluation.EVAL_FALSE_NEGATIVE - - if(self.kpimeasure is not None): - return TestEvaluation.EVAL_FALSE_POSITIVE - - return TestEvaluation.EVAL_TRUE_NEGATIVE - - - def eval_to_str(self): - eval_id = self.eval() - if(eval_id == TestEvaluation.EVAL_TRUE_POSITIVE): - return "True Positive" - if(eval_id == TestEvaluation.EVAL_FALSE_POSITIVE): - return "False Positive" - if(eval_id == TestEvaluation.EVAL_TRUE_NEGATIVE): - return "True Negative" - if(eval_id == TestEvaluation.EVAL_FALSE_NEGATIVE): - return "False Negative" - return "Unknown" - - - - eval_samples = None - num_true_positive = None - num_false_positive = None - num_true_negative = None - num_false_negative = None - measure_precision = None - measure_recall = None - - - def __init__(self): - self.eval_samples = [] - self.num_true_positive = 0 - self.num_false_positive = 0 - self.num_true_negative = 0 - self.num_false_negative = 0 - self.measure_precision = 0.0 - self.measure_recall = 0.0 - - def do_evaluations(self): - self.num_true_positive = 0 - self.num_false_positive = 0 - self.num_true_negative = 0 - self.num_false_negative = 0 - for e in self.eval_samples: - eval_id = e.eval() - if(eval_id == TestEvaluation.EVAL_TRUE_POSITIVE): - self.num_true_positive += 1 - if(eval_id == TestEvaluation.EVAL_FALSE_POSITIVE): - self.num_false_positive += 1 - if(eval_id == TestEvaluation.EVAL_TRUE_NEGATIVE): - self.num_true_negative += 1 - if(eval_id == TestEvaluation.EVAL_FALSE_NEGATIVE): - self.num_false_negative += 1 - - if(self.num_true_positive>0): - self.measure_precision = self.num_true_positive / float(self.num_true_positive + self.num_false_positive) - self.measure_recall = self.num_true_positive / float(self.num_true_positive + self.num_false_negative) - else: - self.measure_precision = 0.0 - self.measure_recall = 0.0 - - def to_string(self, max_width, min_col_width, format): - ctab = ConsoleTable(7) - ctab.cells.append('KPI_ID') - ctab.cells.append('KPI_NAME') - ctab.cells.append('PDF_FILE') - ctab.cells.append('YEAR') - ctab.cells.append('TRUE VALUE') - ctab.cells.append('EXTRACTED VALUE') - ctab.cells.append('CLASSIFICATION') - - for e in self.eval_samples: - ctab.cells.append(str(e.kpispec.kpi_id)) - ctab.cells.append(str(e.kpispec.kpi_name)) - ctab.cells.append(str(e.pdf_file_name)) - ctab.cells.append(str(e.year)) - ctab.cells.append(str(e.get_true_value())) - ctab.cells.append(str(e.get_extracted_value())) - ctab.cells.append(e.eval_to_str().upper()) - - res = ctab.to_string(max_width, min_col_width, format) - - res += "\nSUMMARY:\n" - res += "True Positives : " + str(self.num_true_positive) + "\n" - res += "False Positives : " + str(self.num_false_positive) + "\n" - res += "True Negatives : " + str(self.num_true_negative) + "\n" - res += "False Negatives : " + str(self.num_false_negative) + "\n" - res += "Precision : " + str(self.measure_precision) + "\n" - res += "Recall : " + str(self.measure_recall) + "\n" - - - return res - - def __repr__(self): - - return self.to_string(120, 5, ConsoleTable.FORMAT_CSV) - - - @staticmethod - def generate_evaluation(kpispecs, kpiresults, test_data): - pdf_file_names = test_data.get_fixed_pdf_list() - - res = TestEvaluation() - - for kpispec in kpispecs: - print_verbose(1, 'Evaluating KPI: kpi_id=' +str(kpispec.kpi_id) + ', kpi_name="' + kpispec.kpi_name + '"') - for pdf_file_name in pdf_file_names: - print_verbose(1, '--->> Evaluating PDF = "' + pdf_file_name + '"') - # Find values in test data samples for this kpi/pdf: - for s in test_data.samples: - if(s.data_kpi_id == kpispec.kpi_id and s.fixed_source_file == pdf_file_name): - #match (True KPI exists in pdf) - cur_eval_sample = None - #are there any matches in our results? - for k in kpiresults.kpimeasures: - if(k.kpi_id == kpispec.kpi_id and k.src_file == pdf_file_name and k.year == s.data_year): - #yes (Extracted KPI exists) - cur_eval_sample = TestEvaluation.TestEvalSample(kpispec, k, s, k.year, pdf_file_name) - break - if(cur_eval_sample is None): - #no - cur_eval_sample = TestEvaluation.TestEvalSample(kpispec, None, s, s.data_year, pdf_file_name) - res.eval_samples.append(cur_eval_sample) - - # Any unmatched kpi results (ie. extracted KPIs) left? - for k in kpiresults.kpimeasures: - if(k.src_file != pdf_file_name): - #print('skip: ' +str(k)) - continue - found = False - for e in res.eval_samples: - if(e.kpimeasure is not None and k.kpi_id == e.kpispec.kpi_id and k.year == e.year and e.kpimeasure.src_file == pdf_file_name): - found = True - break - if(not found): - #unmatched - cur_eval_sample = TestEvaluation.TestEvalSample(kpispec, k, None, k.year, pdf_file_name) - res.eval_samples.append(cur_eval_sample) - - res.do_evaluations() - return res - - - - - - - \ No newline at end of file + res.do_evaluations() + return res diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/config.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/config.py index c9e55bf..5202620 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/config.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/config.py @@ -6,33 +6,41 @@ # ============================================================================================================================ -global_verbosity = 1 ### TODO: Change verbosity here. verbosity 6-8 are good values for debugging without too much output ### +global_verbosity = ( + 1 ### TODO: Change verbosity here. verbosity 6-8 are good values for debugging without too much output ### +) global_exec_folder = r"./" # ***** gas and oil pdfs ***** : -global_raw_pdf_folder = r"raw_pdf/" -#global_raw_pdf_folder = os.path.expanduser('~') + r"/raw_pdf/" +global_raw_pdf_folder = r"raw_pdf/" +# global_raw_pdf_folder = os.path.expanduser('~') + r"/raw_pdf/" -global_working_folder = r"work_dir/" +global_working_folder = r"work_dir/" global_output_folder = r"output/" -global_kpi_spec_path = "" # if set, then command line argument will be ignored; example: "kpispec.txt" +global_kpi_spec_path = "" # if set, then command line argument will be ignored; example: "kpispec.txt" global_rendering_font_override = r"default_font.otf" -global_approx_font_name =r"default_font.otf" # use this font as approximation -global_max_identify_complex_items_timeout = 0.5 # seconds +global_approx_font_name = r"default_font.otf" # use this font as approximation +global_max_identify_complex_items_timeout = 0.5 # seconds global_force_special_items_into_table = True -global_row_connection_threshold = 10.0 #default=5 . If there is empty space for that many times the previous row height, we will consider this as two distinct tables -global_be_more_generous_with_good_tables = True # default=False. If true, we will consider some tables as good that normally considered bad +global_row_connection_threshold = 10.0 # default=5 . If there is empty space for that many times the previous row height, we will consider this as two distinct tables +global_be_more_generous_with_good_tables = ( + True # default=False. If true, we will consider some tables as good that normally considered bad +) -global_table_merge_non_overlapping_rows = True #default: Fale. If true, system will try to merge non-overlapping rows that probably belong to the same cell -#global_table_merging_only_if_numbers_come_first = True # default: False. If true, system will only merge such rows, where the first rows contains numbers (not used, doesnt really work) +global_table_merge_non_overlapping_rows = ( + True # default: Fale. If true, system will try to merge non-overlapping rows that probably belong to the same cell +) +# global_table_merging_only_if_numbers_come_first = True # default: False. If true, system will only merge such rows, where the first rows contains numbers (not used, doesnt really work) -global_html_encoding = "utf-8" # default: "ascii" +global_html_encoding = "utf-8" # default: "ascii" -global_ignore_all_years = False # default: False. Set it to true to ignore all years for every KPI (this is used for CDP reports) +global_ignore_all_years = ( + False # default: False. Set it to true to ignore all years for every KPI (this is used for CDP reports) +) -global_analyze_multiple_pages_at_one = False # default: False. Set it to True, to additionally search for KPIs on multiple (currently: 2) subsequent pages at once. \ No newline at end of file +global_analyze_multiple_pages_at_one = False # default: False. Set it to True, to additionally search for KPIs on multiple (currently: 2) subsequent pages at once. diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/globals.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/globals.py index 3e205d2..b540520 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/globals.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/globals.py @@ -23,114 +23,129 @@ import math -ALIGN_DEFAULT = 0 -ALIGN_LEFT = 1 -ALIGN_RIGHT = 2 -ALIGN_CENTER = 3 # New-27.06.2022 - - -#Text categories -CAT_DEFAULT = 0 -CAT_RUNNING_TEXT = 1 -CAT_HEADLINE = 2 -CAT_OTHER_TEXT = 3 -CAT_TABLE_DATA = 4 -CAT_TABLE_HEADLINE = 5 -CAT_TABLE_SPECIAL = 6 #e.g., annonations -CAT_MISC = 7 #probably belongs to figures -CAT_FOOTER = 8 #bugfix 26.07.2022 -CAT_FOOTNOTE = 9 #new categorie for finding footnotes +ALIGN_DEFAULT = 0 +ALIGN_LEFT = 1 +ALIGN_RIGHT = 2 +ALIGN_CENTER = 3 # New-27.06.2022 + + +# Text categories +CAT_DEFAULT = 0 +CAT_RUNNING_TEXT = 1 +CAT_HEADLINE = 2 +CAT_OTHER_TEXT = 3 +CAT_TABLE_DATA = 4 +CAT_TABLE_HEADLINE = 5 +CAT_TABLE_SPECIAL = 6 # e.g., annonations +CAT_MISC = 7 # probably belongs to figures +CAT_FOOTER = 8 # bugfix 26.07.2022 +CAT_FOOTNOTE = 9 # new categorie for finding footnotes # Other constants -DEFAULT_VTHRESHOLD = 15.0 / 609.0 #609px is sample page width -DEFAULT_SPECIAL_ITEM_MAX_DIST = 15.0 / 609.0 #609px is sample page width -DEFAULT_HTHROWAWAY_DIST = 0.3 -DEFAULT_SPECIAL_ITEM_CUTOFF_DIST = 15.0 / 609.9 #609px is sample page width -DEFAULT_FLYSPECK_HEIGHT = 3.0 / 841.0 #841.0 is sampe page height +DEFAULT_VTHRESHOLD = 15.0 / 609.0 # 609px is sample page width +DEFAULT_SPECIAL_ITEM_MAX_DIST = 15.0 / 609.0 # 609px is sample page width +DEFAULT_HTHROWAWAY_DIST = 0.3 +DEFAULT_SPECIAL_ITEM_CUTOFF_DIST = 15.0 / 609.9 # 609px is sample page width +DEFAULT_FLYSPECK_HEIGHT = 3.0 / 841.0 # 841.0 is sampe page height # Rendering options -RENDERING_USE_CLUSTER_COLORS = False +RENDERING_USE_CLUSTER_COLORS = False # Workaround for redirecting PRINT to STDOUT (use it, if you get an error message) -if sys.stdout.encoding != 'utf-8': - sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict') -if sys.stderr.encoding != 'utf-8': - sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict') - - +if sys.stdout.encoding != "utf-8": + sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer, "strict") +if sys.stderr.encoding != "utf-8": + sys.stderr = codecs.getwriter("utf-8")(sys.stderr.buffer, "strict") + + def wait_for_user(): - input("Press Enter to continue...") - - -def print_big(txt, do_wait = True): - if(config.global_verbosity == 0): - return - if(do_wait): - wait_for_user() - print("=======================================") - print(txt.upper()) - print("=======================================") - if(do_wait): - wait_for_user() - - - + input("Press Enter to continue...") + + +def print_big(txt, do_wait=True): + if config.global_verbosity == 0: + return + if do_wait: + wait_for_user() + print("=======================================") + print(txt.upper()) + print("=======================================") + if do_wait: + wait_for_user() + + def print_verbose(verbosity, txt): - if(verbosity <= config.global_verbosity): - print(str(txt)) - + if verbosity <= config.global_verbosity: + print(str(txt)) + + def print_subset(verbosity, list, subset): - for s in subset: - print_verbose(verbosity, list[s]) - + for s in subset: + print_verbose(verbosity, list[s]) + def file_exists(fname): - return os.path.isfile(fname) - + return os.path.isfile(fname) + + def get_num_of_files(pattern): - return len(glob.glob(pattern)) - + return len(glob.glob(pattern)) + + def remove_trailing_slash(s): - if(s.endswith('/') or s.endswith('\\')): - s = s[:-1] - return s - -def remove_bad_chars(s, c): #removed all occurences of c in s - res = s - for x in c: - res = res.replace(x, '') - return res - -def dist(x1,y1,x2,y2): - return ((x1-x2)*(x1-x2) + (y1-y2)*(y1-y2))**0.5 - -def get_text_width(text, font): - size = font.getsize(text) - return size[0] + if s.endswith("/") or s.endswith("\\"): + s = s[:-1] + return s + + +def remove_bad_chars(s, c): # removed all occurences of c in s + res = s + for x in c: + res = res.replace(x, "") + return res + + +def dist(x1, y1, x2, y2): + return ((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)) ** 0.5 + + +def get_text_width(text, font): + size = font.getsize(text) + return size[0] + def get_html_out_dir(fname): - fname = '/'+ fname.replace('\\','/') - return config.global_working_folder + r'html/' + fname[(fname.rfind(r'/')+1):] + r'.html_dir' - - - + fname = "/" + fname.replace("\\", "/") + return config.global_working_folder + r"html/" + fname[(fname.rfind(r"/") + 1) :] + r".html_dir" + + def analyze_pdf(fname): - pdf_to_html(fname , get_html_out_dir(fname)) - + pdf_to_html(fname, get_html_out_dir(fname)) + + def save_txt_to_file(txt, fname): - with open(fname, "w", encoding="utf-8") as text_file: - text_file.write(txt) - - -def hsv_to_rgba(h, s, v): #h,s,v in [0,1], result r,g,b,a in [0,256) - if s == 0.0: return (v, v, v) - i = int(h*6.) # XXX assume int() truncates! - f = (h*6.)-i; p,q,t = v*(1.-s), v*(1.-s*f), v*(1.-s*(1.-f)); i%=6 - if i == 0: return (int(255*v), int(255*t), int(255*p), 255) - if i == 1: return (int(255*q), int(255*v), int(255*p), 255) - if i == 2: return (int(255*p), int(255*v), int(255*t), 255) - if i == 3: return (int(255*p), int(255*q), int(255*v), 255) - if i == 4: return (int(255*t), int(255*p), int(255*v), 255) - if i == 5: return (int(255*v), int(255*p), int(255*q), 255) \ No newline at end of file + with open(fname, "w", encoding="utf-8") as text_file: + text_file.write(txt) + + +def hsv_to_rgba(h, s, v): # h,s,v in [0,1], result r,g,b,a in [0,256) + if s == 0.0: + return (v, v, v) + i = int(h * 6.0) # XXX assume int() truncates! + f = (h * 6.0) - i + p, q, t = v * (1.0 - s), v * (1.0 - s * f), v * (1.0 - s * (1.0 - f)) + i %= 6 + if i == 0: + return (int(255 * v), int(255 * t), int(255 * p), 255) + if i == 1: + return (int(255 * q), int(255 * v), int(255 * p), 255) + if i == 2: + return (int(255 * p), int(255 * v), int(255 * t), 255) + if i == 3: + return (int(255 * p), int(255 * q), int(255 * v), 255) + if i == 4: + return (int(255 * t), int(255 * p), int(255 * v), 255) + if i == 5: + return (int(255 * v), int(255 * p), int(255 * q), 255) diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/main.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/main.py index 1dd61cf..be44513 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/main.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/main.py @@ -15,208 +15,200 @@ from KPIResultSet import * from TestData import * from DataImportExport import * -from test import * #only for testing / debugging purpose +from test import * # only for testing / debugging purpose import config - - - - - def generate_dummy_test_data(): + test_data = TestData() + test_data.generate_dummy_test_data(config.global_raw_pdf_folder, "*") + # print("DATA-SET:") + # print(test_data) + + return test_data + + +def analyze_pdf( + pdf_file, + kpis, + default_year, + info_file_contents, + wildcard_restrict_page="*", + force_pdf_convert=False, + force_parse_pdf=False, + assume_conversion_done=False, + do_wait=False, +): + print_verbose(1, "Analyzing PDF: " + str(pdf_file)) + + guess_year = Format_Analyzer.extract_year_from_text(pdf_file) + if guess_year is None: + guess_year = default_year + + htmldir_path = get_html_out_dir(pdf_file) + os.makedirs(htmldir_path, exist_ok=True) + + reload_neccessary = True + + if not assume_conversion_done: + # convert pdf to html + print_big("Convert PDF to HTML", do_wait) + if force_pdf_convert or not file_exists(htmldir_path + "/index.html"): + HTMLDirectory.convert_pdf_to_html(pdf_file, info_file_contents) + + # return KPIResultSet()# STOP after converting PDF files (dont continue with analysis) + + # parse and create json and png + print_big("Convert HTML to JSON and PNG", do_wait) + # print(htmldir_path) + dir = HTMLDirectory() + if force_parse_pdf or get_num_of_files(htmldir_path + "/jpage*.json") != get_num_of_files( + htmldir_path + "/page*.html" + ): + dir.parse_html_directory(get_html_out_dir(pdf_file), "page*.html") # ! page* + dir.render_to_png(htmldir_path, htmldir_path) + dir.save_to_dir(htmldir_path) + # exit() #TODO: Remove this!!!!! + if wildcard_restrict_page == "*": + reload_neccessary = False + + # return KPIResultSet()# STOP after parsing HTML files (dont continue with analysis) + + # load json files + print_big("Load from JSON", do_wait) + if reload_neccessary: + dir = HTMLDirectory() + dir.load_from_dir(htmldir_path, "jpage" + str(wildcard_restrict_page) + ".json") + + # analyze + print_big("Analyze Pages", do_wait) + ana = AnalyzerDirectory(dir, guess_year) + # kpis = test_prepare_kpispecs() + # print(kpis) + + kpiresults = KPIResultSet(ana.find_multiple_kpis(kpis)) + + print_big("FINAL RESULT FOR: " + str(pdf_file.upper()), do_wait=False) + print_verbose(1, kpiresults) + + return kpiresults + - test_data = TestData() - test_data.generate_dummy_test_data(config.global_raw_pdf_folder, '*') - #print("DATA-SET:") - #print(test_data) - - return test_data - -def analyze_pdf(pdf_file, kpis, default_year, info_file_contents, wildcard_restrict_page='*', force_pdf_convert=False, force_parse_pdf=False, assume_conversion_done=False, do_wait=False): - - print_verbose(1, "Analyzing PDF: " +str(pdf_file)) - - guess_year = Format_Analyzer.extract_year_from_text(pdf_file) - if(guess_year is None): - guess_year = default_year - - - htmldir_path = get_html_out_dir(pdf_file) - os.makedirs(htmldir_path, exist_ok=True) - - - reload_neccessary = True - - - if(not assume_conversion_done): - #convert pdf to html - print_big("Convert PDF to HTML", do_wait) - if(force_pdf_convert or not file_exists(htmldir_path+'/index.html')): - HTMLDirectory.convert_pdf_to_html(pdf_file, info_file_contents) - - - #return KPIResultSet()# STOP after converting PDF files (dont continue with analysis) - - - # parse and create json and png - print_big("Convert HTML to JSON and PNG", do_wait) - #print(htmldir_path) - dir = HTMLDirectory() - if(force_parse_pdf or get_num_of_files(htmldir_path+'/jpage*.json') != get_num_of_files(htmldir_path+'/page*.html') ): - dir.parse_html_directory(get_html_out_dir(pdf_file), 'page*.html') # ! page* - dir.render_to_png(htmldir_path, htmldir_path) - dir.save_to_dir(htmldir_path) - #exit() #TODO: Remove this!!!!! - if(wildcard_restrict_page == '*'): - reload_neccessary = False - - #return KPIResultSet()# STOP after parsing HTML files (dont continue with analysis) - - # load json files - print_big("Load from JSON", do_wait) - if(reload_neccessary): - dir = HTMLDirectory() - dir.load_from_dir(htmldir_path, 'jpage' + str(wildcard_restrict_page) + '.json') - - - # analyze - print_big("Analyze Pages", do_wait) - ana = AnalyzerDirectory(dir, guess_year) - #kpis = test_prepare_kpispecs() - #print(kpis) - - kpiresults = KPIResultSet(ana.find_multiple_kpis(kpis)) - - - print_big("FINAL RESULT FOR: "+ str(pdf_file.upper()), do_wait = False) - print_verbose(1, kpiresults) - - return kpiresults - - - def get_input_variable(val, desc): - if val is None: - val = input(desc) - - if(val is None or val==""): - print("This must not be empty") - sys.exit(0) - - return val - - -def main(): + if val is None: + val = input(desc) - DEFAULT_YEAR = 2019 - - parser = argparse.ArgumentParser(description='Rule-based KPI extraction') - # Add the arguments - parser.add_argument('--raw_pdf_folder', - type=str, - default=None, - help='Folder where PDFs are stored') - parser.add_argument('--working_folder', - type=str, - default=None, - help='Folder where working files are stored') - parser.add_argument('--output_folder', - type=str, - default=None, - help='Folder where output is stored') - parser.add_argument('--verbosity', - type=int, - default=1, - help='Verbosity level (0=shut up)') - args = parser.parse_args() - config.global_raw_pdf_folder = remove_trailing_slash(get_input_variable(args.raw_pdf_folder, "What is the raw pdf folder?")).replace('\\', '/') + r'/' - config.global_working_folder = remove_trailing_slash(get_input_variable(args.working_folder, "What is the working folder?")).replace('\\', '/') + r'/' - config.global_output_folder = remove_trailing_slash(get_input_variable(args.output_folder, "What is the output folder?")).replace('\\', '/') + r'/' - config.global_verbosity = args.verbosity - - os.makedirs(config.global_working_folder, exist_ok=True) - os.makedirs(config.global_output_folder, exist_ok=True) - - # fix config.global_exec_folder and config.global_rendering_font_override - path = '' - try: - path = globals()['_dh'][0] - except KeyError: - path = os.path.dirname(os.path.realpath(__file__)) - path = remove_trailing_slash(path).replace('\\', '/') - - config.global_exec_folder = path+ r'/' - config.global_rendering_font_override = path + r'/' + config.global_rendering_font_override - - print_verbose(1, "Using config.global_exec_folder=" + config.global_exec_folder) - print_verbose(1, "Using config.global_raw_pdf_folder=" + config.global_raw_pdf_folder) - print_verbose(1, "Using config.global_working_folder=" + config.global_working_folder) - print_verbose(1, "Using config.global_output_folder=" + config.global_output_folder) - print_verbose(1, "Using config.global_verbosity=" + str(config.global_verbosity)) - print_verbose(5, "Using config.global_rendering_font_override=" + config.global_rendering_font_override) - - #test_data = load_test_data(r'test_data/aggregated_complete_samples_new.csv') - test_data = generate_dummy_test_data() - - # For debugging, save csv: - #test_data.save_to_csv(r'test_data/test_output_new.csv') - #return - - - - - # Filter PDF - #test_data.filter_kpis(by_source_file = ['PGE_Corporation_CDP_Climate_Change_Questionnaire_2021.pdf']) - - - print_big("Data-set", False) - print_verbose(1, test_data) - - - pdfs = test_data.get_fixed_pdf_list() - - print_verbose(1, 'Related (fixed) PDFs: ' + str(pdfs) + ', in total : ' +str(len(pdfs))) - #return ### TODO: Uncomment this line, to return immediately, after PDF list has been shown. ### - - - - kpis = test_prepare_kpispecs() # TODO: In the future, KPI specs should be loaded from "nicer" implemented source, e.g., JSON file definiton - - overall_kpiresults = KPIResultSet() - - info_file_contents = DataImportExport.load_info_file_contents(remove_trailing_slash(config.global_working_folder) + '/info.json') - - time_start = time.time() - - for pdf in pdfs: - kpiresults = KPIResultSet(kpimeasures = []) - cur_kpiresults = analyze_pdf(config.global_raw_pdf_folder + pdf, kpis, DEFAULT_YEAR, info_file_contents, wildcard_restrict_page='*', assume_conversion_done=False, force_parse_pdf=False) ### TODO: Modify * in order to analyze specfic page, e.g.: *00042 ### - kpiresults.extend(cur_kpiresults) - overall_kpiresults.extend(cur_kpiresults) - kpiresults.save_to_csv_file(config.global_output_folder + pdf + r'.csv') - print_verbose(1, "RESULT FOR " + pdf) - print_verbose(1, kpiresults) - - - - time_finish = time.time() - - print_big("FINAL OVERALL-RESULT", do_wait = False) - print_verbose(1, overall_kpiresults) - - #overall_kpiresults.save_to_file(config.global_output_folder + r'kpiresults_test_tmp.json') - overall_kpiresults.save_to_csv_file(config.global_output_folder + r'kpiresults_tmp.csv') - - - total_time = time_finish - time_start - print_verbose(1, "Total run-time: " + str(total_time) + " sec ( " + str(total_time / max(len(pdfs), 1)) + " sec per PDF)") - - - - - -main() + if val is None or val == "": + print("This must not be empty") + sys.exit(0) + return val +def main(): + DEFAULT_YEAR = 2019 + + parser = argparse.ArgumentParser(description="Rule-based KPI extraction") + # Add the arguments + parser.add_argument("--raw_pdf_folder", type=str, default=None, help="Folder where PDFs are stored") + parser.add_argument("--working_folder", type=str, default=None, help="Folder where working files are stored") + parser.add_argument("--output_folder", type=str, default=None, help="Folder where output is stored") + parser.add_argument("--verbosity", type=int, default=1, help="Verbosity level (0=shut up)") + args = parser.parse_args() + config.global_raw_pdf_folder = ( + remove_trailing_slash(get_input_variable(args.raw_pdf_folder, "What is the raw pdf folder?")).replace("\\", "/") + + r"/" + ) + config.global_working_folder = ( + remove_trailing_slash(get_input_variable(args.working_folder, "What is the working folder?")).replace("\\", "/") + + r"/" + ) + config.global_output_folder = ( + remove_trailing_slash(get_input_variable(args.output_folder, "What is the output folder?")).replace("\\", "/") + + r"/" + ) + config.global_verbosity = args.verbosity + + os.makedirs(config.global_working_folder, exist_ok=True) + os.makedirs(config.global_output_folder, exist_ok=True) + + # fix config.global_exec_folder and config.global_rendering_font_override + path = "" + try: + path = globals()["_dh"][0] + except KeyError: + path = os.path.dirname(os.path.realpath(__file__)) + path = remove_trailing_slash(path).replace("\\", "/") + + config.global_exec_folder = path + r"/" + config.global_rendering_font_override = path + r"/" + config.global_rendering_font_override + + print_verbose(1, "Using config.global_exec_folder=" + config.global_exec_folder) + print_verbose(1, "Using config.global_raw_pdf_folder=" + config.global_raw_pdf_folder) + print_verbose(1, "Using config.global_working_folder=" + config.global_working_folder) + print_verbose(1, "Using config.global_output_folder=" + config.global_output_folder) + print_verbose(1, "Using config.global_verbosity=" + str(config.global_verbosity)) + print_verbose(5, "Using config.global_rendering_font_override=" + config.global_rendering_font_override) + + # test_data = load_test_data(r'test_data/aggregated_complete_samples_new.csv') + test_data = generate_dummy_test_data() + + # For debugging, save csv: + # test_data.save_to_csv(r'test_data/test_output_new.csv') + # return + + # Filter PDF + # test_data.filter_kpis(by_source_file = ['PGE_Corporation_CDP_Climate_Change_Questionnaire_2021.pdf']) + + print_big("Data-set", False) + print_verbose(1, test_data) + + pdfs = test_data.get_fixed_pdf_list() + + print_verbose(1, "Related (fixed) PDFs: " + str(pdfs) + ", in total : " + str(len(pdfs))) + # return ### TODO: Uncomment this line, to return immediately, after PDF list has been shown. ### + + kpis = ( + test_prepare_kpispecs() + ) # TODO: In the future, KPI specs should be loaded from "nicer" implemented source, e.g., JSON file definiton + + overall_kpiresults = KPIResultSet() + + info_file_contents = DataImportExport.load_info_file_contents( + remove_trailing_slash(config.global_working_folder) + "/info.json" + ) + + time_start = time.time() + + for pdf in pdfs: + kpiresults = KPIResultSet(kpimeasures=[]) + cur_kpiresults = analyze_pdf( + config.global_raw_pdf_folder + pdf, + kpis, + DEFAULT_YEAR, + info_file_contents, + wildcard_restrict_page="*", + assume_conversion_done=False, + force_parse_pdf=False, + ) ### TODO: Modify * in order to analyze specfic page, e.g.: *00042 ### + kpiresults.extend(cur_kpiresults) + overall_kpiresults.extend(cur_kpiresults) + kpiresults.save_to_csv_file(config.global_output_folder + pdf + r".csv") + print_verbose(1, "RESULT FOR " + pdf) + print_verbose(1, kpiresults) + + time_finish = time.time() + + print_big("FINAL OVERALL-RESULT", do_wait=False) + print_verbose(1, overall_kpiresults) + + # overall_kpiresults.save_to_file(config.global_output_folder + r'kpiresults_test_tmp.json') + overall_kpiresults.save_to_csv_file(config.global_output_folder + r"kpiresults_tmp.csv") + + total_time = time_finish - time_start + print_verbose( + 1, "Total run-time: " + str(total_time) + " sec ( " + str(total_time / max(len(pdfs), 1)) + " sec per PDF)" + ) + + +main() diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/main_find_xy.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/main_find_xy.py index 5952ce8..f103a54 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/main_find_xy.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/main_find_xy.py @@ -1,7 +1,7 @@ # ============================================================================================================================ # PDF_Analyzer # File : main_find_xy.py -# Author : Lei Deng (D87HMXV) - reference mian.py +# Author : Lei Deng (D87HMXV) - reference mian.py # Date : 12.10.2022 # ============================================================================================================================ @@ -15,224 +15,234 @@ import config import re import pandas as pd + pd.options.mode.chained_assignment = None # default='warn' def generate_dummy_test_data(): + test_data = TestData() + test_data.generate_dummy_test_data(config.global_raw_pdf_folder, "*") + + return test_data + + +def analyze_pdf( + pdf_file, + pageNum, + txt, + info_file_contents, + force_pdf_convert=False, + force_parse_pdf=False, + assume_conversion_done=False, + do_wait=False, +): + print_verbose(1, "Analyzing PDF: " + str(pdf_file)) + + htmldir_path = get_html_out_dir(pdf_file) # get pdf + os.makedirs(htmldir_path, exist_ok=True) # ceate directory recursively + + reload_neccessary = True # why use this? + + dir = HTMLDirectory() + if not assume_conversion_done: + # convert pdf to html + print_big("Convert PDF to HTML", do_wait) + if force_pdf_convert or not file_exists(htmldir_path + "/index.html"): + HTMLDirectory.convert_pdf_to_html(pdf_file, info_file_contents) # 04.11.2022 + + # parse and create json and png + print_big("Convert HTML to JSON and PNG", do_wait) + + # dir = HTMLDirectory() + if force_parse_pdf or get_num_of_files(htmldir_path + "/jpage*.json") != get_num_of_files( + htmldir_path + "/page*.html" + ): + dir.parse_html_directory(get_html_out_dir(pdf_file), "page*.html") # ! page* + dir.render_to_png(htmldir_path, htmldir_path) + dir.save_to_dir(htmldir_path) + + # load json files + print_big("Load from JSON", do_wait) + if reload_neccessary: + dir = HTMLDirectory() + dir.load_from_dir(htmldir_path, "jpage*" + str(pageNum) + ".json") + + # get coordinates + print_big("get coordinates", do_wait) + res = [] + index = None + this_id = 0 + for p in dir.htmlpages: + # print_verbose(2, p.page_num) + if p.page_num == pageNum: + for i in p.items: + contxt = concat_Nitem(i, p) + print_verbose(2, "\n\ncontxt:") + print_verbose(2, contxt) + try: + print_verbose(2, "looking for: " + txt) + index = contxt.index(txt) # get the index substing's first letter + wordIndex = len( + contxt[:index].strip().split() + ) # get str before substring's 1st letter, split it by word,length = index of substring's 1st word + print_verbose(2, "wordIndex:") + print_verbose(2, wordIndex) + # print_verbose(2, "input txt:" + txt) + res.extend( + ( + i.words[wordIndex].rect.get_coordinates()[0] / p.page_width, + i.words[wordIndex].rect.get_coordinates()[1] / p.page_height, + ) + ) + print_verbose(2, res) + ##### TODO: Also compare the paragraph from the CSV, in order to get the best result if we have multiple matches !!! + return res + # if(i.txt == txt): + # res.extend(i.get_coordis()) #result[158.27,115.37] + except ValueError: + print_verbose(2, "substring not found") + except IndexError: + print_verbose(2, "list index out of range") + return res + + +def concat_Nitem(item, page): + res = item.txt + cur_item = item + while cur_item.next_id != -1: + for i in page.items: + if i.this_id == cur_item.next_id: # identify the next line of item + res += " " + i.txt + cur_item = i # update condition of while + break + return res.replace("\n", " ") - test_data = TestData() - test_data.generate_dummy_test_data(config.global_raw_pdf_folder, '*') - - return test_data - -def analyze_pdf(pdf_file, pageNum, txt, info_file_contents, force_pdf_convert=False, force_parse_pdf=False, assume_conversion_done=False, do_wait=False): - - print_verbose(1, "Analyzing PDF: " +str(pdf_file)) - - htmldir_path = get_html_out_dir(pdf_file)#get pdf - os.makedirs(htmldir_path, exist_ok=True) #ceate directory recursively - - reload_neccessary = True # why use this? - - dir = HTMLDirectory() - if(not assume_conversion_done): - #convert pdf to html - print_big("Convert PDF to HTML", do_wait) - if(force_pdf_convert or not file_exists(htmldir_path+'/index.html')): - HTMLDirectory.convert_pdf_to_html(pdf_file, info_file_contents) #04.11.2022 - - # parse and create json and png - print_big("Convert HTML to JSON and PNG", do_wait) - - #dir = HTMLDirectory() - if(force_parse_pdf or get_num_of_files(htmldir_path+'/jpage*.json') != get_num_of_files(htmldir_path+'/page*.html') ): - dir.parse_html_directory(get_html_out_dir(pdf_file), 'page*.html') # ! page* - dir.render_to_png(htmldir_path, htmldir_path) - dir.save_to_dir(htmldir_path) - - - - # load json files - print_big("Load from JSON", do_wait) - if(reload_neccessary): - dir = HTMLDirectory() - dir.load_from_dir(htmldir_path, 'jpage*' + str(pageNum) + '.json') - - - # get coordinates - print_big("get coordinates", do_wait) - res = [] - index = None - this_id = 0 - for p in dir.htmlpages: - #print_verbose(2, p.page_num) - if(p.page_num == pageNum): - for i in p.items: - contxt = concat_Nitem(i,p) - print_verbose(2, "\n\ncontxt:") - print_verbose(2, contxt) - try: - print_verbose(2, "looking for: " + txt) - index = contxt.index(txt) #get the index substing's first letter - wordIndex = len(contxt[:index].strip().split()) #get str before substring's 1st letter, split it by word,length = index of substring's 1st word - print_verbose(2, "wordIndex:" ) - print_verbose(2, wordIndex) - #print_verbose(2, "input txt:" + txt) - res.extend((i.words[wordIndex].rect.get_coordinates()[0]/p.page_width,i.words[wordIndex].rect.get_coordinates()[1]/p.page_height)) - print_verbose(2, res) - ##### TODO: Also compare the paragraph from the CSV, in order to get the best result if we have multiple matches !!! - return res - #if(i.txt == txt): - #res.extend(i.get_coordis()) #result[158.27,115.37] - except ValueError: - print_verbose(2, "substring not found") - except IndexError: - print_verbose(2, "list index out of range") - return res - -def concat_Nitem(item,page): - res = item.txt - cur_item = item - while(cur_item.next_id != -1): - for i in page.items: - if(i.this_id == cur_item.next_id): #identify the next line of item - res += ' ' + i.txt - cur_item = i #update condition of while - break - return res.replace('\n', ' ') def get_input_variable(val, desc): - if val is None: - val = input(desc) + if val is None: + val = input(desc) - if(val is None or val==""): - print_verbose(0, "This must not be empty") - sys.exit(0) - - return val - + if val is None or val == "": + print_verbose(0, "This must not be empty") + sys.exit(0) -# def read_csv(csv): - # with open(csv) as csvFile: #open file - # reader = csv.reader(csvFile) #buil csv reader - # rows = [row for row in reader] - # data = np.array(rows) #convert data format from list to array - -def modify_csv(csv, info_file_contents): - csvPD = pd.read_csv(csv, encoding='utf-8') #Building a csv reader - coordis = None - for c in range(len(csvPD)): #check columns - #print(str(c) + str(csvPD['PDF_NAME'][c])) - if str(csvPD['POS_X'][c])=="nan" or str(csvPD['POS_Y'][c]) == "nan": - coordis = analyze_pdf(config.global_raw_pdf_folder + str(csvPD['PDF_NAME'][c]), int(csvPD['PAGE'][c]), str(csvPD['ANSWER_RAW'][c]), info_file_contents, assume_conversion_done=False, force_parse_pdf=False) - print_verbose(2, "coord:") - print_verbose(2, coordis) - if(len(coordis)>0): - csvPD['POS_X'][c] = coordis[0] - csvPD['POS_Y'][c] = coordis[1] - #df = pd.DataFrame(coordis, encoding = 'utf-8-sig') #initical data as dataframe - csvPD.to_csv(csv, index=False) - print_verbose(2, csvPD.to_csv(csv, index=False)) + return val +# def read_csv(csv): +# with open(csv) as csvFile: #open file +# reader = csv.reader(csvFile) #buil csv reader +# rows = [row for row in reader] +# data = np.array(rows) #convert data format from list to array -def main(): +def modify_csv(csv, info_file_contents): + csvPD = pd.read_csv(csv, encoding="utf-8") # Building a csv reader + coordis = None + for c in range(len(csvPD)): # check columns + # print(str(c) + str(csvPD['PDF_NAME'][c])) + if str(csvPD["POS_X"][c]) == "nan" or str(csvPD["POS_Y"][c]) == "nan": + coordis = analyze_pdf( + config.global_raw_pdf_folder + str(csvPD["PDF_NAME"][c]), + int(csvPD["PAGE"][c]), + str(csvPD["ANSWER_RAW"][c]), + info_file_contents, + assume_conversion_done=False, + force_parse_pdf=False, + ) + print_verbose(2, "coord:") + print_verbose(2, coordis) + if len(coordis) > 0: + csvPD["POS_X"][c] = coordis[0] + csvPD["POS_Y"][c] = coordis[1] + # df = pd.DataFrame(coordis, encoding = 'utf-8-sig') #initical data as dataframe + csvPD.to_csv(csv, index=False) + print_verbose(2, csvPD.to_csv(csv, index=False)) - #parse input parameters - - parser = argparse.ArgumentParser(description='coordinates extraction') - - parser.add_argument('--raw_pdf_folder', - type=str, - default=None, - help='Folder where PDFs are stored') - parser.add_argument('--working_folder', - type=str, - default=None, - help='Folder where working files are stored') - parser.add_argument('--pdf_name', - type=str, - default=None, - help='name of pdf which you want to check') - parser.add_argument('--csv_name', - type=str, - default=None, - help='name of csv file') - # parser.add_argument('--page_number', - # type=int, - # default=None, - # help='in which page of pdf you want to find the coordinates of text') - # parser.add_argument('--text', - # type=str, - # default=None, - # help='for which text you want to find the coordinates') - parser.add_argument('--output_folder', - type=str, - default=None, - help='Folder where output is stored') - parser.add_argument('--verbosity', - type=int, - default=1, - help='Verbosity level (0=shut up)') - - args = parser.parse_args() - config.global_raw_pdf_folder = remove_trailing_slash(get_input_variable(args.raw_pdf_folder, "What is the raw pdf folder?")).replace('\\', '/') + r'/' - config.global_working_folder = remove_trailing_slash(get_input_variable(args.working_folder, "What is the working folder?")).replace('\\', '/') + r'/' - config.global_pdf_name = get_input_variable(args.pdf_name, "Which pdf do you want to check?") - config.global_csv_name = get_input_variable(args.csv_name, "Which csv file do you want to check?") - #config.global_page_number = get_input_variable(args.page_number, "Which page do you want to check?") - #config.global_text = get_input_variable(args.text, "For which text do you want to find the x, y coordinates?") - config.global_output_folder = remove_trailing_slash(get_input_variable(args.output_folder, "What is the output folder?")).replace('\\', '/') + r'/' - config.global_verbosity = args.verbosity - - os.makedirs(config.global_working_folder, exist_ok=True) - os.makedirs(config.global_output_folder, exist_ok=True) - - # fix config.global_exec_folder and config.global_rendering_font_override - path = '' - try: - path = globals()['_dh'][0] - except KeyError: - path = os.path.dirname(os.path.realpath(__file__)) - path = remove_trailing_slash(path).replace('\\', '/') - - config.global_exec_folder = path+ r'/' - config.global_rendering_font_override = path + r'/' + config.global_rendering_font_override - - print_verbose(1, "Using config.global_exec_folder=" + config.global_exec_folder) - print_verbose(1, "Using config.global_raw_pdf_folder=" + config.global_raw_pdf_folder) - print_verbose(1, "Using config.global_working_folder=" + config.global_working_folder) - print_verbose(1, "Using config.global_output_folder=" + config.global_output_folder) - print_verbose(1, "Using config.global_verbosity=" + str(config.global_verbosity)) - print_verbose(5, "Using config.global_rendering_font_override=" + config.global_rendering_font_override) - - test_data = generate_dummy_test_data() - - print_big("Data-set", False) - print_verbose(1, test_data) - - - - - info_file_contents = DataImportExport.load_info_file_contents(remove_trailing_slash(config.global_working_folder) + '/info.json') - - time_start = time.time() - - - coordisresults = [] - #cur_coordisresults = analyze_pdf(config.global_raw_pdf_folder + config.global_pdf_name, config.global_page_number, config.global_text, info_file_contents, assume_conversion_done=False, force_parse_pdf=False) #analyse input data - modify_csv(config.global_csv_name, info_file_contents) - #print(cur_coordisresults) #debugging empty, so analyze_pdf method didn't work as expected - #coordisresults.extend(cur_coordisresults) - #print(coordisresults) #debugging empty - print_verbose(1, "RESULT FOR " + config.global_pdf_name) - print_verbose(1, coordisresults) - - - time_finish = time.time() - - total_time = time_finish - time_start - -main() +def main(): + # parse input parameters + + parser = argparse.ArgumentParser(description="coordinates extraction") + + parser.add_argument("--raw_pdf_folder", type=str, default=None, help="Folder where PDFs are stored") + parser.add_argument("--working_folder", type=str, default=None, help="Folder where working files are stored") + parser.add_argument("--pdf_name", type=str, default=None, help="name of pdf which you want to check") + parser.add_argument("--csv_name", type=str, default=None, help="name of csv file") + # parser.add_argument('--page_number', + # type=int, + # default=None, + # help='in which page of pdf you want to find the coordinates of text') + # parser.add_argument('--text', + # type=str, + # default=None, + # help='for which text you want to find the coordinates') + parser.add_argument("--output_folder", type=str, default=None, help="Folder where output is stored") + parser.add_argument("--verbosity", type=int, default=1, help="Verbosity level (0=shut up)") + + args = parser.parse_args() + config.global_raw_pdf_folder = ( + remove_trailing_slash(get_input_variable(args.raw_pdf_folder, "What is the raw pdf folder?")).replace("\\", "/") + + r"/" + ) + config.global_working_folder = ( + remove_trailing_slash(get_input_variable(args.working_folder, "What is the working folder?")).replace("\\", "/") + + r"/" + ) + config.global_pdf_name = get_input_variable(args.pdf_name, "Which pdf do you want to check?") + config.global_csv_name = get_input_variable(args.csv_name, "Which csv file do you want to check?") + # config.global_page_number = get_input_variable(args.page_number, "Which page do you want to check?") + # config.global_text = get_input_variable(args.text, "For which text do you want to find the x, y coordinates?") + config.global_output_folder = ( + remove_trailing_slash(get_input_variable(args.output_folder, "What is the output folder?")).replace("\\", "/") + + r"/" + ) + config.global_verbosity = args.verbosity + + os.makedirs(config.global_working_folder, exist_ok=True) + os.makedirs(config.global_output_folder, exist_ok=True) + + # fix config.global_exec_folder and config.global_rendering_font_override + path = "" + try: + path = globals()["_dh"][0] + except KeyError: + path = os.path.dirname(os.path.realpath(__file__)) + path = remove_trailing_slash(path).replace("\\", "/") + + config.global_exec_folder = path + r"/" + config.global_rendering_font_override = path + r"/" + config.global_rendering_font_override + + print_verbose(1, "Using config.global_exec_folder=" + config.global_exec_folder) + print_verbose(1, "Using config.global_raw_pdf_folder=" + config.global_raw_pdf_folder) + print_verbose(1, "Using config.global_working_folder=" + config.global_working_folder) + print_verbose(1, "Using config.global_output_folder=" + config.global_output_folder) + print_verbose(1, "Using config.global_verbosity=" + str(config.global_verbosity)) + print_verbose(5, "Using config.global_rendering_font_override=" + config.global_rendering_font_override) + + test_data = generate_dummy_test_data() + + print_big("Data-set", False) + print_verbose(1, test_data) + + info_file_contents = DataImportExport.load_info_file_contents( + remove_trailing_slash(config.global_working_folder) + "/info.json" + ) + + time_start = time.time() + + coordisresults = [] + # cur_coordisresults = analyze_pdf(config.global_raw_pdf_folder + config.global_pdf_name, config.global_page_number, config.global_text, info_file_contents, assume_conversion_done=False, force_parse_pdf=False) #analyse input data + modify_csv(config.global_csv_name, info_file_contents) + # print(cur_coordisresults) #debugging empty, so analyze_pdf method didn't work as expected + # coordisresults.extend(cur_coordisresults) + # print(coordisresults) #debugging empty + print_verbose(1, "RESULT FOR " + config.global_pdf_name) + print_verbose(1, coordisresults) + + time_finish = time.time() + + total_time = time_finish - time_start + + +main() diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/rb_server.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/rb_server.py index fee8a88..1f5394b 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/rb_server.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/rb_server.py @@ -17,42 +17,49 @@ def create_directory(directory_name): if os.path.isfile(file_path): os.unlink(file_path) except Exception as e: - print('Failed to delete %s. Reason: %s' % (file_path, e)) + print("Failed to delete %s. Reason: %s" % (file_path, e)) def run_rb_int(raw_pdf_folder, working_folder, output_folder, verbosity): - cmd = 'python3 /app/code/rule_based_pipeline/rule_based_pipeline/main.py' + \ - ' --raw_pdf_folder "' + raw_pdf_folder + '"' + \ - ' --working_folder "' + working_folder + '"' + \ - ' --output_folder "' + output_folder + '"' + \ - ' --verbosity ' + str(verbosity) + cmd = ( + "python3 /app/code/rule_based_pipeline/rule_based_pipeline/main.py" + + ' --raw_pdf_folder "' + + raw_pdf_folder + + '"' + + ' --working_folder "' + + working_folder + + '"' + + ' --output_folder "' + + output_folder + + '"' + + " --verbosity " + + str(verbosity) + ) print("Running command: " + cmd) os.system(cmd) def run_rb(project_name, verbosity, s3_usage, s3_settings): - base = r'/app/data/' + project_name - raw_pdf_folder = base + r'/interim/pdfs/' - working_folder = base + r'/interim/rb/work' - output_folder = base + r'/output/KPI_EXTRACTION/rb' + base = r"/app/data/" + project_name + raw_pdf_folder = base + r"/interim/pdfs/" + working_folder = base + r"/interim/rb/work" + output_folder = base + r"/output/KPI_EXTRACTION/rb" if s3_usage: s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) create_directory(base) create_directory(raw_pdf_folder) create_directory(working_folder) create_directory(output_folder) - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/pdfs/inference', - raw_pdf_folder) + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/pdfs/inference", raw_pdf_folder) run_rb_int(raw_pdf_folder, working_folder, output_folder, verbosity) if s3_usage: - s3c_main.upload_files_in_dir_to_prefix(output_folder, - project_prefix + '/output/KPI_EXTRACTION/rb') + s3c_main.upload_files_in_dir_to_prefix(output_folder, project_prefix + "/output/KPI_EXTRACTION/rb") return True @@ -64,13 +71,13 @@ def liveness(): @app.route("/run") def run(): try: - args = json.loads(request.args['payload']) - project_name = args['project_name'] + args = json.loads(request.args["payload"]) + project_name = args["project_name"] s3_settings = None - if args['s3_usage']: + if args["s3_usage"]: s3_settings = args["s3_settings"] - verbosity = int(args['verbosity']) - run_rb(project_name, verbosity, args['s3_usage'], s3_settings) + verbosity = int(args["verbosity"]) + run_rb(project_name, verbosity, args["s3_usage"], s3_settings) return Response(response={}, status=200) except Exception as e: m = traceback.format_exc() @@ -79,52 +86,63 @@ def run(): @app.route("/run_xy_ml") def run_xy_ml(): - args = json.loads(request.args['payload']) - project_name = args['project_name'] - base = r'/app/data/' + project_name - raw_pdf_folder = base + r'/interim/pdfs/' - working_folder = base + r'/interim/rb/work' - output_folder = base + r'/output/KPI_EXTRACTION/joined_ml_rb' - csv_path = output_folder + '/' + args['csv_name'] - if args['s3_usage']: + args = json.loads(request.args["payload"]) + project_name = args["project_name"] + base = r"/app/data/" + project_name + raw_pdf_folder = base + r"/interim/pdfs/" + working_folder = base + r"/interim/rb/work" + output_folder = base + r"/output/KPI_EXTRACTION/joined_ml_rb" + csv_path = output_folder + "/" + args["csv_name"] + if args["s3_usage"]: s3_settings = args["s3_settings"] s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) create_directory(base) create_directory(raw_pdf_folder) create_directory(working_folder) create_directory(output_folder) - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/pdfs/inference', - raw_pdf_folder) - s3c_main.download_file_from_s3(csv_path, project_prefix + '/output/KPI_EXTRACTION/joined_ml_rb', args['csv_name']) - cmd = 'python3 /app/code/rule_based_pipeline/rule_based_pipeline/main_find_xy.py' + \ - ' --raw_pdf_folder "' + raw_pdf_folder + '"' + \ - ' --working_folder "' + working_folder + '"' + \ - ' --pdf_name "' + args['pdf_name'] + '"' + \ - ' --csv_name "' + csv_path + '"' + \ - ' --output_folder "' + output_folder + '"' + \ - ' --verbosity ' + str(args['verbosity']) + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/pdfs/inference", raw_pdf_folder) + s3c_main.download_file_from_s3( + csv_path, project_prefix + "/output/KPI_EXTRACTION/joined_ml_rb", args["csv_name"] + ) + cmd = ( + "python3 /app/code/rule_based_pipeline/rule_based_pipeline/main_find_xy.py" + + ' --raw_pdf_folder "' + + raw_pdf_folder + + '"' + + ' --working_folder "' + + working_folder + + '"' + + ' --pdf_name "' + + args["pdf_name"] + + '"' + + ' --csv_name "' + + csv_path + + '"' + + ' --output_folder "' + + output_folder + + '"' + + " --verbosity " + + str(args["verbosity"]) + ) print("Running command: " + cmd) os.system(cmd) - if args['s3_usage']: - s3c_main.upload_file_to_s3(filepath=csv_path, - s3_prefix=project_prefix + '/output/KPI_EXTRACTION/joined_ml_rb', - s3_key=args['csv_name']) + if args["s3_usage"]: + s3c_main.upload_file_to_s3( + filepath=csv_path, s3_prefix=project_prefix + "/output/KPI_EXTRACTION/joined_ml_rb", s3_key=args["csv_name"] + ) return Response(response={}, status=200) if __name__ == "__main__": - parser = argparse.ArgumentParser(description='rb server') + parser = argparse.ArgumentParser(description="rb server") # Add the arguments - parser.add_argument('--port', - type=int, - default=8000, - help='port to use for the infer server') + parser.add_argument("--port", type=int, default=8000, help="port to use for the infer server") args = parser.parse_args() port = args.port app.run(host="0.0.0.0", port=port) diff --git a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/test.py b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/test.py index 8a736e6..6c0bfc8 100644 --- a/data_extractor/code/rule_based_pipeline/rule_based_pipeline/test.py +++ b/data_extractor/code/rule_based_pipeline/rule_based_pipeline/test.py @@ -18,285 +18,949 @@ from TestData import * from DataImportExport import * from TestEvaluation import * -from test import * #only for testing / debugging purpose +from test import * # only for testing / debugging purpose + def test(pdf_file, wildcard): - - dir = HTMLDirectory() - dir.parse_html_directory(get_html_out_dir(pdf_file), r'page' + str(wildcard) + '.html') - dir.render_to_png(get_html_out_dir(pdf_file), get_html_out_dir(pdf_file)) - dir.save_to_dir(get_html_out_dir(pdf_file)) + dir = HTMLDirectory() + dir.parse_html_directory(get_html_out_dir(pdf_file), r"page" + str(wildcard) + ".html") + dir.render_to_png(get_html_out_dir(pdf_file), get_html_out_dir(pdf_file)) + dir.save_to_dir(get_html_out_dir(pdf_file)) def test_convert_pdf(pdf_file): - HTMLDirectory.convert_pdf_to_html(pdf_file) - - dir = HTMLDirectory() - dir.parse_html_directory(get_html_out_dir(pdf_file), r'page*.html') - dir.save_to_dir(get_html_out_dir(pdf_file)) - - -def test_load_json(pdf_file, wildcard): - - dir = HTMLDirectory() - dir.load_from_dir(get_html_out_dir(pdf_file), 'jpage' + str(wildcard) + '.json') - #dir.render_to_png(get_html_out_dir(pdf_file), get_html_out_dir(pdf_file)) - - return dir - -def test_save_json(dir): - dir.save_to_dir(get_html_out_dir(pdf_file)) - - -def test_print_all_clusters(htmldir): - for p in htmldir.htmlpages: - print(p.clusters_text) - - - -# -# Only used for initial testing -# -def test_prepare_kpispecs(): - # TODO: This should be read from JSON files, but for now we can simply define it in code: - - - - - def prepare_kpi_2_0_provable_plus_probable_reserves(): - # KPI 2.0 = proved plus probable reserves - - kpi = KPISpecs() - kpi.kpi_id = 2.0 - kpi.kpi_name = 'Proven or probable reserves (Total hydrocarbons)' - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*prov.*develop.*undevelop.*reserv.*',score=12000,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('proven develop or undevelop reserv'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*prov.*reserv.*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('proven reserves'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*total.*hydrocarbon.*', score=15000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('total hydrocarbon'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(total|combine).*', score=2500, matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0.9, letter_decay_hl = 10, letter_decay_disregard = len('total'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*total.*(prov|prob).*reserv', score=5000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('total proven reserves'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*total.*reserv', score=4000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('total reserves'))) - - #TODO: Check if we should add P50, like for KPI 2.1! - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*2p[\s]*reserv.*', score=4000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = len('2p reserves'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*prov.*probab.*', score=3000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = len('proved and proable'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='(.*(prov.*prob|prob.*prov).*|^((?!pro(b|v)).)*$)', score=1, matching_mode=MATCHING_MUST_INCLUDE_EACH_NODE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = True)) - - - - kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(boe|barrel.*oil|(b|m)illion.*barrel).*',case_sensitive=False)) - kpi.value_must_be_numeric = True - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*exploration.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 2000, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('exploration'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*upstream.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 2000, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('upstream'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='^balance sheet.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 2000, matching_mode = MATCHING_MUST_EXCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = 0)) - - #kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*(prov|prob).*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 2000, matching_mode = MATCHING_MUST_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 8)) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*((^|[^a-z])prov|(^|[^a-z])2p($|[^a-z])).*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 500, matching_mode = MATCHING_MUST_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('prov'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*((^|[^a-z])prob|(^|[^a-z])2p($|[^a-z])).*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 500, matching_mode = MATCHING_MUST_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('prob'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*((^|[^a-z])pro(b|v).*(^|[^a-z])pro(b|v)|(^|[^a-z])2p($|[^a-z])).*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 2500, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 10, letter_decay_disregard = len('prov prob'))) - - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - return kpi - - - def prepare_kpi_2_1_provable_reserves(): - # KPI 2.1 = proved reserves - kpi = KPISpecs() - kpi.kpi_id = 2.1 - kpi.kpi_name = 'Proven reserves (Total hydrocarbons)' - - # - # TODO : Add kpi description here (similar to the procedure above!) - # - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])prov.*develop.*undevelop.*reserv.*',score=12000,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('proven develop or undevelop reserv'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])prov.*reserv.*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('proven reserves'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])prov.*reserv.*oil.*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])prov.*reserv.*(oil.*gas|gas.*oil).*',score=15000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*total.*hydrocarbon.*', score=1500, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('total hydrocarbon'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(total|combine).*', score=1500, matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0.9, letter_decay_hl = 10, letter_decay_disregard = len('total'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*total.*prov.*reserv', score=5000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('total proven reserves'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*total.*reserv', score=4000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('total reserves'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*1p[\s]*(reserv|.*p90).*', score=4000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = len('1p reserves'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])prov.*', score=3000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = len('proved and proable'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])sec.*',score=1000,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('proven reserves sec'))) - - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='^((?!prob).)*$', score=1, matching_mode=MATCHING_MUST_INCLUDE_EACH_NODE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = True)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*prms.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = True)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*change.*(pro|reserv).*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - - - - kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(boe|barrel.*oil|(b|m)illion.*barrel).*',case_sensitive=False)) - kpi.value_must_be_numeric = True - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*exploration.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 2000, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('exploration'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*upstream.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 2000, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('upstream'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='^balance sheet.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 2000, matching_mode = MATCHING_MUST_EXCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = 0)) + HTMLDirectory.convert_pdf_to_html(pdf_file) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*((^|[^a-z])prov|(^|[^a-z])1p($|[^a-z])).*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 500, matching_mode = MATCHING_MUST_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('prov'))) - - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - return kpi + dir = HTMLDirectory() + dir.parse_html_directory(get_html_out_dir(pdf_file), r"page*.html") + dir.save_to_dir(get_html_out_dir(pdf_file)) - - - def prepare_kpi_3_production(): - # KPI 3 = production - kpi = KPISpecs() - kpi.kpi_id = 3 - kpi.kpi_name = 'Production' - - # - # TODO : Add kpi description here (similar to the procedure above!) - # - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*total.*production.*',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('total production'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*hydrocarbon.*production.*',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('hydrocarbon production'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*oil.*production.*',score=3000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('oil production'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*interest.*production.*',score=3000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('interest production'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*group.*',score=3000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('group production'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*apr.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*may.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*crease.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*change.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(total|combine).*', score=1500, matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0.9, letter_decay_hl = 10, letter_decay_disregard = len('total'))) - - - kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(boe|barrel.*oil|(b|m)illion.*barrel|tonnes).*',case_sensitive=False)) - - - kpi.value_must_be_numeric = True - #kpi.value_percentage_match = VALUE_PERCENTAGE_MUST_NOT - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*production.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 100, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('production'))) - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - - return kpi - - - - def prepare_kpi_3_1_oil_production(): - # KPI 3 = production - kpi = KPISpecs() - kpi.kpi_id = 3.1 - kpi.kpi_name = 'Oil Production' - - # - # TODO : Add kpi description here (similar to the procedure above!) - # - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*production.*',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('production'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*oil.*',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl = 20, letter_decay_disregard = len('oil production'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(total|combine).*', score=1500, matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0.9, letter_decay_hl = 10, letter_decay_disregard = len('total'))) - - - kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(boe|barrel.*oil|(b|m)illion.*barrel|tonnes).*',case_sensitive=False)) - - - kpi.value_must_be_numeric = True - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*production.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 100, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('production'))) - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - - return kpi +def test_load_json(pdf_file, wildcard): + dir = HTMLDirectory() + dir.load_from_dir(get_html_out_dir(pdf_file), "jpage" + str(wildcard) + ".json") + # dir.render_to_png(get_html_out_dir(pdf_file), get_html_out_dir(pdf_file)) + return dir - def prepare_kpi_3_2_liquid_hydrocarbons_production(): - # KPI 3 = production - kpi = KPISpecs() - kpi.kpi_id = 3.2 - kpi.kpi_name = 'Liquid Hydrocarbons Production' - - # - # TODO : Add kpi description here (similar to the procedure above!) - # - - #kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*production.*',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('production'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*liquid.*hydrocarbon.*',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('liquid hydrocarbon'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(total|combine).*', score=1500, matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0.9, letter_decay_hl = 10, letter_decay_disregard = len('total'))) - - - kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(boe|barrel.*oil|(b|m)illion.*barrel|ton|mt).*',case_sensitive=False)) - - - kpi.value_must_be_numeric = True - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*production.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 100, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('production'))) - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - - return kpi +def test_save_json(dir): + dir.save_to_dir(get_html_out_dir(pdf_file)) +def test_print_all_clusters(htmldir): + for p in htmldir.htmlpages: + print(p.clusters_text) - def prepare_kpi_3_3_gas_production(): - # KPI 3 = production - kpi = KPISpecs() - kpi.kpi_id = 3.3 - kpi.kpi_name = 'Gas Production' - - # - # TODO : Add kpi description here (similar to the procedure above!) - # - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*production.*of.*gas.*',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('production gas'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*gas.*',score=3000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl = 20, letter_decay_disregard = len('gas'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*production.*',score=3000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('production'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*gas.*production.*',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl=20, letter_decay_disregard = len('gas production'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(total|combine).*', score=1500, matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0.9, letter_decay_hl = 10, letter_decay_disregard = len('total'))) - - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*for.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*emission.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - - - kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(boe|barrel.*oil|(b|m)illion.*barrel|ton|mt|million|cm).*',case_sensitive=False)) - - - kpi.value_must_be_numeric = True - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*production.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 100, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('production'))) - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - - return kpi - def prepare_Scope1_kpi_6_Direct_total_GHG_emissions(): - # KPI 6 = Scope 1 / Direct total GHGs emissions - kpi = KPISpecs() - kpi.kpi_id = 6 - kpi.kpi_name = 'Scope 1 / Direct total GHGs emissions' - - # Match paragraphs - - ''' +# +# Only used for initial testing +# +def test_prepare_kpispecs(): + # TODO: This should be read from JSON files, but for now we can simply define it in code: + + def prepare_kpi_2_0_provable_plus_probable_reserves(): + # KPI 2.0 = proved plus probable reserves + + kpi = KPISpecs() + kpi.kpi_id = 2.0 + kpi.kpi_name = "Proven or probable reserves (Total hydrocarbons)" + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*prov.*develop.*undevelop.*reserv.*", + score=12000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("proven develop or undevelop reserv"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*prov.*reserv.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("proven reserves"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*total.*hydrocarbon.*", + score=15000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("total hydrocarbon"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(total|combine).*", + score=2500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0.9, + letter_decay_hl=10, + letter_decay_disregard=len("total"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*total.*(prov|prob).*reserv", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("total proven reserves"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*total.*reserv", + score=4000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("total reserves"), + ) + ) + + # TODO: Check if we should add P50, like for KPI 2.1! + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*2p[\s]*reserv.*", + score=4000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=len("2p reserves"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*prov.*probab.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=len("proved and proable"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw="(.*(prov.*prob|prob.*prov).*|^((?!pro(b|v)).)*$)", + score=1, + matching_mode=MATCHING_MUST_INCLUDE_EACH_NODE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=True, + ) + ) + + kpi.unit_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw=".*(boe|barrel.*oil|(b|m)illion.*barrel).*", case_sensitive=False) + ) + kpi.value_must_be_numeric = True + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*exploration.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=2000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("exploration"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*upstream.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=2000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("upstream"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw="^balance sheet.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=2000, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=0, + ) + ) + + # kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*(prov|prob).*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 2000, matching_mode = MATCHING_MUST_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 8)) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch( + pattern_raw=".*((^|[^a-z])prov|(^|[^a-z])2p($|[^a-z])).*", case_sensitive=False + ), + distance_mode=DISTANCE_MOD_EUCLID, + score=500, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len("prov"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch( + pattern_raw=".*((^|[^a-z])prob|(^|[^a-z])2p($|[^a-z])).*", case_sensitive=False + ), + distance_mode=DISTANCE_MOD_EUCLID, + score=500, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len("prob"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch( + pattern_raw=".*((^|[^a-z])pro(b|v).*(^|[^a-z])pro(b|v)|(^|[^a-z])2p($|[^a-z])).*", + case_sensitive=False, + ), + distance_mode=DISTANCE_MOD_EUCLID, + score=2500, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=10, + letter_decay_disregard=len("prov prob"), + ) + ) + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_2_1_provable_reserves(): + # KPI 2.1 = proved reserves + kpi = KPISpecs() + kpi.kpi_id = 2.1 + kpi.kpi_name = "Proven reserves (Total hydrocarbons)" + + # + # TODO : Add kpi description here (similar to the procedure above!) + # + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])prov.*develop.*undevelop.*reserv.*", + score=12000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("proven develop or undevelop reserv"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])prov.*reserv.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("proven reserves"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])prov.*reserv.*oil.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])prov.*reserv.*(oil.*gas|gas.*oil).*", + score=15000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*total.*hydrocarbon.*", + score=1500, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("total hydrocarbon"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(total|combine).*", + score=1500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0.9, + letter_decay_hl=10, + letter_decay_disregard=len("total"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*total.*prov.*reserv", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("total proven reserves"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*total.*reserv", + score=4000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("total reserves"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*1p[\s]*(reserv|.*p90).*", + score=4000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=len("1p reserves"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])prov.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=len("proved and proable"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])sec.*", + score=1000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("proven reserves sec"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw="^((?!prob).)*$", + score=1, + matching_mode=MATCHING_MUST_INCLUDE_EACH_NODE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=True, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*prms.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=True, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*change.*(pro|reserv).*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + + kpi.unit_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw=".*(boe|barrel.*oil|(b|m)illion.*barrel).*", case_sensitive=False) + ) + kpi.value_must_be_numeric = True + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*exploration.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=2000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("exploration"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*upstream.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=2000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("upstream"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw="^balance sheet.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=2000, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=0, + ) + ) + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch( + pattern_raw=".*((^|[^a-z])prov|(^|[^a-z])1p($|[^a-z])).*", case_sensitive=False + ), + distance_mode=DISTANCE_MOD_EUCLID, + score=500, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len("prov"), + ) + ) + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_3_production(): + # KPI 3 = production + kpi = KPISpecs() + kpi.kpi_id = 3 + kpi.kpi_name = "Production" + + # + # TODO : Add kpi description here (similar to the procedure above!) + # + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*total.*production.*", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("total production"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*hydrocarbon.*production.*", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("hydrocarbon production"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*oil.*production.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("oil production"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*interest.*production.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("interest production"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*group.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("group production"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*apr.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*may.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*crease.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*change.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(total|combine).*", + score=1500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0.9, + letter_decay_hl=10, + letter_decay_disregard=len("total"), + ) + ) + + kpi.unit_regex_match_list.append( + KPISpecs.GeneralRegExMatch( + pattern_raw=".*(boe|barrel.*oil|(b|m)illion.*barrel|tonnes).*", case_sensitive=False + ) + ) + + kpi.value_must_be_numeric = True + # kpi.value_percentage_match = VALUE_PERCENTAGE_MUST_NOT + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*production.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=100, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("production"), + ) + ) + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_3_1_oil_production(): + # KPI 3 = production + kpi = KPISpecs() + kpi.kpi_id = 3.1 + kpi.kpi_name = "Oil Production" + + # + # TODO : Add kpi description here (similar to the procedure above!) + # + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*production.*", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("production"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*oil.*", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=20, + letter_decay_disregard=len("oil production"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(total|combine).*", + score=1500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0.9, + letter_decay_hl=10, + letter_decay_disregard=len("total"), + ) + ) + + kpi.unit_regex_match_list.append( + KPISpecs.GeneralRegExMatch( + pattern_raw=".*(boe|barrel.*oil|(b|m)illion.*barrel|tonnes).*", case_sensitive=False + ) + ) + + kpi.value_must_be_numeric = True + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*production.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=100, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("production"), + ) + ) + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_3_2_liquid_hydrocarbons_production(): + # KPI 3 = production + kpi = KPISpecs() + kpi.kpi_id = 3.2 + kpi.kpi_name = "Liquid Hydrocarbons Production" + + # + # TODO : Add kpi description here (similar to the procedure above!) + # + + # kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*production.*',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('production'))) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*liquid.*hydrocarbon.*", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("liquid hydrocarbon"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(total|combine).*", + score=1500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0.9, + letter_decay_hl=10, + letter_decay_disregard=len("total"), + ) + ) + + kpi.unit_regex_match_list.append( + KPISpecs.GeneralRegExMatch( + pattern_raw=".*(boe|barrel.*oil|(b|m)illion.*barrel|ton|mt).*", case_sensitive=False + ) + ) + + kpi.value_must_be_numeric = True + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*production.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=100, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("production"), + ) + ) + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_3_3_gas_production(): + # KPI 3 = production + kpi = KPISpecs() + kpi.kpi_id = 3.3 + kpi.kpi_name = "Gas Production" + + # + # TODO : Add kpi description here (similar to the procedure above!) + # + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*production.*of.*gas.*", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("production gas"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*gas.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=20, + letter_decay_disregard=len("gas"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*production.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("production"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*gas.*production.*", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("gas production"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(total|combine).*", + score=1500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0.9, + letter_decay_hl=10, + letter_decay_disregard=len("total"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*for.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*emission.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + + kpi.unit_regex_match_list.append( + KPISpecs.GeneralRegExMatch( + pattern_raw=".*(boe|barrel.*oil|(b|m)illion.*barrel|ton|mt|million|cm).*", case_sensitive=False + ) + ) + + kpi.value_must_be_numeric = True + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*production.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=100, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("production"), + ) + ) + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_Scope1_kpi_6_Direct_total_GHG_emissions(): + # KPI 6 = Scope 1 / Direct total GHGs emissions + kpi = KPISpecs() + kpi.kpi_id = 6 + kpi.kpi_name = "Scope 1 / Direct total GHGs emissions" + + # Match paragraphs + + """ kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(ghg)\semission[s]?.*',score=2000 ,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('(GHG) emissions'))) kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(gas|(gas).*emissions?))',score=2000 ,matching_mode=MATCHING_MUST_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('(GHG) emissions'))) @@ -308,461 +972,1168 @@ def prepare_Scope1_kpi_6_Direct_total_GHG_emissions(): kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.\d+\s*(thousand|hundred|million).*(gas|ghg|(ghg)).*emissions.*?(totaled|summed)? ',score=12000 ,matching_mode=MATCHING_MUST_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('(GHG) emissions'))) kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*\s*(thousand|hundred|million)?.*(gas|ghg|(ghg)).*emissions.*(2017)?(totaled|summed|to)?\d+',score=8000 ,matching_mode=MATCHING_MUST_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('(GHG) emissions'))) - ''' - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric|co2).*emissions?.*',score=7000 ,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('gas emissions'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])co2.*emissions?.*tCO2e.*',score=10000 ,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('co2 emissions'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*scope[^a-zA-Z0-9]?1.*',score=12000 ,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('scope 1'))) - - - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric|co2|combustion.*fuels?).*emissions?.*',score=7000 ,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('gas emissions'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])combustion.*fuels?.*',score=6000 ,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('gas emissions'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric).*(direct)(emissions?)?.*', score=9000, matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 15, letter_decay_disregard = len('gas direct'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric).*direct.*scope[^a-zA-Z0-9]?1(emissions?)?.*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('direct scope 1'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*direct.*(gas|ghg|(ghg)|atmospheric).*scope[^a-zA-Z0-9]?1(emissions?)?.*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('direct scope 1'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*scope\s1.*',score=6000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric).*direct.*scope[^a-zA-Z0-9]?1(emissions?)?.*m.*t',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*scope[^a-zA-Z0-9]?2.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - #kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*group.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(((gas|ghg|(ghg)|atmospheric)|direct).*emissions?|scope[^a-zA-Z0-9]?1).*(million\s? tonnes|co2[^a-zA-Z0-9]?eq)',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])((gas|ghg|(ghg)|atmospheric)|direct).*(million\s? tonnes|co2[^a-zA-Z0-9]?eq).*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])direct.*(gas|ghg|(ghg)).*(million\s? tonnes|co2[^a-zA-Z0-9]?(eq|equivalent)).*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(scope[^a-zA-Z0-9]?2|scope[^a-zA-Z0-9]?3).*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*scope[^a-zA-Z0-9]?1,?[^a-zA-Z0-9]?2.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*scope[^a-zA-Z0-9]?1.*relative.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - - - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*^(?=.*(gas|ghg|(ghg)|atmospheric))(?=.*direct)(?=.*scope[^a-zA-Z0-9]?1).*$(emissions?)?.*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*^(?=.*(gas|ghg|(ghg)|atmospheric))(?=.*direct)(?=.*operated).*$(emissions?)?.*',score=6000,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*^(?=.*(gas|ghg|(ghg)|atmospheric))(?=.*direct).*$(emissions?)?.*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])(greenhouse)?.*^(?=.*(gas|ghg|(ghg)|atmospheric))(?=.*scope[^a-zA-Z0-9]?1).*$(emissions?)?.*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('proven reserves of oil and gas'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(^|[^a-z])direct.*emissions?.*',score=12000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 30, letter_decay_disregard = len('direct emissions'))) - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(total|combine).*', score=800, matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl = 10, letter_decay_disregard = len('total indirect ghg scope-2'))) - - - ### .*Direct NO.*emissions.* - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*direct no.*emissions.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(ton|mn|million|kt|m t|co 2|co.*emission).*',case_sensitive=False)) - - #kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*((b|m)illions?.*(tons?|tonnes?).*CO2\s*equivalent|MteCO2e|mil\s*t\s*eq|(b|m)illions?\/teq|emissions.*CO2\s*equivalent).*',case_sensitive=False)) - kpi.value_must_be_numeric=True - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*(direct|ghg|gas).*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 500, matching_mode = MATCHING_CAN_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('direct'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*scope[^a-zA-Z0-9]?1.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 500, matching_mode = MATCHING_CAN_INCLUDE, score_decay = 0.7, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('direct'))) - - - # added in particular for CDP reports: - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*gross global scope 1 emissions.*metric.*ton.*',score=20000 ,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('gross global scope 1 emissions metric ton'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*c6\.1.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 5000, matching_mode = MATCHING_CAN_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('c6.1'))) - kpi.value_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*[0-9].*[0-9].*',case_sensitive=False)) # must contain at least two digits - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - - return kpi - - + """ + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric|co2).*emissions?.*", + score=7000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("gas emissions"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])co2.*emissions?.*tCO2e.*", + score=10000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("co2 emissions"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*scope[^a-zA-Z0-9]?1.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("scope 1"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric|co2|combustion.*fuels?).*emissions?.*", + score=7000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("gas emissions"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])combustion.*fuels?.*", + score=6000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("gas emissions"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric).*(direct)(emissions?)?.*", + score=9000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=15, + letter_decay_disregard=len("gas direct"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric).*direct.*scope[^a-zA-Z0-9]?1(emissions?)?.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("direct scope 1"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*direct.*(gas|ghg|(ghg)|atmospheric).*scope[^a-zA-Z0-9]?1(emissions?)?.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("direct scope 1"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*scope\s1.*", + score=6000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*(gas|ghg|(ghg)|atmospheric).*direct.*scope[^a-zA-Z0-9]?1(emissions?)?.*m.*t", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*scope[^a-zA-Z0-9]?2.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + # kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*group.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(((gas|ghg|(ghg)|atmospheric)|direct).*emissions?|scope[^a-zA-Z0-9]?1).*(million\s? tonnes|co2[^a-zA-Z0-9]?eq)", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])((gas|ghg|(ghg)|atmospheric)|direct).*(million\s? tonnes|co2[^a-zA-Z0-9]?eq).*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])direct.*(gas|ghg|(ghg)).*(million\s? tonnes|co2[^a-zA-Z0-9]?(eq|equivalent)).*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(scope[^a-zA-Z0-9]?2|scope[^a-zA-Z0-9]?3).*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*scope[^a-zA-Z0-9]?1,?[^a-zA-Z0-9]?2.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*scope[^a-zA-Z0-9]?1.*relative.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*^(?=.*(gas|ghg|(ghg)|atmospheric))(?=.*direct)(?=.*scope[^a-zA-Z0-9]?1).*$(emissions?)?.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*^(?=.*(gas|ghg|(ghg)|atmospheric))(?=.*direct)(?=.*operated).*$(emissions?)?.*", + score=6000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*^(?=.*(gas|ghg|(ghg)|atmospheric))(?=.*direct).*$(emissions?)?.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])(greenhouse)?.*^(?=.*(gas|ghg|(ghg)|atmospheric))(?=.*scope[^a-zA-Z0-9]?1).*$(emissions?)?.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("proven reserves of oil and gas"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(^|[^a-z])direct.*emissions?.*", + score=12000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=30, + letter_decay_disregard=len("direct emissions"), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(total|combine).*", + score=800, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=10, + letter_decay_disregard=len("total indirect ghg scope-2"), + ) + ) + + ### .*Direct NO.*emissions.* + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*direct no.*emissions.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.unit_regex_match_list.append( + KPISpecs.GeneralRegExMatch( + pattern_raw=".*(ton|mn|million|kt|m t|co 2|co.*emission).*", case_sensitive=False + ) + ) + + # kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*((b|m)illions?.*(tons?|tonnes?).*CO2\s*equivalent|MteCO2e|mil\s*t\s*eq|(b|m)illions?\/teq|emissions.*CO2\s*equivalent).*',case_sensitive=False)) + kpi.value_must_be_numeric = True + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*(direct|ghg|gas).*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len("direct"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*scope[^a-zA-Z0-9]?1.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.7, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len("direct"), + ) + ) + + # added in particular for CDP reports: + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*gross global scope 1 emissions.*metric.*ton.*", + score=20000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("gross global scope 1 emissions metric ton"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*c6\.1.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=5000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len("c6.1"), + ) + ) + kpi.value_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw=".*[0-9].*[0-9].*", case_sensitive=False) + ) # must contain at least two digits + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_7_Scope2_GHGs_emissions(): + # KPI 7 = Scope 2 Energy indirect total GHGs emissions + kpi = KPISpecs() + kpi.kpi_id = 7 + kpi.kpi_name = "Scope 2 Energy indirect total GHGs emissions" + + # + # TODO : Add kpi description here (similar to the procedure above!) + # + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*s.*cope( |-)2.*market", + score=5000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("scope-2 market "), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*s.*cope( |-)2.*", + score=8000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-2 "), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*indirect.*ghg.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-2 "), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*ghg.*", + score=3000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-2 "), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*indirect.*", + score=3000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-2 "), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*co 2.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("CO2"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*market.*", + score=3000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope 2 "), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*electricity.*", + score=500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-2"), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*indirect.*emissions.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-2 "), + ) + ) # by Lei + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(total|combine).*", + score=1500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=10, + letter_decay_disregard=len("total indirect ghg scope-2"), + ) + ) + + # kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(ton|mn|million|kt|m t|co 2).*', case_sensitive=False)) #by Lei + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*sale.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*s.*cope( |-)3.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(upstream|refin).*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + + # eq does not work + kpi.unit_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw="^(t|.*(ton|mn|million|kt|m t|co 2).*)$", case_sensitive=False) + ) + + kpi.value_must_be_numeric = True + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch( + pattern_raw=".*(environment|emission).*", case_sensitive=False + ), + distance_mode=DISTANCE_MOD_EUCLID, + score=100, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("production"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*total.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=100, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("production"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*tons.*co.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=100, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.05, + multi_match_decay=0.01, + letter_decay_hl=5, + letter_decay_disregard=len("tons co2"), + ) + ) # by Lei + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*million metric.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=200, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.05, + multi_match_decay=0.01, + letter_decay_hl=5, + letter_decay_disregard=len("million metric"), + ) + ) # by Lei + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_8_Scope3_GHGs_emissions(): + # KPI 8 = Scope 3 Upstream Energy indirect total GHGs emissions + kpi = KPISpecs() + kpi.kpi_id = 8 + kpi.kpi_name = "Scope 3 Upstream Energy indirect total GHGs emissions" + + # + # TODO : Add kpi description here (similar to the procedure above!) + # + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*s.*cope( |-)3.*", + score=8000, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-3 "), + ) + ) # by Lei + # kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*s.*cope( |-)3.*',score=8000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-3 '))) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*ghg.*", + score=3000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-3 "), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*indirect.*", + score=3000, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-3 "), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*emissions.*", + score=3000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=1, + case_sensitive=False, + multi_match_decay=1, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-3 "), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*(total|combine).*", + score=1500, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.8, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("total indirect ghg scope-3 "), + ) + ) + + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*intensity.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*s.*cope( |-)2.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) # by Lei + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*305.*", + score=1, + matching_mode=MATCHING_MUST_EXCLUDE, + score_decay=0, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=10, + letter_decay_disregard=0, + count_if_matched=False, + allow_matching_against_concat_txt=False, + ) + ) # by Lei + + kpi.unit_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw=".*(ton|mn|million|kt|m t|co 2).*", case_sensitive=False) + ) + + kpi.value_must_be_numeric = True + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch( + pattern_raw=".*(environment|emission).*", case_sensitive=False + ), + distance_mode=DISTANCE_MOD_EUCLID, + score=100, + matching_mode=MATCHING_MAY_INCLUDE, + score_decay=0.9, + multi_match_decay=0.5, + letter_decay_hl=5, + letter_decay_disregard=len("environment"), + ) + ) + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_12_Target_Year_Reduction(): + # KPI 12 = Target Year Reduction + # this works right now only for CDP reports + kpi = KPISpecs() + kpi.kpi_id = 12 + kpi.kpi_name = "Target Year Reduction" + + # Match paragraphs + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*targeted reduction from base year.*", + score=8000, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("targeted reduction from base year"), + ) + ) + + kpi.value_must_be_numeric = True + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*target year.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=5000, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len("c6.1"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*scope 1.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=5000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len("c6.1"), + ) + ) + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch(pattern_raw=".*c4\.1a.*", case_sensitive=False), + distance_mode=DISTANCE_MOD_EUCLID, + score=5000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len("c6.1"), + ) + ) + + kpi.value_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw="[0-9]{2,3}(\.[0-9]+)?", case_sensitive=False) + ) # must be 2-3 digits, optional fractional part + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_9991_CDP_0_4_Currency(): + # KPI 9991 = Currency of CDP Report + # this works right now only for CDP reports + kpi = KPISpecs() + kpi.kpi_id = 9991 + kpi.kpi_name = "C0.4 Currency of CDP Report" + + # Match paragraphs + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*\(c0\.4\) select the currency used.*", + score=8000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len( + "(C0.4) Select the currency used for all financial information disclosed throughout your" + ), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*\(c0\.4\).*", + score=2000, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len( + "(C0.4) Select the currency used for all financial information disclosed throughout your" + ), + ) + ) + + kpi.value_must_be_numeric = False + + kpi.value_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw="[A-Z]{3,3}", case_sensitive=True) + ) # must be 3 uppercase letters + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_9992_CDP_4_1_Emiss_Target(): + # KPI 9992 = Emission Target CDP Report + # this works right now only for CDP reports + kpi = KPISpecs() + kpi.kpi_id = 9991 + kpi.kpi_name = "C4.1 Currency of CDP Report" + + # Match paragraphs + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*\(c4\.1\) *did you have an emissions target.*", + score=8000, + matching_mode=MATCHING_CAN_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len( + "(C4.1) Did you have an emissions target that was active in the reporting year?" + ), + ) + ) + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*\(c4\.1\).*", + score=2000, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len( + "(C4.1) Did you have an emissions target that was active in the reporting year?" + ), + ) + ) + + kpi.value_must_be_numeric = False + + kpi.value_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw="[A-Z ]{3,35}target[A-Z ]{0,10}", case_sensitive=False) + ) # must be 3 uppercase letters + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_9993_CDP_4_2b_Year_Target_Set(): + # KPI 9993 = CDP 4.b, year target set, for first target + # this is more for testing/demonstration for CDP reports, dont use it productively yet + kpi = KPISpecs() + kpi.kpi_id = 9993 + kpi.kpi_name = "C4.2b Year target set" + + # Match paragraphs + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*year target was set.*", + score=8000, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("Year target was set"), + ) + ) + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch( + pattern_raw=".*\(c4\.2b\) *provide details of .*targets.*", case_sensitive=False + ), + distance_mode=DISTANCE_MOD_EUCLID_UP_ONLY, + score=5000, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len( + "(C4.2b) Provide details of any other climate-related targets, including methane" + ), + ) + ) + + kpi.value_must_be_numeric = True + kpi.value_must_be_year = True + + kpi.value_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw="(19[8-9]|20[0-2])[0-9]", case_sensitive=False) + ) # must be year 1980-2029 + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + def prepare_kpi_9994_CDP_4_2b_Base_Year(): + # KPI 9993 = CDP 4.b, year target set, for first target + # this is more for testing/demonstration for CDP reports, dont use it productively yet + kpi = KPISpecs() + kpi.kpi_id = 9994 + kpi.kpi_name = "C4.2b Base year" + + # Match paragraphs + kpi.desc_regex_match_list.append( + KPISpecs.DescRegExMatch( + pattern_raw=".*base year.*", + score=8000, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.1, + case_sensitive=False, + multi_match_decay=0, + letter_decay_hl=20, + letter_decay_disregard=len("Year target was set"), + ) + ) + + kpi.anywhere_regex_match_list.append( + KPISpecs.AnywhereRegExMatch( + general_match=KPISpecs.GeneralRegExMatch( + pattern_raw=".*\(c4\.2b\) *provide details of .*targets.*", case_sensitive=False + ), + distance_mode=DISTANCE_MOD_EUCLID_UP_ONLY, + score=5000, + matching_mode=MATCHING_MUST_INCLUDE, + score_decay=0.9, + multi_match_decay=0.2, + letter_decay_hl=8, + letter_decay_disregard=len( + "(C4.2b) Provide details of any other climate-related targets, including methane" + ), + ) + ) + + kpi.value_must_be_numeric = True + kpi.value_must_be_year = True + + kpi.value_regex_match_list.append( + KPISpecs.GeneralRegExMatch(pattern_raw="(19[8-9]|20[0-2])[0-9]", case_sensitive=False) + ) # must be year 1980-2029 + + kpi.minimum_score = 500 + kpi.minimum_score_desc_regex = 250 + + return kpi + + ### TODO: Add new KPI definitions here ! (similar to the proceudres above: def prepare_...) ### + + ### TODO: Append relevant Kpi defintions to "res" : ### + + # Note: Tested for CDP: kpi 6_1, 12, 9991 + + res = [] + res.append(prepare_kpi_2_0_provable_plus_probable_reserves()) + res.append(prepare_kpi_2_1_provable_reserves()) + # res.append(prepare_kpi_2_2_probable_reserves()) #Not yet implemented! DO not comment in!!! + res.append(prepare_kpi_3_production()) + res.append(prepare_kpi_3_1_oil_production()) + res.append(prepare_kpi_3_2_liquid_hydrocarbons_production()) + res.append(prepare_kpi_3_3_gas_production()) + res.append(prepare_Scope1_kpi_6_Direct_total_GHG_emissions()) + res.append(prepare_kpi_7_Scope2_GHGs_emissions()) + res.append(prepare_kpi_8_Scope3_GHGs_emissions()) + res.append(prepare_kpi_12_Target_Year_Reduction()) + # res.append(prepare_kpi_9991_CDP_0_4_Currency()) + # res.append(prepare_kpi_9992_CDP_4_1_Emiss_Target()) + # res.append(prepare_kpi_9993_CDP_4_2b_Year_Target_Set()) + # res.append(prepare_kpi_9994_CDP_4_2b_Base_Year()) + + return res +def load_test_data(test_data_file_path): + test_data = TestData() + test_data.load_from_csv(test_data_file_path) + # for testing purpose: + # test_data.filter_kpis(by_kpi_id = [2], by_data_type = ['TABLE']) + # test_data.filter_kpis(by_kpi_id = [2.1], by_data_type = ['TABLE']) + # test_data.filter_kpis(by_kpi_id = [3], by_data_type = ['TABLE']) + test_data.filter_kpis(by_kpi_id=[7], by_data_type=["TABLE"]) - def prepare_kpi_7_Scope2_GHGs_emissions(): - # KPI 7 = Scope 2 Energy indirect total GHGs emissions - kpi = KPISpecs() - kpi.kpi_id = 7 - kpi.kpi_name = 'Scope 2 Energy indirect total GHGs emissions' - - # - # TODO : Add kpi description here (similar to the procedure above!) - # - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*s.*cope( |-)2.*market',score=5000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('scope-2 market '))) + # print("DATA-SET:") + # print(test_data) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*s.*cope( |-)2.*',score=8000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-2 '))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*indirect.*ghg.*',score=3000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-2 '))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*ghg.*',score=3000,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-2 '))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*indirect.*',score=3000,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-2 '))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*co 2.*',score=3000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl=20, letter_decay_disregard = len('CO2'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*market.*',score=3000,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl=20, letter_decay_disregard = len('total indirect ghg scope 2 '))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*electricity.*',score=500,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl=20, letter_decay_disregard = len('total indirect ghg scope-2'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*indirect.*emissions.*',score=3000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-2 '))) #by Lei + # print("PDF-Files:") + # print(test_data.get_pdf_list()) + # fix_list = DataImportExport.import_files(r"/home/ismail/Share/initial_data/Europe", config.global_raw_pdf_folder, test_data.get_pdf_list(), 'pdf') + fix_list = DataImportExport.import_files( + r"//Wwg00m.rootdom.net/afs-team/1200000089/FC/R-M/AZUREPOC/2020/KPIs extraction/Training data/03_Oil Gas sector reports/Europe", + config.global_raw_pdf_folder, + test_data.get_pdf_list(), + "pdf", + ) + # fix_list = DataImportExport.import_files(r"//Wwg00m.rootdom.net/afs-team/1200000089/FC/R-M/AZUREPOC/2020/KPIs extraction/Training data", config.global_raw_pdf_folder, test_data.get_pdf_list(), 'pdf') + test_data.fix_file_names(fix_list) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(total|combine).*', score=1500, matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 1, letter_decay_hl = 10, letter_decay_disregard = len('total indirect ghg scope-2'))) + # filter out entries with no source file: + test_data.filter_kpis(by_has_fixed_source_file=True) + # print("DATA-SET:") + # print(test_data) - #kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(ton|mn|million|kt|m t|co 2).*', case_sensitive=False)) #by Lei - - - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*sale.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*s.*cope( |-)3.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(upstream|refin).*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) + return test_data - #eq does not work - kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='^(t|.*(ton|mn|million|kt|m t|co 2).*)$',case_sensitive=False)) - - - kpi.value_must_be_numeric = True - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*(environment|emission).*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 100, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('production'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*total.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 100, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('production'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*tons.*co.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 100, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.05, multi_match_decay=0.01, letter_decay_hl = 5, letter_decay_disregard = len('tons co2'))) #by Lei - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*million metric.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 200, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.05, multi_match_decay=0.01, letter_decay_hl = 5, letter_decay_disregard = len('million metric'))) #by Lei - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - - return kpi - - - - def prepare_kpi_8_Scope3_GHGs_emissions(): - # KPI 8 = Scope 3 Upstream Energy indirect total GHGs emissions - kpi = KPISpecs() - kpi.kpi_id = 8 - kpi.kpi_name = 'Scope 3 Upstream Energy indirect total GHGs emissions' - - # - # TODO : Add kpi description here (similar to the procedure above!) - # +def test_analyze_directory(htmldirectoy): + ana = AnalyzerDirectory(htmldirectoy, 2019) + kpis = test_prepare_kpispecs() + # print(kpis) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*s.*cope( |-)3.*',score=8000,matching_mode=MATCHING_MUST_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-3 '))) #by Lei -# kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*s.*cope( |-)3.*',score=8000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-3 '))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*ghg.*',score=3000,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-3 '))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*indirect.*',score=3000,matching_mode=MATCHING_MAY_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-3 '))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*emissions.*',score=3000,matching_mode=MATCHING_CAN_INCLUDE, score_decay=1, case_sensitive=False, multi_match_decay = 1, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-3 '))) + kpiresults = KPIResultSet(ana.find_multiple_kpis(kpis)) + print_big("FINAL RESULT", do_wait=False) + print(kpiresults) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*(total|combine).*', score=1500, matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.8, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('total indirect ghg scope-3 '))) +def test_result(): + kpiresults = KPIResultSet() + print(kpiresults) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*intensity.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*s.*cope( |-)2.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) #by Lei - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*305.*', score=1, matching_mode=MATCHING_MUST_EXCLUDE, score_decay=0, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 10, letter_decay_disregard = 0, count_if_matched = False, allow_matching_against_concat_txt = False)) #by Lei +def demo(): + pdf_file = config.global_raw_pdf_folder + r"test_bp.pdf" + print_big("Convert PDF to HTML") + HTMLDirectory.convert_pdf_to_html(pdf_file) - kpi.unit_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='.*(ton|mn|million|kt|m t|co 2).*',case_sensitive=False)) - - - kpi.value_must_be_numeric = True - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*(environment|emission).*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 100, matching_mode = MATCHING_MAY_INCLUDE, score_decay = 0.9, multi_match_decay=0.5, letter_decay_hl = 5, letter_decay_disregard = len('environment'))) + print_big("Convert HTML to JSON and PNG") - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - - return kpi + dir = HTMLDirectory() + dir.parse_html_directory(get_html_out_dir(pdf_file), r"page*.html") + dir.save_to_dir(get_html_out_dir(pdf_file)) + dir.render_to_png(get_html_out_dir(pdf_file), get_html_out_dir(pdf_file)) + print_big("Load from JSON") + dir = None + dir = HTMLDirectory() + dir.load_from_dir(get_html_out_dir(pdf_file), r"jpage*.json") + print_big("Analyze Tables") + test_analyze_directory(dir) - def prepare_kpi_12_Target_Year_Reduction(): - # KPI 12 = Target Year Reduction - # this works right now only for CDP reports - kpi = KPISpecs() - kpi.kpi_id = 12 - kpi.kpi_name = 'Target Year Reduction' - - # Match paragraphs - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*targeted reduction from base year.*',score=8000 ,matching_mode=MATCHING_MUST_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('targeted reduction from base year'))) - - kpi.value_must_be_numeric=True - - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*target year.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 5000, matching_mode = MATCHING_MUST_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('c6.1'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*scope 1.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 5000, matching_mode = MATCHING_CAN_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('c6.1'))) - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*c4\.1a.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID, score = 5000, matching_mode = MATCHING_CAN_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('c6.1'))) - - kpi.value_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='[0-9]{2,3}(\.[0-9]+)?',case_sensitive=False)) # must be 2-3 digits, optional fractional part - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 +def test_main(): + PDF_FILE = config.global_raw_pdf_folder + r"04_NOVATEK_AR_2016_ENG_11.pdf" - - return kpi + # test(PDF_FILE, "38") + dir = test_load_json(PDF_FILE, "*") - def prepare_kpi_9991_CDP_0_4_Currency(): - # KPI 9991 = Currency of CDP Report - # this works right now only for CDP reports - kpi = KPISpecs() - kpi.kpi_id = 9991 - kpi.kpi_name = 'C0.4 Currency of CDP Report' - - # Match paragraphs - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*\(c0\.4\) select the currency used.*',score=8000 ,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('(C0.4) Select the currency used for all financial information disclosed throughout your'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*\(c0\.4\).*',score=2000 ,matching_mode=MATCHING_MUST_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('(C0.4) Select the currency used for all financial information disclosed throughout your'))) - - kpi.value_must_be_numeric=False - - - kpi.value_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='[A-Z]{3,3}',case_sensitive=True)) # must be 3 uppercase letters - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 + test_analyze_directory(dir) - - return kpi - - - def prepare_kpi_9992_CDP_4_1_Emiss_Target(): - # KPI 9992 = Emission Target CDP Report - # this works right now only for CDP reports - kpi = KPISpecs() - kpi.kpi_id = 9991 - kpi.kpi_name = 'C4.1 Currency of CDP Report' - - # Match paragraphs - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*\(c4\.1\) *did you have an emissions target.*',score=8000 ,matching_mode=MATCHING_CAN_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('(C4.1) Did you have an emissions target that was active in the reporting year?'))) - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*\(c4\.1\).*',score=2000 ,matching_mode=MATCHING_MUST_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('(C4.1) Did you have an emissions target that was active in the reporting year?'))) - - kpi.value_must_be_numeric=False - - - kpi.value_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='[A-Z ]{3,35}target[A-Z ]{0,10}',case_sensitive=False)) # must be 3 uppercase letters - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 - - return kpi - +def test_evaluation(): + test_data = load_test_data(r"test_data/aggregated_complete_samples_new.csv") - - - def prepare_kpi_9993_CDP_4_2b_Year_Target_Set(): - # KPI 9993 = CDP 4.b, year target set, for first target - # this is more for testing/demonstration for CDP reports, dont use it productively yet - kpi = KPISpecs() - kpi.kpi_id = 9993 - kpi.kpi_name = 'C4.2b Year target set' - - # Match paragraphs - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*year target was set.*',score=8000 ,matching_mode=MATCHING_MUST_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('Year target was set'))) - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*\(c4\.2b\) *provide details of .*targets.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID_UP_ONLY, score = 5000, matching_mode = MATCHING_MUST_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('(C4.2b) Provide details of any other climate-related targets, including methane'))) - - kpi.value_must_be_numeric=True - kpi.value_must_be_year = True - - - kpi.value_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='(19[8-9]|20[0-2])[0-9]',case_sensitive=False)) # must be year 1980-2029 - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 + # test_data.filter_kpis(by_source_file = ['NYSE_TOT_2015 annual.pdf', 'LUKOIL_ANNUAL_REPORT_2018_ENG']) - - return kpi - - - - def prepare_kpi_9994_CDP_4_2b_Base_Year(): - # KPI 9993 = CDP 4.b, year target set, for first target - # this is more for testing/demonstration for CDP reports, dont use it productively yet - kpi = KPISpecs() - kpi.kpi_id = 9994 - kpi.kpi_name = 'C4.2b Base year' - - # Match paragraphs - kpi.desc_regex_match_list.append(KPISpecs.DescRegExMatch(pattern_raw='.*base year.*',score=8000 ,matching_mode=MATCHING_MUST_INCLUDE, score_decay=0.1, case_sensitive=False, multi_match_decay = 0, letter_decay_hl = 20, letter_decay_disregard = len('Year target was set'))) - - kpi.anywhere_regex_match_list.append(KPISpecs.AnywhereRegExMatch(general_match = KPISpecs.GeneralRegExMatch(pattern_raw='.*\(c4\.2b\) *provide details of .*targets.*',case_sensitive=False), distance_mode = DISTANCE_MOD_EUCLID_UP_ONLY, score = 5000, matching_mode = MATCHING_MUST_INCLUDE, score_decay = 0.9, multi_match_decay=0.2, letter_decay_hl = 8, letter_decay_disregard = len('(C4.2b) Provide details of any other climate-related targets, including methane'))) - - kpi.value_must_be_numeric=True - kpi.value_must_be_year = True - - - kpi.value_regex_match_list.append(KPISpecs.GeneralRegExMatch(pattern_raw='(19[8-9]|20[0-2])[0-9]',case_sensitive=False)) # must be year 1980-2029 - - kpi.minimum_score = 500 - kpi.minimum_score_desc_regex = 250 + test_data.filter_kpis( + by_source_file=[ + "Aker-BP-Sustainability-Report-2019.pdf" # KPIs are on pg: 84: 2009:665.1 ... 2013:575.7 + # , 'NYSE_TOT_2018 annual.pdf' # KPIs are on pg: 129: 2017:914, 2018:917 + # , 'Transocean_Sustain_digital_FN_4 2017_2018.pdf' # KPIs are on pg: 112: 2016:711.1, 2015: 498.2 + # , 'Wintershall-Dea_Sustainability_Report_2019.pdf' + ] + ) - - return kpi - - - - ### TODO: Add new KPI definitions here ! (similar to the proceudres above: def prepare_...) ### - - - - - - - - - - ### TODO: Append relevant Kpi defintions to "res" : ### - - # Note: Tested for CDP: kpi 6_1, 12, 9991 - - res = [] - res.append(prepare_kpi_2_0_provable_plus_probable_reserves()) - res.append(prepare_kpi_2_1_provable_reserves()) - #res.append(prepare_kpi_2_2_probable_reserves()) #Not yet implemented! DO not comment in!!! - res.append(prepare_kpi_3_production()) - res.append(prepare_kpi_3_1_oil_production()) - res.append(prepare_kpi_3_2_liquid_hydrocarbons_production()) - res.append(prepare_kpi_3_3_gas_production()) - res.append(prepare_Scope1_kpi_6_Direct_total_GHG_emissions()) - res.append(prepare_kpi_7_Scope2_GHGs_emissions()) - res.append(prepare_kpi_8_Scope3_GHGs_emissions()) - res.append(prepare_kpi_12_Target_Year_Reduction()) - #res.append(prepare_kpi_9991_CDP_0_4_Currency()) - #res.append(prepare_kpi_9992_CDP_4_1_Emiss_Target()) - #res.append(prepare_kpi_9993_CDP_4_2b_Year_Target_Set()) - #res.append(prepare_kpi_9994_CDP_4_2b_Base_Year()) - - return res - - - + print_big("Data-set", False) + print(test_data) -def load_test_data(test_data_file_path): - test_data = TestData() - test_data.load_from_csv(test_data_file_path) - - # for testing purpose: - #test_data.filter_kpis(by_kpi_id = [2], by_data_type = ['TABLE']) - #test_data.filter_kpis(by_kpi_id = [2.1], by_data_type = ['TABLE']) - #test_data.filter_kpis(by_kpi_id = [3], by_data_type = ['TABLE']) - test_data.filter_kpis(by_kpi_id = [7], by_data_type = ['TABLE']) - - - #print("DATA-SET:") - #print(test_data) - - #print("PDF-Files:") - #print(test_data.get_pdf_list()) - - #fix_list = DataImportExport.import_files(r"/home/ismail/Share/initial_data/Europe", config.global_raw_pdf_folder, test_data.get_pdf_list(), 'pdf') - fix_list = DataImportExport.import_files(r"//Wwg00m.rootdom.net/afs-team/1200000089/FC/R-M/AZUREPOC/2020/KPIs extraction/Training data/03_Oil Gas sector reports/Europe", config.global_raw_pdf_folder, test_data.get_pdf_list(), 'pdf') - #fix_list = DataImportExport.import_files(r"//Wwg00m.rootdom.net/afs-team/1200000089/FC/R-M/AZUREPOC/2020/KPIs extraction/Training data", config.global_raw_pdf_folder, test_data.get_pdf_list(), 'pdf') - test_data.fix_file_names(fix_list) - - - # filter out entries with no source file: - test_data.filter_kpis(by_has_fixed_source_file = True) - - #print("DATA-SET:") - #print(test_data) - - return test_data - - - - -def test_analyze_directory(htmldirectoy): - - ana = AnalyzerDirectory(htmldirectoy, 2019) - kpis = test_prepare_kpispecs() - #print(kpis) - - kpiresults = KPIResultSet(ana.find_multiple_kpis(kpis)) - - - print_big("FINAL RESULT", do_wait = False) - print(kpiresults) - - - - + kpiresults = KPIResultSet.load_from_file(r"test_data/kpiresults_test_all_files_against_kpi_2_0.json") -def test_result(): - kpiresults = KPIResultSet() - print(kpiresults) + print_big("Kpi-Results", do_wait=False) + print(kpiresults) - - - + print_big("Kpi-Evaluation", do_wait=False) -def demo(): + kpis = test_prepare_kpispecs() - pdf_file = config.global_raw_pdf_folder+r'test_bp.pdf' - - - print_big("Convert PDF to HTML") - HTMLDirectory.convert_pdf_to_html(pdf_file) - - - print_big("Convert HTML to JSON and PNG") - - dir = HTMLDirectory() - dir.parse_html_directory(get_html_out_dir(pdf_file), r'page*.html') - dir.save_to_dir(get_html_out_dir(pdf_file)) - dir.render_to_png(get_html_out_dir(pdf_file), get_html_out_dir(pdf_file)) - - - print_big("Load from JSON") - dir = None - dir = HTMLDirectory() - dir.load_from_dir(get_html_out_dir(pdf_file) , r'jpage*.json') - - print_big("Analyze Tables") - test_analyze_directory(dir) - - - -def test_main(): - PDF_FILE = config.global_raw_pdf_folder+r'04_NOVATEK_AR_2016_ENG_11.pdf' - - #test(PDF_FILE, "38") - - dir = test_load_json(PDF_FILE, "*") - - test_analyze_directory(dir) - - -def test_evaluation(): + test_eval = TestEvaluation.generate_evaluation(kpis, kpiresults, test_data) - test_data = load_test_data(r'test_data/aggregated_complete_samples_new.csv') - - - #test_data.filter_kpis(by_source_file = ['NYSE_TOT_2015 annual.pdf', 'LUKOIL_ANNUAL_REPORT_2018_ENG']) - - - test_data.filter_kpis(by_source_file = [ - 'Aker-BP-Sustainability-Report-2019.pdf' # KPIs are on pg: 84: 2009:665.1 ... 2013:575.7 - #, 'NYSE_TOT_2018 annual.pdf' # KPIs are on pg: 129: 2017:914, 2018:917 - #, 'Transocean_Sustain_digital_FN_4 2017_2018.pdf' # KPIs are on pg: 112: 2016:711.1, 2015: 498.2 - #, 'Wintershall-Dea_Sustainability_Report_2019.pdf' - ]) - - print_big("Data-set", False) - print(test_data) - - - - kpiresults = KPIResultSet.load_from_file(r'test_data/kpiresults_test_all_files_against_kpi_2_0.json') - - - - print_big("Kpi-Results", do_wait = False) - print(kpiresults) - - - print_big("Kpi-Evaluation", do_wait = False) - - kpis = test_prepare_kpispecs() - - test_eval = TestEvaluation.generate_evaluation(kpis, kpiresults, test_data) - - print(test_eval) - - - - - + print(test_eval) diff --git a/data_extractor/code/rule_based_pipeline/setup.py b/data_extractor/code/rule_based_pipeline/setup.py index 6b128f1..9a79cb8 100644 --- a/data_extractor/code/rule_based_pipeline/setup.py +++ b/data_extractor/code/rule_based_pipeline/setup.py @@ -4,34 +4,36 @@ from setuptools import find_packages, setup -NAME = 'rule_based_pipeline' -DESCRIPTION = 'Run rule-based solution' -AUTHOR = 'I. Demir' -REQUIRES_PYTHON = '>=3.6.0' +NAME = "rule_based_pipeline" +DESCRIPTION = "Run rule-based solution" +AUTHOR = "I. Demir" +REQUIRES_PYTHON = ">=3.6.0" -def list_reqs(fname='requirements.txt'): + +def list_reqs(fname="requirements.txt"): with open(fname) as fd: return fd.read().splitlines() + here = os.path.abspath(os.path.dirname(__file__)) # Load the package's __version__.py module as a dictionary. ROOT_DIR = Path(__file__).resolve().parent PACKAGE_DIR = ROOT_DIR / NAME about = {} -with open(PACKAGE_DIR / 'VERSION') as f: +with open(PACKAGE_DIR / "VERSION") as f: _version = f.read().strip() - about['__version__'] = _version + about["__version__"] = _version setup( name=NAME, - version=about['__version__'], + version=about["__version__"], description=DESCRIPTION, author=AUTHOR, python_requires=REQUIRES_PYTHON, packages=find_packages(), - package_data={'rule_based_pipeline': ['VERSION']}, + package_data={"rule_based_pipeline": ["VERSION"]}, install_requires=list_reqs(), extras_require={}, - include_package_data=True + include_package_data=True, ) diff --git a/data_extractor/code/s3_communication.py b/data_extractor/code/s3_communication.py index 82df3d3..470b008 100644 --- a/data_extractor/code/s3_communication.py +++ b/data_extractor/code/s3_communication.py @@ -23,9 +23,7 @@ class S3Communication(object): It connects with the bucket and provides methods to read and write data in parquet, csv, and json formats. """ - def __init__( - self, s3_endpoint_url, aws_access_key_id, aws_secret_access_key, s3_bucket - ): + def __init__(self, s3_endpoint_url, aws_access_key_id, aws_secret_access_key, s3_bucket): """Initialize communicator.""" self.s3_endpoint_url = s3_endpoint_url self.aws_access_key_id = aws_access_key_id @@ -63,9 +61,7 @@ def download_file_from_s3(self, filepath, s3_prefix, s3_key): with open(filepath, "wb") as f: f.write(buffer_bytes) - def upload_df_to_s3( - self, df, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_to_ftype_args - ): + def upload_df_to_s3(self, df, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_to_ftype_args): """ Take as input the data frame to be uploaded, and the output s3_key. @@ -79,16 +75,12 @@ def upload_df_to_s3( elif filetype == S3FileType.PARQUET: df.to_parquet(buffer, **pd_to_ftype_args) else: - raise ValueError( - f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})" - ) + raise ValueError(f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})") status = self._upload_bytes(buffer.getvalue(), s3_prefix, s3_key) return status - def download_df_from_s3( - self, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_read_ftype_args - ): + def download_df_from_s3(self, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_read_ftype_args): """Read from s3 and see if the saved data is correct.""" buffer_bytes = self._download_bytes(s3_prefix, s3_key) buffer = BytesIO(buffer_bytes) @@ -100,9 +92,7 @@ def download_df_from_s3( elif filetype == S3FileType.PARQUET: df = pd.read_parquet(buffer, **pd_read_ftype_args) else: - raise ValueError( - f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})" - ) + raise ValueError(f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})") return df def upload_files_in_dir_to_prefix(self, source_dir, s3_prefix): @@ -126,9 +116,7 @@ def download_files_in_prefix_to_dir(self, s3_prefix, destination_dir): Modified from original code here: https://stackoverflow.com/a/33350380 """ paginator = self.s3_resource.meta.client.get_paginator("list_objects") - for result in paginator.paginate( - Bucket=self.bucket, Delimiter="/", Prefix=s3_prefix - ): + for result in paginator.paginate(Bucket=self.bucket, Delimiter="/", Prefix=s3_prefix): # download all files in the sub "directory", if any if result.get("CommonPrefixes") is not None: for subdir in result.get("CommonPrefixes"): diff --git a/data_extractor/code/setup_project.py b/data_extractor/code/setup_project.py index ba9e553..3ce5d92 100644 --- a/data_extractor/code/setup_project.py +++ b/data_extractor/code/setup_project.py @@ -5,32 +5,31 @@ def main(): - parser = argparse.ArgumentParser(description='Setup new NLP project') + parser = argparse.ArgumentParser(description="Setup new NLP project") # Add the arguments - parser.add_argument('--project_name', - type=str, - default=None, - help='Name of the Project') + parser.add_argument("--project_name", type=str, default=None, help="Name of the Project") args = parser.parse_args() project_name = args.project_name if project_name is None: project_name = input("What is the project name? ") - if(project_name is None or project_name==""): + if project_name is None or project_name == "": print("project name must not be empty") return - - os.makedirs(config_path.DATA_DIR + r'/' + project_name, exist_ok=True) - os.makedirs(config_path.DATA_DIR + r'/' + project_name + r'/input', exist_ok=True) - os.makedirs(config_path.DATA_DIR + r'/' + project_name + r'/input/pdfs', exist_ok=True) - os.makedirs(config_path.DATA_DIR + r'/' + project_name + r'/input/kpi_mapping', exist_ok=True) - os.makedirs(config_path.DATA_DIR + r'/' + project_name + r'/input/annotations', exist_ok=True) - os.makedirs(config_path.DATA_DIR + r'/' + project_name + r'/input/pdfs/training', exist_ok=True) - os.makedirs(config_path.DATA_DIR + r'/' + project_name + r'/input/pdfs/inference', exist_ok=True) - os.makedirs(config_path.DATA_DIR + r'/' + project_name + r'/interim', exist_ok=True) - os.makedirs(config_path.DATA_DIR + r'/' + project_name + r'/output', exist_ok=True) - shutil.copy(config_path.DATA_DIR + r'/settings_default.yaml', config_path.DATA_DIR + r'/' + project_name + r'/settings.yaml') + + os.makedirs(config_path.DATA_DIR + r"/" + project_name, exist_ok=True) + os.makedirs(config_path.DATA_DIR + r"/" + project_name + r"/input", exist_ok=True) + os.makedirs(config_path.DATA_DIR + r"/" + project_name + r"/input/pdfs", exist_ok=True) + os.makedirs(config_path.DATA_DIR + r"/" + project_name + r"/input/kpi_mapping", exist_ok=True) + os.makedirs(config_path.DATA_DIR + r"/" + project_name + r"/input/annotations", exist_ok=True) + os.makedirs(config_path.DATA_DIR + r"/" + project_name + r"/input/pdfs/training", exist_ok=True) + os.makedirs(config_path.DATA_DIR + r"/" + project_name + r"/input/pdfs/inference", exist_ok=True) + os.makedirs(config_path.DATA_DIR + r"/" + project_name + r"/interim", exist_ok=True) + os.makedirs(config_path.DATA_DIR + r"/" + project_name + r"/output", exist_ok=True) + shutil.copy( + config_path.DATA_DIR + r"/settings_default.yaml", config_path.DATA_DIR + r"/" + project_name + r"/settings.yaml" + ) if __name__ == "__main__": diff --git a/data_extractor/code/tests/conftest.py b/data_extractor/code/tests/conftest.py index 59b31f8..d843ae6 100644 --- a/data_extractor/code/tests/conftest.py +++ b/data_extractor/code/tests/conftest.py @@ -6,11 +6,12 @@ import pandas as pd import sys from tests.utils_test import project_tests_root + # add test_on_pdf.py to the PATH sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) - -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def path_folder_temporary() -> Path: """Fixture for defining path for running check @@ -19,18 +20,17 @@ def path_folder_temporary() -> Path: :yield: Path to the temporary folder :rtype: Iterator[Path] """ - path_folder_temporary_ = project_tests_root() / 'temporary_folder' + path_folder_temporary_ = project_tests_root() / "temporary_folder" # delete the temporary folder and recreate it - shutil.rmtree(path_folder_temporary_, ignore_errors = True) + shutil.rmtree(path_folder_temporary_, ignore_errors=True) path_folder_temporary_.mkdir() yield path_folder_temporary_ - + # cleanup - shutil.rmtree(path_folder_temporary_, ignore_errors = True) - - -@pytest.fixture(scope='session') + shutil.rmtree(path_folder_temporary_, ignore_errors=True) + + +@pytest.fixture(scope="session") def path_folder_root_testing() -> Path: - path_folder_data_sample_ = project_tests_root() / 'root_testing' + path_folder_data_sample_ = project_tests_root() / "root_testing" yield path_folder_data_sample_ - diff --git a/data_extractor/code/tests/test_train_on_pdf.py b/data_extractor/code/tests/test_train_on_pdf.py index e21590b..09e34df 100644 --- a/data_extractor/code/tests/test_train_on_pdf.py +++ b/data_extractor/code/tests/test_train_on_pdf.py @@ -20,11 +20,8 @@ @pytest.fixture(params=[()], autouse=True) def prerequisite_train_on_pdf_try_run( - request: FixtureRequest, - path_folder_root_testing: Path, - path_folder_temporary: Path, - prerequisite_running - ) -> None: + request: FixtureRequest, path_folder_root_testing: Path, path_folder_temporary: Path, prerequisite_running +) -> None: """Defines a fixture for the train_on_pdf script :param request: Request for parametrization @@ -36,61 +33,48 @@ def prerequisite_train_on_pdf_try_run( :rtype prerequisite_train_on_pdf_try_run: None """ mocked_project_settings = { - 's3_usage': False, - 's3_settings': {}, - 'general': - { - 'ext_port': 0, - 'infer_port': 0, - 'ext_ip': '0.0.0.0', - 'infer_ip': '0.0.0.0', - 'delete_interim_files': False - }, - 'train_relevance': - { - 'output_model_name': 'test', - 'train': False - }, - 'train_kpi': - { - 'output_model_name': 'test', - 'train': False - }, - 'extraction': - { - 'use_extractions': False, - 'store_extractions': False - } + "s3_usage": False, + "s3_settings": {}, + "general": { + "ext_port": 0, + "infer_port": 0, + "ext_ip": "0.0.0.0", + "infer_ip": "0.0.0.0", + "delete_interim_files": False, + }, + "train_relevance": {"output_model_name": "test", "train": False}, + "train_kpi": {"output_model_name": "test", "train": False}, + "extraction": {"use_extractions": False, "store_extractions": False}, } - + mocked_s3_settings = { - 'prefix': 'test_prefix', - 'main_bucket': { - 's3_endpoint': 'S3_END_MAIN', - 's3_access_key': 'S3_ACCESS_MAIN', - 's3_secret_key': 'S3_SECRET_MAIN', - 's3_bucket_name': 'S3_NAME_MAIN' + "prefix": "test_prefix", + "main_bucket": { + "s3_endpoint": "S3_END_MAIN", + "s3_access_key": "S3_ACCESS_MAIN", + "s3_secret_key": "S3_SECRET_MAIN", + "s3_bucket_name": "S3_NAME_MAIN", + }, + "interim_bucket": { + "s3_endpoint": "S3_END_INTERIM", + "s3_access_key": "S3_ACCESS_INTERIM", + "s3_secret_key": "S3_SECRET_INTERIM", + "s3_bucket_name": "S3_NAME_INTERIM", }, - 'interim_bucket': { - 's3_endpoint': 'S3_END_INTERIM', - 's3_access_key': 'S3_ACCESS_INTERIM', - 's3_secret_key': 'S3_SECRET_INTERIM', - 's3_bucket_name': 'S3_NAME_INTERIM' - } } - project_name = 'TEST' - path_folder_data = path_folder_temporary / 'data' - path_folder_models = path_folder_temporary / 'models' + project_name = "TEST" + path_folder_data = path_folder_temporary / "data" + path_folder_models = path_folder_temporary / "models" Path(path_folder_data / project_name).mkdir(parents=True, exist_ok=True) path_folder_models.mkdir(parents=True, exist_ok=True) - + # copy settings files to temporary folder - path_file_settings_root_testing = path_folder_root_testing / 'data' / project_name / 'settings.yaml' - path_file_settings_temporary = path_folder_temporary / 'data' / project_name / 'settings.yaml' - - path_file_settings_s3_root_testing = path_folder_root_testing / 'data' / 's3_settings.yaml' - path_file_settings_s3_temporary = path_folder_temporary / 'data' / 's3_settings.yaml' - + path_file_settings_root_testing = path_folder_root_testing / "data" / project_name / "settings.yaml" + path_file_settings_temporary = path_folder_temporary / "data" / project_name / "settings.yaml" + + path_file_settings_s3_root_testing = path_folder_root_testing / "data" / "s3_settings.yaml" + path_file_settings_s3_temporary = path_folder_temporary / "data" / "s3_settings.yaml" + shutil.copy(path_file_settings_root_testing, path_file_settings_temporary) shutil.copy(path_file_settings_s3_root_testing, path_file_settings_s3_temporary) @@ -100,52 +84,51 @@ def return_project_settings(*args: typing.List[Mock]): :return: Project or S3 Settings file :rtype: typing.Dict[str] """ - if 's3' in args[0].name: + if "s3" in args[0].name: return mocked_s3_settings else: return mocked_project_settings - + # modifying the project settings file via parametrization mocked_project_settings = modify_project_settings(mocked_project_settings, request.param) with ( - patch('train_on_pdf.argparse.ArgumentParser.parse_args', Mock()) as mocked_argpase, - patch('train_on_pdf.config_path', Mock()) as mocked_config_path, - patch('train_on_pdf.yaml', Mock()) as mocked_yaml, - patch('train_on_pdf.project_settings', mocked_project_settings) + patch("train_on_pdf.argparse.ArgumentParser.parse_args", Mock()) as mocked_argpase, + patch("train_on_pdf.config_path", Mock()) as mocked_config_path, + patch("train_on_pdf.yaml", Mock()) as mocked_yaml, + patch("train_on_pdf.project_settings", mocked_project_settings), ): mocked_argpase.return_value.project_name = project_name - mocked_argpase.return_value.s3_usage = 'N' + mocked_argpase.return_value.s3_usage = "N" mocked_config_path.DATA_DIR = str(path_folder_data) mocked_config_path.MODEL_DIR = str(path_folder_models) mocked_yaml.safe_load.side_effect = return_project_settings yield - + # cleanup shutil.rmtree(path_folder_temporary) - + def test_train_on_pdf_check_running(capsys: typing.Generator[CaptureFixture[str], None, None]): """Tests if everything is printed when another training is running :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None]) """ - with patch('train_on_pdf.check_running'): + with patch("train_on_pdf.check_running"): return_value = train_on_pdf.main() - + output_cmd, _ = capsys.readouterr() - string_expected = 'Another training or inference process is currently running.' + string_expected = "Another training or inference process is currently running." train_on_pdf.check_running.assert_called_once() assert return_value is None assert string_expected in output_cmd -@pytest.mark.parametrize('project_name', - [None, - '']) -def test_train_on_pdf_wrong_input_project_name(project_name: typing.Union[str, None], - capsys: typing.Generator[CaptureFixture[str], None, None]): +@pytest.mark.parametrize("project_name", [None, ""]) +def test_train_on_pdf_wrong_input_project_name( + project_name: typing.Union[str, None], capsys: typing.Generator[CaptureFixture[str], None, None] +): """Tests the correct behaviour of wrong given project names :param project_name: Project name @@ -153,41 +136,43 @@ def test_train_on_pdf_wrong_input_project_name(project_name: typing.Union[str, N :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None]) """ - with (patch('train_on_pdf.argparse.ArgumentParser.parse_args', Mock()) as mocked_argpase, - patch('train_on_pdf.input', Mock()) as mocked_input): + with ( + patch("train_on_pdf.argparse.ArgumentParser.parse_args", Mock()) as mocked_argpase, + patch("train_on_pdf.input", Mock()) as mocked_input, + ): mocked_argpase.return_value.project_name = project_name mocked_input.return_value = project_name - + return_value = train_on_pdf.main() - + output_cmd, _ = capsys.readouterr() - string_expected = 'project name must not be empty' + string_expected = "project name must not be empty" if project_name is None: - string_call_expected = 'What is the project name? ' + string_call_expected = "What is the project name? " mocked_input.assert_called_once() - mocked_input.assert_called_with(string_call_expected) + mocked_input.assert_called_with(string_call_expected) assert string_expected in output_cmd assert return_value is None - + def test_train_on_pdf_correct_input_project_name(): - """Tests that a correct project name is accepted - """ - with (patch('train_on_pdf.argparse.ArgumentParser.parse_args', Mock()) as mocked_argpase, - patch('train_on_pdf.input', Mock()) as mocked_input): + """Tests that a correct project name is accepted""" + with ( + patch("train_on_pdf.argparse.ArgumentParser.parse_args", Mock()) as mocked_argpase, + patch("train_on_pdf.input", Mock()) as mocked_input, + ): mocked_argpase.return_value.s3_usage = True - mocked_input.side_effect = lambda: 'TEST' - + mocked_input.side_effect = lambda: "TEST" + train_on_pdf.main() - - assert mocked_input() == 'TEST' + + assert mocked_input() == "TEST" -@pytest.mark.parametrize('s3_usage', - [None, - 'X']) -def test_train_on_pdf_wrong_input_s3(s3_usage: typing.Union[str, None], - capsys: typing.Generator[CaptureFixture[str], None, None]): +@pytest.mark.parametrize("s3_usage", [None, "X"]) +def test_train_on_pdf_wrong_input_s3( + s3_usage: typing.Union[str, None], capsys: typing.Generator[CaptureFixture[str], None, None] +): """Tests the correct behaviour of wrong s3 input is given :param s3_usage: S3 usage (yes or no) @@ -195,80 +180,82 @@ def test_train_on_pdf_wrong_input_s3(s3_usage: typing.Union[str, None], :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None]) """ - with (patch('train_on_pdf.argparse.ArgumentParser.parse_args', Mock()) as mocked_argpase, - patch('train_on_pdf.input', Mock()) as mocked_input): - mocked_argpase.return_value.project_name = 'TEST' + with ( + patch("train_on_pdf.argparse.ArgumentParser.parse_args", Mock()) as mocked_argpase, + patch("train_on_pdf.input", Mock()) as mocked_input, + ): + mocked_argpase.return_value.project_name = "TEST" mocked_argpase.return_value.s3_usage = s3_usage - + return_value = train_on_pdf.main() - + output_cmd, _ = capsys.readouterr() - string_expected = 'Answer to S3 usage must by Y or N. Stop program. Please restart.' + string_expected = "Answer to S3 usage must by Y or N. Stop program. Please restart." if s3_usage is None: - string_call_expected = 'Do you want to use S3? Type either Y or N.' + string_call_expected = "Do you want to use S3? Type either Y or N." mocked_input.assert_called_once() mocked_input.assert_called_with(string_call_expected) assert string_expected in output_cmd assert return_value is None - -@pytest.mark.parametrize('s3_usage', - ['Y', - 'N']) + +@pytest.mark.parametrize("s3_usage", ["Y", "N"]) def test_train_on_pdf_correct_input_s3_usage(s3_usage: typing.Union[str, None]): """Tests that the correct s3 usage is accepted :param s3_usage: S3 usage (yes or no) :type s3_usage: typing.Union[str, None] """ - with (patch('train_on_pdf.argparse.ArgumentParser.parse_args', Mock()) as mocked_argpase, - patch('train_on_pdf.input', Mock()) as mocked_input, - patch('train_on_pdf.create_directory', - side_effect=lambda *args: Path(args[0]).mkdir(parents=True, exist_ok=True)), - patch('train_on_pdf.S3Communication', Mock()) as mocked_s3_communication): - mocked_argpase.return_value.project_name = 'TEST' + with ( + patch("train_on_pdf.argparse.ArgumentParser.parse_args", Mock()) as mocked_argpase, + patch("train_on_pdf.input", Mock()) as mocked_input, + patch( + "train_on_pdf.create_directory", side_effect=lambda *args: Path(args[0]).mkdir(parents=True, exist_ok=True) + ), + patch("train_on_pdf.S3Communication", Mock()) as mocked_s3_communication, + ): + mocked_argpase.return_value.project_name = "TEST" mocked_argpase.return_value.s3_usage = None mocked_input.side_effect = lambda *args: s3_usage - + train_on_pdf.main() - + assert mocked_input() == s3_usage - if s3_usage == 'Y': + if s3_usage == "Y": assert mocked_s3_communication.call_count == 2 - + mocked_s3_communication.return_value.download_file_from_s3.assert_called_once() def test_train_on_pdf_s3_usage(): - """Tests if the s3 usage is correctly performed + """Tests if the s3 usage is correctly performed""" + project_name = "TEST" - """ - project_name = 'TEST' - - with (patch('train_on_pdf.os.getenv', Mock(side_effect=lambda *args: args[0])), - patch('train_on_pdf.argparse.ArgumentParser.parse_args', Mock()) as mocked_argpase, - patch('train_on_pdf.S3Communication', Mock()) as mocked_s3_communication, - patch('train_on_pdf.create_directory', Mock())): - + with ( + patch("train_on_pdf.os.getenv", Mock(side_effect=lambda *args: args[0])), + patch("train_on_pdf.argparse.ArgumentParser.parse_args", Mock()) as mocked_argpase, + patch("train_on_pdf.S3Communication", Mock()) as mocked_s3_communication, + patch("train_on_pdf.create_directory", Mock()), + ): mocked_argpase.return_value.project_name = project_name - mocked_argpase.return_value.s3_usage = 'Y' - + mocked_argpase.return_value.s3_usage = "Y" + train_on_pdf.main() - + mocked_s3_communication.assert_any_call( - s3_endpoint_url='S3_END_MAIN', - aws_access_key_id='S3_ACCESS_MAIN', - aws_secret_access_key='S3_SECRET_MAIN', - s3_bucket='S3_NAME_MAIN' + s3_endpoint_url="S3_END_MAIN", + aws_access_key_id="S3_ACCESS_MAIN", + aws_secret_access_key="S3_SECRET_MAIN", + s3_bucket="S3_NAME_MAIN", ) - + mocked_s3_communication.assert_any_call( - s3_endpoint_url='S3_END_INTERIM', - aws_access_key_id='S3_ACCESS_INTERIM', - aws_secret_access_key='S3_SECRET_INTERIM', - s3_bucket='S3_NAME_INTERIM' + s3_endpoint_url="S3_END_INTERIM", + aws_access_key_id="S3_ACCESS_INTERIM", + aws_secret_access_key="S3_SECRET_INTERIM", + s3_bucket="S3_NAME_INTERIM", ) - + mocked_s3_communication.return_value.download_file_from_s3.assert_called_once() @@ -278,81 +265,82 @@ def test_train_on_pdf_folders_default_created(path_folder_temporary: Path): :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - + paths_folders_expected = [ - r'/interim/ml', - r'/interim/pdfs/', - r'/interim/ml/annotations/', - r'/interim/kpi_mapping/', - r'/interim/ml/extraction/', - r'/interim/ml/training/', - r'/interim/ml/curation/', - r'/output/RELEVANCE/Text'] - - with (patch('train_on_pdf.link_files', Mock()), - patch('train_on_pdf.run_router', side_effect=lambda *args: False), - patch('train_on_pdf.create_directory', Mock()) as mocked_create_directory): - + r"/interim/ml", + r"/interim/pdfs/", + r"/interim/ml/annotations/", + r"/interim/kpi_mapping/", + r"/interim/ml/extraction/", + r"/interim/ml/training/", + r"/interim/ml/curation/", + r"/output/RELEVANCE/Text", + ] + + with ( + patch("train_on_pdf.link_files", Mock()), + patch("train_on_pdf.run_router", side_effect=lambda *args: False), + patch("train_on_pdf.create_directory", Mock()) as mocked_create_directory, + ): train_on_pdf.main() - + # we have to combine pathlib object with str path... - path_folder_temporary = path_folder_temporary / 'data' - path_folder_temporary = str(path_folder_temporary) + '/TEST' + path_folder_temporary = path_folder_temporary / "data" + path_folder_temporary = str(path_folder_temporary) + "/TEST" for path_current in paths_folders_expected: path_folder_current = path_folder_temporary + path_current mocked_create_directory.assert_any_call(str(path_folder_current)) - -@pytest.mark.parametrize('prerequisite_train_on_pdf_try_run', - [('train_relevance', 'train', True)], - indirect=True) + +@pytest.mark.parametrize("prerequisite_train_on_pdf_try_run", [("train_relevance", "train", True)], indirect=True) def test_train_on_pdf_folders_relevance_created(path_folder_temporary: Path): """Tests of the relevance folder is created :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - - with (patch('train_on_pdf.link_files', Mock()), - patch('train_on_pdf.run_router', side_effect=lambda *args: False), - patch('train_on_pdf.create_directory', Mock()) as mocked_create_directory): - + + with ( + patch("train_on_pdf.link_files", Mock()), + patch("train_on_pdf.run_router", side_effect=lambda *args: False), + patch("train_on_pdf.create_directory", Mock()) as mocked_create_directory, + ): train_on_pdf.main() - + # we have to combine pathlib object with str path... - path_folder_temporary = path_folder_temporary / 'models' - path_folder_temporary = str(path_folder_temporary) + '/TEST' - path_folder_expected = path_folder_temporary + '/RELEVANCE/Text/test' + path_folder_temporary = path_folder_temporary / "models" + path_folder_temporary = str(path_folder_temporary) + "/TEST" + path_folder_expected = path_folder_temporary + "/RELEVANCE/Text/test" mocked_create_directory.assert_any_call(str(path_folder_expected)) - -@pytest.mark.parametrize('prerequisite_train_on_pdf_try_run', - [('train_kpi', 'train', True)], - indirect=True) + +@pytest.mark.parametrize("prerequisite_train_on_pdf_try_run", [("train_kpi", "train", True)], indirect=True) def test_train_on_pdf_folders_kpi_extraction_created(path_folder_temporary: Path): """Tests of the kpi extraction folder is created :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - with (patch('train_on_pdf.link_files', Mock()), - patch('train_on_pdf.run_router', side_effect=lambda *args: False), - patch('train_on_pdf.create_directory', Mock()) as mocked_create_directory): - + with ( + patch("train_on_pdf.link_files", Mock()), + patch("train_on_pdf.run_router", side_effect=lambda *args: False), + patch("train_on_pdf.create_directory", Mock()) as mocked_create_directory, + ): train_on_pdf.main() - + # we have to combine pathlib object with str path... - path_folder_temporary = path_folder_temporary / 'models' - path_folder_temporary = str(path_folder_temporary) + '/TEST' - path_folder_expected = path_folder_temporary + '/KPI_EXTRACTION/Text/test' + path_folder_temporary = path_folder_temporary / "models" + path_folder_temporary = str(path_folder_temporary) + "/TEST" + path_folder_expected = path_folder_temporary + "/KPI_EXTRACTION/Text/test" mocked_create_directory.assert_any_call(str(path_folder_expected)) - -@pytest.mark.parametrize('prerequisite_train_on_pdf_try_run', - [('extraction', 'store_extractions', True)], - indirect=True) -def test_train_on_pdf_e2e_store_extractions(path_folder_temporary: Path, - capsys: typing.Generator[CaptureFixture[str], None, None]): + +@pytest.mark.parametrize( + "prerequisite_train_on_pdf_try_run", [("extraction", "store_extractions", True)], indirect=True +) +def test_train_on_pdf_e2e_store_extractions( + path_folder_temporary: Path, capsys: typing.Generator[CaptureFixture[str], None, None] +): """Tests of the extraction works properly :param path_folder_temporary: Requesting the path_folder_temporary fixture @@ -360,55 +348,58 @@ def test_train_on_pdf_e2e_store_extractions(path_folder_temporary: Path, :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None]) """ - - with (patch('train_on_pdf.link_files', Mock()), - patch('train_on_pdf.run_router', side_effect=lambda *args: True), - patch('train_on_pdf.save_train_info', Mock()) as mocked_save_train_info, - patch('train_on_pdf.copy_file_without_overwrite', Mock()) as mocked_copy_files, - patch('train_on_pdf.create_directory', Mock())): + + with ( + patch("train_on_pdf.link_files", Mock()), + patch("train_on_pdf.run_router", side_effect=lambda *args: True), + patch("train_on_pdf.save_train_info", Mock()) as mocked_save_train_info, + patch("train_on_pdf.copy_file_without_overwrite", Mock()) as mocked_copy_files, + patch("train_on_pdf.create_directory", Mock()), + ): mocked_copy_files.return_value = False - + train_on_pdf.main() - + # we have to combine pathlib object with str path... - path_folder_root = path_folder_temporary / 'data' - path_folder_root_source = str(path_folder_root) + '/TEST/interim/ml/extraction/' - path_folder_root_destination = str(path_folder_root) + '/TEST/output/TEXT_EXTRACTION' + path_folder_root = path_folder_temporary / "data" + path_folder_root_source = str(path_folder_root) + "/TEST/interim/ml/extraction/" + path_folder_root_destination = str(path_folder_root) + "/TEST/output/TEXT_EXTRACTION" output_cmd, _ = capsys.readouterr() - - assert 'Finally we transfer the text extraction to the output folder\n' in output_cmd + + assert "Finally we transfer the text extraction to the output folder\n" in output_cmd mocked_copy_files.assert_called_with(path_folder_root_source, path_folder_root_destination) -@pytest.mark.parametrize('prerequisite_train_on_pdf_try_run', - [('general', 'delete_interim_files', True)], - indirect=True) +@pytest.mark.parametrize( + "prerequisite_train_on_pdf_try_run", [("general", "delete_interim_files", True)], indirect=True +) def test_train_on_pdf_e2e_delete_interim_files(path_folder_root_testing: Path): """Tests if interim files are getting deleted :param path_folder_root_testing: Requesting the path_folder_root_testing fixture :type path_folder_root_testing: Path """ - + # define the folders for getting checked paths_folders_expected = [ - r'interim/pdfs/', - r'interim/kpi_mapping/', - r'interim/ml/annotations/', - r'interim/ml/extraction/', - r'interim/ml/training/', - r'interim/ml/curation/', - ] - - with (patch('train_on_pdf.link_files', Mock()), - patch('train_on_pdf.run_router', side_effect=lambda *args: True), - patch('train_on_pdf.save_train_info', Mock()) as mocked_save_train_info, - patch('train_on_pdf.create_directory', Mock())): - + r"interim/pdfs/", + r"interim/kpi_mapping/", + r"interim/ml/annotations/", + r"interim/ml/extraction/", + r"interim/ml/training/", + r"interim/ml/curation/", + ] + + with ( + patch("train_on_pdf.link_files", Mock()), + patch("train_on_pdf.run_router", side_effect=lambda *args: True), + patch("train_on_pdf.save_train_info", Mock()) as mocked_save_train_info, + patch("train_on_pdf.create_directory", Mock()), + ): train_on_pdf.main() - + # we have to combine pathlib object with str path... - path_folder_root_testing = path_folder_root_testing / 'data' / 'TEST' + path_folder_root_testing = path_folder_root_testing / "data" / "TEST" for path_current in paths_folders_expected: path_folder_current = path_folder_root_testing / path_current assert not any(path_folder_current.iterdir()) @@ -420,29 +411,32 @@ def test_train_on_pdf_e2e_save_train_info(capsys: typing.Generator[CaptureFixtur :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None] """ - with (patch('train_on_pdf.link_files', Mock()), - patch('train_on_pdf.run_router', side_effect=lambda *args: True), - patch('train_on_pdf.save_train_info', Mock()) as mocked_save_train_info, - patch('train_on_pdf.create_directory', Mock())): + with ( + patch("train_on_pdf.link_files", Mock()), + patch("train_on_pdf.run_router", side_effect=lambda *args: True), + patch("train_on_pdf.save_train_info", Mock()) as mocked_save_train_info, + patch("train_on_pdf.create_directory", Mock()), + ): train_on_pdf.main() - + mocked_save_train_info.assert_called_once() output_cmd, _ = capsys.readouterr() assert output_cmd == "End-to-end inference complete\n" - + def test_train_on_pdf_process_failed(capsys: typing.Generator[CaptureFixture[str], None, None]): """Tests for cmd output if exception is raised :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None] """ - with (patch('train_on_pdf.link_files', Mock()), - patch('train_on_pdf.run_router', side_effect=lambda *args: False), - patch('train_on_pdf.link_files', side_effect=ValueError()), - patch('train_on_pdf.create_directory', lambda *args: Path(args[0]).mkdir(exist_ok=True))): - + with ( + patch("train_on_pdf.link_files", Mock()), + patch("train_on_pdf.run_router", side_effect=lambda *args: False), + patch("train_on_pdf.link_files", side_effect=ValueError()), + patch("train_on_pdf.create_directory", lambda *args: Path(args[0]).mkdir(exist_ok=True)), + ): train_on_pdf.main() - + output_cmd, _ = capsys.readouterr() - assert "Process failed to run. Reason: " in output_cmd \ No newline at end of file + assert "Process failed to run. Reason: " in output_cmd diff --git a/data_extractor/code/tests/test_utils/test_convert_xls_to_csv.py b/data_extractor/code/tests/test_utils/test_convert_xls_to_csv.py index ae65cee..fe9e840 100644 --- a/data_extractor/code/tests/test_utils/test_convert_xls_to_csv.py +++ b/data_extractor/code/tests/test_utils/test_convert_xls_to_csv.py @@ -17,17 +17,19 @@ def prerequisites_convert_xls_to_csv(path_folder_temporary: Path) -> None: :type path_folder_temporary: Path :rtype: None """ - path_source_annotation = path_folder_temporary / 'input' / 'pdfs' / 'training' - path_destination_annotation = path_folder_temporary / 'interim' / 'ml' / 'annotations' - path_source_annotation.mkdir(parents = True, exist_ok = True) - path_destination_annotation.mkdir(parents = True, exist_ok = True) + path_source_annotation = path_folder_temporary / "input" / "pdfs" / "training" + path_destination_annotation = path_folder_temporary / "interim" / "ml" / "annotations" + path_source_annotation.mkdir(parents=True, exist_ok=True) + path_destination_annotation.mkdir(parents=True, exist_ok=True) project_prefix = str(path_folder_temporary) - - with (patch('train_on_pdf.source_annotation', str(path_source_annotation)), - patch('train_on_pdf.destination_annotation', str(path_destination_annotation)), - patch('train_on_pdf.project_prefix', project_prefix)): + + with ( + patch("train_on_pdf.source_annotation", str(path_source_annotation)), + patch("train_on_pdf.destination_annotation", str(path_destination_annotation)), + patch("train_on_pdf.project_prefix", project_prefix), + ): yield - + # cleanup for path in path_folder_temporary.glob("*"): shutil.rmtree(path) @@ -35,22 +37,22 @@ def prerequisites_convert_xls_to_csv(path_folder_temporary: Path) -> None: def test_convert_xls_to_csv_download_s3(): """Tests the function convert_xls_to_csv for successfully downloading - files from a S3 bucket. All required variables/functions/methods are mocked by the + files from a S3 bucket. All required variables/functions/methods are mocked by the prerequisites_convert_xls_to_csv fixture Requesting prerequisites_convert_xls_to_csv automatically (autouse) """ - + s3_usage = True - mocked_s3c_main = Mock(spec = s3_communication.S3Communication) + mocked_s3c_main = Mock(spec=s3_communication.S3Communication) mocked_s3c_main.download_files_in_prefix_to_dir.side_effect = lambda *args: create_single_xlsx_file(Path(args[1])) - mocked_s3c_interim = Mock(spec = s3_communication.S3Communication) - + mocked_s3c_interim = Mock(spec=s3_communication.S3Communication) + convert_xls_to_csv(s3_usage, mocked_s3c_main, mocked_s3c_interim) - - mocked_s3c_main.download_files_in_prefix_to_dir.assert_called_once() - content_folder_source_annotation = list(Path(train_on_pdf.source_annotation).glob('*.xlsx')) - assert len(content_folder_source_annotation) == 1 - + + mocked_s3c_main.download_files_in_prefix_to_dir.assert_called_once() + content_folder_source_annotation = list(Path(train_on_pdf.source_annotation).glob("*.xlsx")) + assert len(content_folder_source_annotation) == 1 + def test_convert_xls_to_csv_upload_s3(): """Tests the function convert_xls_to_csv for successfully uploading @@ -58,15 +60,17 @@ def test_convert_xls_to_csv_upload_s3(): Requesting prerequisites_convert_xls_to_csv automatically (autouse) """ s3_usage = True - mocked_s3c_main = Mock(spec = s3_communication.S3Communication) + mocked_s3c_main = Mock(spec=s3_communication.S3Communication) mocked_s3c_main.download_files_in_prefix_to_dir.side_effect = lambda *args: create_single_xlsx_file(Path(args[1])) - mocked_s3c_interim = Mock(spec = s3_communication.S3Communication) - mocked_s3c_interim.upload_files_in_dir_to_prefix.side_effect = lambda *args: create_multiple_xlsx_files(Path(args[1])) - + mocked_s3c_interim = Mock(spec=s3_communication.S3Communication) + mocked_s3c_interim.upload_files_in_dir_to_prefix.side_effect = lambda *args: create_multiple_xlsx_files( + Path(args[1]) + ) + convert_xls_to_csv(s3_usage, mocked_s3c_main, mocked_s3c_interim) - + mocked_s3c_interim.upload_files_in_dir_to_prefix.assert_called_once() - + def test_convert_xls_to_csv_value_error_multiple_xls(): """Test the function convert_xls_to_csv for raising ValueError if more than one @@ -74,12 +78,14 @@ def test_convert_xls_to_csv_value_error_multiple_xls(): Requesting prerequisites_convert_xls_to_csv automatically (autouse) """ s3_usage = True - mocked_s3c_main = Mock(spec = s3_communication.S3Communication) + mocked_s3c_main = Mock(spec=s3_communication.S3Communication) # create more than one file executing mocked_s3c_main - mocked_s3c_main.download_files_in_prefix_to_dir.side_effect = lambda *args: create_multiple_xlsx_files(Path(args[1])) - mocked_s3c_interim = Mock(spec = s3_communication.S3Communication) - - with pytest.raises(ValueError, match = 'More than one excel sheet found'): + mocked_s3c_main.download_files_in_prefix_to_dir.side_effect = lambda *args: create_multiple_xlsx_files( + Path(args[1]) + ) + mocked_s3c_interim = Mock(spec=s3_communication.S3Communication) + + with pytest.raises(ValueError, match="More than one excel sheet found"): convert_xls_to_csv(s3_usage, mocked_s3c_main, mocked_s3c_interim) mocked_s3c_main.download_files_in_prefix_to_dir.assert_called_once() @@ -91,14 +97,14 @@ def test_convert_xls_to_csv_value_error_no_annotation_xls(): Requesting prerequisites_convert_xls_to_csv automatically (autouse) """ s3_usage = True - mocked_s3c_main = Mock(spec = s3_communication.S3Communication) + mocked_s3c_main = Mock(spec=s3_communication.S3Communication) # do not create any file mocked_s3c_main.download_files_in_prefix_to_dir.side_effect = lambda *args: None - mocked_s3c_interim = Mock(spec = s3_communication.S3Communication) - - with pytest.raises(ValueError, match = 'No annotation excel sheet found'): + mocked_s3c_interim = Mock(spec=s3_communication.S3Communication) + + with pytest.raises(ValueError, match="No annotation excel sheet found"): convert_xls_to_csv(s3_usage, mocked_s3c_main, mocked_s3c_interim) - + mocked_s3c_main.download_files_in_prefix_to_dir.assert_called_once() @@ -107,25 +113,27 @@ def test_convert_xls_to_csv_s3_usage(): Requesting prerequisites_convert_xls_to_csv automatically (autouse) """ s3_usage = True - mocked_s3c_main = Mock(spec = s3_communication.S3Communication) + mocked_s3c_main = Mock(spec=s3_communication.S3Communication) mocked_s3c_main.download_files_in_prefix_to_dir.side_effect = lambda *args: create_single_xlsx_file(Path(args[1])) - mocked_s3c_interim = Mock(spec = s3_communication.S3Communication) - mocked_s3c_interim.upload_files_in_dir_to_prefix.side_effect = lambda *args: create_multiple_xlsx_files(Path(args[1])) - + mocked_s3c_interim = Mock(spec=s3_communication.S3Communication) + mocked_s3c_interim.upload_files_in_dir_to_prefix.side_effect = lambda *args: create_multiple_xlsx_files( + Path(args[1]) + ) + convert_xls_to_csv(s3_usage, mocked_s3c_main, mocked_s3c_interim) - + mocked_s3c_interim.upload_files_in_dir_to_prefix.assert_called_once() - + def test_convert_xls_to_csv_no_s3_usage(): """Tests the function convert_xls_to_csv for not using an S3 bucket Requesting prerequisites_convert_xls_to_csv automatically (autouse) """ s3_usage = False - mocked_s3c_main = Mock(spec = s3_communication.S3Communication) - mocked_s3c_interim = Mock(spec = s3_communication.S3Communication) - - with pytest.raises(ValueError, match = 'No annotation excel sheet found'): + mocked_s3c_main = Mock(spec=s3_communication.S3Communication) + mocked_s3c_interim = Mock(spec=s3_communication.S3Communication) + + with pytest.raises(ValueError, match="No annotation excel sheet found"): convert_xls_to_csv(s3_usage, mocked_s3c_main, mocked_s3c_interim) - + mocked_s3c_interim.upload_files_in_dir_to_prefix.assert_not_called() diff --git a/data_extractor/code/tests/test_utils/test_copy_file_without_overwrite.py b/data_extractor/code/tests/test_utils/test_copy_file_without_overwrite.py index 8e14c13..0525421 100644 --- a/data_extractor/code/tests/test_utils/test_copy_file_without_overwrite.py +++ b/data_extractor/code/tests/test_utils/test_copy_file_without_overwrite.py @@ -3,7 +3,7 @@ import shutil import pytest - + @pytest.fixture(autouse=True) def prerequisites_copy_file_without_overwrite(path_folder_temporary: Path) -> None: """Defines a fixture for creating the source and destination folder @@ -12,12 +12,12 @@ def prerequisites_copy_file_without_overwrite(path_folder_temporary: Path) -> No :type path_folder_temporary: Path :rtype: None """ - path_folder_source = path_folder_temporary / 'source' - path_folder_destination = path_folder_temporary / 'destination' - path_folder_source.mkdir(parents = True) - path_folder_destination.mkdir(parents = True) + path_folder_source = path_folder_temporary / "source" + path_folder_destination = path_folder_temporary / "destination" + path_folder_source.mkdir(parents=True) + path_folder_destination.mkdir(parents=True) yield - + # cleanup for path in path_folder_temporary.glob("*"): shutil.rmtree(path) @@ -26,34 +26,34 @@ def prerequisites_copy_file_without_overwrite(path_folder_temporary: Path) -> No def test_copy_file_without_overwrite_result(path_folder_temporary: Path): """Tests if copy_file_without_overwrite returns True if executed Requesting prerequisites_copy_file_without_overwrite automatically (autouse) - + :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - path_folder_source = path_folder_temporary / 'source' - path_folder_destination = path_folder_temporary / 'destination' - path_folder_source_file = path_folder_source / 'test.txt' + path_folder_source = path_folder_temporary / "source" + path_folder_destination = path_folder_temporary / "destination" + path_folder_source_file = path_folder_source / "test.txt" path_folder_source_file.touch() - + result = copy_file_without_overwrite(str(path_folder_source), str(path_folder_destination)) assert result == True - - + + def test_copy_file_without_overwrite_file_not_exists(path_folder_temporary: Path): - """Tests that copy_file_without_overwrite copies the files from the source to the + """Tests that copy_file_without_overwrite copies the files from the source to the destination folder if they do no exist in the destination folder Requesting prerequisites_copy_file_without_overwrite automatically (autouse) - + :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - path_folder_source = path_folder_temporary / 'source' - path_folder_destination = path_folder_temporary / 'destination' - path_folder_source_file = path_folder_source / 'test.txt' + path_folder_source = path_folder_temporary / "source" + path_folder_destination = path_folder_temporary / "destination" + path_folder_source_file = path_folder_source / "test.txt" path_folder_source_file.touch() - - path_folder_destination_file = path_folder_destination / 'test.txt' + + path_folder_destination_file = path_folder_destination / "test.txt" assert not path_folder_destination_file.exists() - + copy_file_without_overwrite(str(path_folder_source), str(path_folder_destination)) assert path_folder_destination_file.exists() diff --git a/data_extractor/code/tests/test_utils/test_create_directory.py b/data_extractor/code/tests/test_utils/test_create_directory.py index e2e53c9..9539ca1 100644 --- a/data_extractor/code/tests/test_utils/test_create_directory.py +++ b/data_extractor/code/tests/test_utils/test_create_directory.py @@ -10,7 +10,7 @@ def test_create_directory(path_folder_temporary: Path): :type path_folder_temporary: Path """ create_directory(str(path_folder_temporary)) - + assert path_folder_temporary.exists() @@ -20,10 +20,10 @@ def test_create_directory_cleanup(path_folder_temporary: Path): :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - path_folder_temporary.mkdir(exist_ok = True) + path_folder_temporary.mkdir(exist_ok=True) for i in range(10): - path_current_test_file = path_folder_temporary / f'test_{i}.txt' + path_current_test_file = path_folder_temporary / f"test_{i}.txt" path_current_test_file.touch() - + create_directory(str(path_folder_temporary)) - assert not any(path_folder_temporary.iterdir()) \ No newline at end of file + assert not any(path_folder_temporary.iterdir()) diff --git a/data_extractor/code/tests/test_utils/test_generate_text.py b/data_extractor/code/tests/test_utils/test_generate_text.py index c998645..eebee7f 100644 --- a/data_extractor/code/tests/test_utils/test_generate_text.py +++ b/data_extractor/code/tests/test_utils/test_generate_text.py @@ -20,22 +20,24 @@ def prerequisites_generate_text(path_folder_temporary: Path) -> None: :type path_folder_temporary: Path :rtype: None """ - path_folder_relevance = path_folder_temporary / 'relevance' - path_folder_text_3434 = path_folder_temporary / 'folder_test_3434' - path_folder_relevance.mkdir(parents = True) - path_folder_text_3434.mkdir(parents = True) - + path_folder_relevance = path_folder_temporary / "relevance" + path_folder_text_3434 = path_folder_temporary / "folder_test_3434" + path_folder_relevance.mkdir(parents=True) + path_folder_text_3434.mkdir(parents=True) + # create multiple files in the folder_relevance with the same header for i in range(5): - path_current_file = path_folder_relevance / f'{i}_test.csv' + path_current_file = path_folder_relevance / f"{i}_test.csv" path_current_file.touch() - write_to_file(path_current_file, f'That is a test {i}', 'HEADER') - - with (patch('train_on_pdf.folder_relevance', str(path_folder_relevance)), - patch('train_on_pdf.folder_text_3434', str(path_folder_text_3434)), - patch('train_on_pdf.os.getenv', lambda *args: args[0])): + write_to_file(path_current_file, f"That is a test {i}", "HEADER") + + with ( + patch("train_on_pdf.folder_relevance", str(path_folder_relevance)), + patch("train_on_pdf.folder_text_3434", str(path_folder_text_3434)), + patch("train_on_pdf.os.getenv", lambda *args: args[0]), + ): yield - + # cleanup for path in path_folder_temporary.glob("*"): shutil.rmtree(path) @@ -49,41 +51,45 @@ def test_generate_text_with_s3(path_folder_temporary: Path): :type path_folder_temporary: Path """ # get the path to the temporary folder - path_folder_text_3434 = path_folder_temporary / 'folder_test_3434' - project_name = 'test' - + path_folder_text_3434 = path_folder_temporary / "folder_test_3434" + project_name = "test" + mocked_s3_settings = { - 'prefix': 'test_prefix', - 'main_bucket': { - 's3_endpoint': 'S3_END_MAIN', - 's3_access_key': 'S3_ACCESS_MAIN', - 's3_secret_key': 'S3_SECRET_MAIN', - 's3_bucket_name': 'S3_NAME_MAIN' + "prefix": "test_prefix", + "main_bucket": { + "s3_endpoint": "S3_END_MAIN", + "s3_access_key": "S3_ACCESS_MAIN", + "s3_secret_key": "S3_SECRET_MAIN", + "s3_bucket_name": "S3_NAME_MAIN", + }, + "interim_bucket": { + "s3_endpoint": "S3_END_INTERIM", + "s3_access_key": "S3_ACCESS_INTERIM", + "s3_secret_key": "S3_SECRET_INTERIM", + "s3_bucket_name": "S3_NAME_INTERIM", }, - 'interim_bucket': { - 's3_endpoint': 'S3_END_INTERIM', - 's3_access_key': 'S3_ACCESS_INTERIM', - 's3_secret_key': 'S3_SECRET_INTERIM', - 's3_bucket_name': 'S3_NAME_INTERIM' - } } - - with (patch('train_on_pdf.S3Communication', Mock(spec=s3_communication.S3Communication)) as mocked_s3): + + with patch("train_on_pdf.S3Communication", Mock(spec=s3_communication.S3Communication)) as mocked_s3: generate_text_3434(project_name, True, mocked_s3_settings) - + # check for calls - mocked_s3.assert_any_call(s3_endpoint_url='S3_END_MAIN', - aws_access_key_id='S3_ACCESS_MAIN', - aws_secret_access_key='S3_SECRET_MAIN', - s3_bucket='S3_NAME_MAIN') - mocked_s3.assert_any_call(s3_endpoint_url='S3_END_INTERIM', - aws_access_key_id='S3_ACCESS_INTERIM', - aws_secret_access_key='S3_SECRET_INTERIM', - s3_bucket='S3_NAME_INTERIM') - + mocked_s3.assert_any_call( + s3_endpoint_url="S3_END_MAIN", + aws_access_key_id="S3_ACCESS_MAIN", + aws_secret_access_key="S3_SECRET_MAIN", + s3_bucket="S3_NAME_MAIN", + ) + mocked_s3.assert_any_call( + s3_endpoint_url="S3_END_INTERIM", + aws_access_key_id="S3_ACCESS_INTERIM", + aws_secret_access_key="S3_SECRET_INTERIM", + s3_bucket="S3_NAME_INTERIM", + ) + call_list = [call[0] for call in mocked_s3.mock_calls] - assert any([call for call in call_list if 'download_files_in_prefix_to_dir' in call]) - assert any([call for call in call_list if 'upload_file_to_s3' in call]) + assert any([call for call in call_list if "download_files_in_prefix_to_dir" in call]) + assert any([call for call in call_list if "upload_file_to_s3" in call]) def test_generate_text_no_s3(path_folder_temporary: Path): @@ -96,30 +102,28 @@ def test_generate_text_no_s3(path_folder_temporary: Path): :type path_folder_temporary: Path """ # get the path to the temporary folder - path_folder_text_3434 = path_folder_temporary / 'folder_test_3434' - project_name = 'test' + path_folder_text_3434 = path_folder_temporary / "folder_test_3434" + project_name = "test" s3_usage = False project_settings = None - + generate_text_3434(project_name, s3_usage, project_settings) - - # ensure that the header and the content form the first file is written to + + # ensure that the header and the content form the first file is written to # the file text_3434.csv in folder relevance and the the content of the other # files in folder relevance is appended without the header # check if file_3434 exists - path_file_text_3434_csv = path_folder_text_3434 / 'text_3434.csv' + path_file_text_3434_csv = path_folder_text_3434 / "text_3434.csv" assert path_file_text_3434_csv.exists() - + # check if header and content of files exist - strings_expected = [ - f'That is a test {line_number}' for line_number in range(5) - ] + strings_expected = [f"That is a test {line_number}" for line_number in range(5)] - with open(str(path_file_text_3434_csv), 'r') as file_text_3434: - for line_number, line_content in enumerate(file_text_3434, start = -1): + with open(str(path_file_text_3434_csv), "r") as file_text_3434: + for line_number, line_content in enumerate(file_text_3434, start=-1): if line_number == -1: - assert line_content.rstrip() == 'HEADER' + assert line_content.rstrip() == "HEADER" else: assert line_content.rstrip() in strings_expected @@ -131,16 +135,17 @@ def test_generate_text_successful(path_folder_temporary: Path): :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - project_name = 'test' + project_name = "test" s3_usage = False project_settings = None - + return_value = generate_text_3434(project_name, s3_usage, project_settings) assert return_value == True - -def test_generate_text_not_successful_empty_folder(path_folder_temporary: Path, - capsys: typing.Generator[CaptureFixture[str], None, None]): + +def test_generate_text_not_successful_empty_folder( + path_folder_temporary: Path, capsys: typing.Generator[CaptureFixture[str], None, None] +): """Tests if the function returns false Requesting prerequisites_generate_text automatically (autouse) @@ -149,22 +154,22 @@ def test_generate_text_not_successful_empty_folder(path_folder_temporary: Path, :param capsys: Requesting default fixture for capturing cmd output :type capsys: typing.Generator[CaptureFixture[str], None, None]) """ - project_name = 'test' + project_name = "test" s3_usage = False project_settings = None - + # clear the relevance folder - path_folder_relevance = path_folder_temporary / 'relevance' - [file.unlink() for file in path_folder_relevance.glob("*") if file.is_file()] - + path_folder_relevance = path_folder_temporary / "relevance" + [file.unlink() for file in path_folder_relevance.glob("*") if file.is_file()] + # call the function return_value = generate_text_3434(project_name, s3_usage, project_settings) - + output_cmd, _ = capsys.readouterr() - assert 'No relevance inference results found.' in output_cmd + assert "No relevance inference results found." in output_cmd assert return_value == False - + def test_generate_text_not_successful_exception(path_folder_temporary: Path): """Tests if the function returns false Requesting prerequisites_generate_text automatically (autouse) @@ -172,16 +177,16 @@ def test_generate_text_not_successful_exception(path_folder_temporary: Path): :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - project_name = 'test' + project_name = "test" s3_usage = False project_settings = None - + # clear the relevance folder - path_folder_relevance = path_folder_temporary / 'relevance' + path_folder_relevance = path_folder_temporary / "relevance" [file.unlink() for file in path_folder_relevance.glob("*") if file.is_file()] - + # patch glob.iglob to force an exception... - with patch('train_on_pdf.glob.iglob', side_effect=lambda *args: [None]): + with patch("train_on_pdf.glob.iglob", side_effect=lambda *args: [None]): return_value = generate_text_3434(project_name, s3_usage, project_settings) - + assert return_value == False diff --git a/data_extractor/code/tests/test_utils/test_link_files.py b/data_extractor/code/tests/test_utils/test_link_files.py index 72b4944..075b305 100644 --- a/data_extractor/code/tests/test_utils/test_link_files.py +++ b/data_extractor/code/tests/test_utils/test_link_files.py @@ -12,12 +12,12 @@ def path_folders_required_linking(path_folder_temporary: Path) -> None: :type path_folder_temporary: Path :return: None """ - path_folder_source = path_folder_temporary / 'source' - path_folder_source_pdf = path_folder_temporary / 'source_pdf' - path_folder_destination = path_folder_temporary / 'destination' - path_folder_source.mkdir(parents = True) - path_folder_source_pdf.mkdir(parents = True) - path_folder_destination.mkdir(parents = True) + path_folder_source = path_folder_temporary / "source" + path_folder_source_pdf = path_folder_temporary / "source_pdf" + path_folder_destination = path_folder_temporary / "destination" + path_folder_source.mkdir(parents=True) + path_folder_source_pdf.mkdir(parents=True) + path_folder_destination.mkdir(parents=True) yield # cleanup @@ -28,72 +28,70 @@ def path_folders_required_linking(path_folder_temporary: Path) -> None: def test_link_files(path_folder_temporary: Path): """Tests if link_files creates proper hard links Requesting path_folders_required_linking automatically (autouse) - + :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - path_folder_source = path_folder_temporary / 'source' - path_folder_source_pdf = path_folder_temporary / 'source_pdf' - path_folder_destination = path_folder_temporary / 'destination' - + path_folder_source = path_folder_temporary / "source" + path_folder_source_pdf = path_folder_temporary / "source_pdf" + path_folder_destination = path_folder_temporary / "destination" + for i in range(10): - path_current_file = path_folder_source / f'test_{i}.txt' + path_current_file = path_folder_source / f"test_{i}.txt" path_current_file.touch() - + link_files(str(path_folder_source), str(path_folder_destination)) - + for i in range(10): - path_current_file = path_folder_source / f'test_{i}.txt' + path_current_file = path_folder_source / f"test_{i}.txt" assert path_current_file.stat().st_nlink == 2 def test_link_extracted_files_result(path_folder_temporary: Path): """Tests if link_extracted_files returns True if executed Requesting path_folders_required_linking automatically (autouse) - + :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - path_folder_source = path_folder_temporary / 'source' - path_folder_source_pdf = path_folder_temporary / 'source_pdf' - path_folder_destination = path_folder_temporary / 'destination' - - path_folder_source_file_pdf = path_folder_source / f'test.pdf' - path_folder_source_file_json = path_folder_source / f'test.json' - path_source_file_pdf = path_folder_source_pdf / f'test.pdf' - - result = link_extracted_files(str(path_folder_source), str(path_folder_source_pdf), - str(path_folder_destination)) + path_folder_source = path_folder_temporary / "source" + path_folder_source_pdf = path_folder_temporary / "source_pdf" + path_folder_destination = path_folder_temporary / "destination" + + path_folder_source_file_pdf = path_folder_source / f"test.pdf" + path_folder_source_file_json = path_folder_source / f"test.json" + path_source_file_pdf = path_folder_source_pdf / f"test.pdf" + + result = link_extracted_files(str(path_folder_source), str(path_folder_source_pdf), str(path_folder_destination)) assert result == True - - + + def test_link_extracted_files_copy(path_folder_temporary: Path): - """Tests if the extracted json files in folder_source has a regarding pdf in the folder_source_pdf + """Tests if the extracted json files in folder_source has a regarding pdf in the folder_source_pdf and if so, copy the json file to the folder_destination Requesting path_folders_required_linking automatically (autouse) - + :param path_folder_temporary: Requesting the path_folder_temporary fixture :type path_folder_temporary: Path """ - path_folder_source = path_folder_temporary / 'source' - path_folder_source_pdf = path_folder_temporary / 'source_pdf' - path_folder_destination = path_folder_temporary / 'destination' - + path_folder_source = path_folder_temporary / "source" + path_folder_source_pdf = path_folder_temporary / "source_pdf" + path_folder_destination = path_folder_temporary / "destination" + for i in range(10): - path_current_file = path_folder_source / f'test_{i}.pdf' + path_current_file = path_folder_source / f"test_{i}.pdf" path_current_file.touch() - path_current_file = path_folder_source / f'test_{i}.json' + path_current_file = path_folder_source / f"test_{i}.json" path_current_file.touch() - path_current_file = path_folder_source_pdf / f'test_{i}.pdf' + path_current_file = path_folder_source_pdf / f"test_{i}.pdf" path_current_file.touch() - + for i in range(10): - path_current_file = path_folder_destination / f'test_{i}.json' + path_current_file = path_folder_destination / f"test_{i}.json" assert not path_current_file.exists() == True - - link_extracted_files(str(path_folder_source), str(path_folder_source_pdf), - str(path_folder_destination)) - + + link_extracted_files(str(path_folder_source), str(path_folder_source_pdf), str(path_folder_destination)) + for i in range(10): - path_current_file = path_folder_destination / f'test_{i}.json' + path_current_file = path_folder_destination / f"test_{i}.json" assert path_current_file.exists() == True diff --git a/data_extractor/code/tests/test_utils/test_run_router.py b/data_extractor/code/tests/test_utils/test_run_router.py index 5bbf85e..29b3a42 100644 --- a/data_extractor/code/tests/test_utils/test_run_router.py +++ b/data_extractor/code/tests/test_utils/test_run_router.py @@ -15,51 +15,57 @@ @pytest.fixture -def prerequisites_run_router(prerequisites_convert_xls_to_csv, - prerequisites_generate_text - ) -> requests_mock.mocker.Mocker: +def prerequisites_run_router( + prerequisites_convert_xls_to_csv, prerequisites_generate_text +) -> requests_mock.mocker.Mocker: """Prerequisites for running the function run_router - :param prerequisites_convert_xls_to_csv: Requesting fixture for running function convert_xls_to_csv (required in + :param prerequisites_convert_xls_to_csv: Requesting fixture for running function convert_xls_to_csv (required in run_router) - :param prerequisites_generate_text: Requesting fixture for running function generate_text (required in + :param prerequisites_generate_text: Requesting fixture for running function generate_text (required in run_router) :rtype: requests_mock.mocker.Mocker """ mocked_project_settings = { - 'train_relevance': {'train': False}, - 'train_kpi': {'train': False}, - 's3_usage': None, - 's3_settings': None + "train_relevance": {"train": False}, + "train_kpi": {"train": False}, + "s3_usage": None, + "s3_settings": None, } - extraction_ip = '0.0.0.0' - extraction_port = '8000' - inference_ip = '0.0.0.1' - inference_port = '8000' - - with (requests_mock.Mocker() as mocked_server, - patch('train_on_pdf.convert_xls_to_csv', Mock()), - patch('train_on_pdf.project_settings', mocked_project_settings)): - mocked_server.get(f'http://{extraction_ip}:{extraction_port}/liveness', status_code=200) - mocked_server.get(f'http://{extraction_ip}:{extraction_port}/extract', status_code=200) - mocked_server.get(f'http://{extraction_ip}:{extraction_port}/curate', status_code=200) - mocked_server.get(f'http://{inference_ip}:{inference_port}/liveness', status_code=200) - mocked_server.get(f'http://{inference_ip}:{inference_port}/train_relevance', status_code=200) - mocked_server.get(f'http://{inference_ip}:{inference_port}/infer_relevance', status_code=200) - mocked_server.get(f'http://{inference_ip}:{inference_port}/train_kpi', status_code=200) + extraction_ip = "0.0.0.0" + extraction_port = "8000" + inference_ip = "0.0.0.1" + inference_port = "8000" + + with ( + requests_mock.Mocker() as mocked_server, + patch("train_on_pdf.convert_xls_to_csv", Mock()), + patch("train_on_pdf.project_settings", mocked_project_settings), + ): + mocked_server.get(f"http://{extraction_ip}:{extraction_port}/liveness", status_code=200) + mocked_server.get(f"http://{extraction_ip}:{extraction_port}/extract", status_code=200) + mocked_server.get(f"http://{extraction_ip}:{extraction_port}/curate", status_code=200) + mocked_server.get(f"http://{inference_ip}:{inference_port}/liveness", status_code=200) + mocked_server.get(f"http://{inference_ip}:{inference_port}/train_relevance", status_code=200) + mocked_server.get(f"http://{inference_ip}:{inference_port}/infer_relevance", status_code=200) + mocked_server.get(f"http://{inference_ip}:{inference_port}/train_kpi", status_code=200) yield mocked_server - - -@pytest.mark.parametrize('status_code, cmd_output_expected, return_value_expected', - [ - (200, 'Extraction server is up. Proceeding to extraction.', True), - (-1, 'Extraction server is not responding.', False) - ]) -def test_run_router_extraction_liveness_up(prerequisites_run_router: requests_mock.mocker.Mocker, - status_code: int, - cmd_output_expected: str, - return_value_expected: bool, - capsys: typing.Generator[CaptureFixture[str], None, None]): + + +@pytest.mark.parametrize( + "status_code, cmd_output_expected, return_value_expected", + [ + (200, "Extraction server is up. Proceeding to extraction.", True), + (-1, "Extraction server is not responding.", False), + ], +) +def test_run_router_extraction_liveness_up( + prerequisites_run_router: requests_mock.mocker.Mocker, + status_code: int, + cmd_output_expected: str, + return_value_expected: bool, + capsys: typing.Generator[CaptureFixture[str], None, None], +): """Tests the liveness of the extraction server :param prerequisites_run_router: Requesting the prerequisites_run_router fixture @@ -72,14 +78,14 @@ def test_run_router_extraction_liveness_up(prerequisites_run_router: requests_mo :type return_value_expected: bool :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None]) - """ - extraction_ip = '0.0.0.0' - extraction_port = '8000' - inference_port = '8000' - project_name = 'TEST' + """ + extraction_ip = "0.0.0.0" + extraction_port = "8000" + inference_port = "8000" + project_name = "TEST" mocked_server = prerequisites_run_router - - mocked_server.get(f'http://{extraction_ip}:{extraction_port}/liveness', status_code=status_code) + + mocked_server.get(f"http://{extraction_ip}:{extraction_port}/liveness", status_code=status_code) return_value = run_router(extraction_port, inference_port, project_name) cmd_output, _ = capsys.readouterr() @@ -93,13 +99,13 @@ def test_run_router_extraction_server_down(prerequisites_run_router: requests_mo :param prerequisites_run_router: Requesting the prerequisites_run_router fixture :type prerequisites_run_router: requests_mock.mocker.Mocker """ - extraction_ip = '0.0.0.0' - extraction_port = '8000' - inference_port = '8000' - project_name = 'TEST' + extraction_ip = "0.0.0.0" + extraction_port = "8000" + inference_port = "8000" + project_name = "TEST" mocked_server = prerequisites_run_router - - mocked_server.get(f'http://{extraction_ip}:{extraction_port}/extract', status_code=-1) + + mocked_server.get(f"http://{extraction_ip}:{extraction_port}/extract", status_code=-1) return_value = run_router(extraction_port, inference_port, project_name) assert return_value is False @@ -111,28 +117,32 @@ def test_run_router_extraction_curation_server_down(prerequisites_run_router: re :param prerequisites_run_router: Requesting the prerequisites_run_router fixture :type prerequisites_run_router: requests_mock.mocker.Mocker """ - extraction_ip = '0.0.0.0' - extraction_port = '8000' - inference_port = '8000' - project_name = 'TEST' + extraction_ip = "0.0.0.0" + extraction_port = "8000" + inference_port = "8000" + project_name = "TEST" mocked_server = prerequisites_run_router - - mocked_server.get(f'http://{extraction_ip}:{extraction_port}/curate', status_code=-1) + + mocked_server.get(f"http://{extraction_ip}:{extraction_port}/curate", status_code=-1) return_value = run_router(extraction_port, inference_port, project_name) assert return_value is False -@pytest.mark.parametrize('status_code, cmd_output_expected, return_value_expected', - [ - (200, 'Inference server is up. Proceeding to Inference.', True), - (-1, 'Inference server is not responding.', False) - ]) -def test_run_router_inference_liveness(prerequisites_run_router: requests_mock.mocker.Mocker, - status_code: int, - cmd_output_expected: str, - return_value_expected: bool, - capsys: typing.Generator[CaptureFixture[str], None, None]): +@pytest.mark.parametrize( + "status_code, cmd_output_expected, return_value_expected", + [ + (200, "Inference server is up. Proceeding to Inference.", True), + (-1, "Inference server is not responding.", False), + ], +) +def test_run_router_inference_liveness( + prerequisites_run_router: requests_mock.mocker.Mocker, + status_code: int, + cmd_output_expected: str, + return_value_expected: bool, + capsys: typing.Generator[CaptureFixture[str], None, None], +): """Tests the liveness of the inference server, up as well as down :param prerequisites_run_router: Requesting the prerequisites_run_router fixture @@ -146,34 +156,45 @@ def test_run_router_inference_liveness(prerequisites_run_router: requests_mock.m :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None] """ - extraction_ip = '0.0.0.0' - extraction_port = '8000' - inference_ip = '0.0.0.1' - inference_port = '8000' - project_name = 'TEST' + extraction_ip = "0.0.0.0" + extraction_port = "8000" + inference_ip = "0.0.0.1" + inference_port = "8000" + project_name = "TEST" mocked_server = prerequisites_run_router - - mocked_server.get(f'http://{inference_ip}:{inference_port}/liveness', status_code=status_code) + + mocked_server.get(f"http://{inference_ip}:{inference_port}/liveness", status_code=status_code) return_value = run_router(extraction_port, inference_port, project_name, infer_ip=inference_ip) - + cmd_output, _ = capsys.readouterr() assert cmd_output_expected in cmd_output assert return_value == return_value_expected -@pytest.mark.parametrize('train_relevance, status_code, cmd_output_expected, return_value_expected', - [ - (True, -1, "Relevance training will be started.", False), - (True, 200, "Relevance training will be started.", True), - (False, -1, ("No relevance training done. If you want to have a relevance training please " - "set variable train under train_relevance to true."), True) - ]) -def test_run_router_relevance_training(prerequisites_run_router: requests_mock.mocker.Mocker, - train_relevance: bool, - status_code: int, - cmd_output_expected: str, - return_value_expected: bool, - capsys: typing.Generator[CaptureFixture[str], None, None]): +@pytest.mark.parametrize( + "train_relevance, status_code, cmd_output_expected, return_value_expected", + [ + (True, -1, "Relevance training will be started.", False), + (True, 200, "Relevance training will be started.", True), + ( + False, + -1, + ( + "No relevance training done. If you want to have a relevance training please " + "set variable train under train_relevance to true." + ), + True, + ), + ], +) +def test_run_router_relevance_training( + prerequisites_run_router: requests_mock.mocker.Mocker, + train_relevance: bool, + status_code: int, + cmd_output_expected: str, + return_value_expected: bool, + capsys: typing.Generator[CaptureFixture[str], None, None], +): """Tests if the relevance training fails and successfully starts :param prerequisites_run_router: Requesting the prerequisites_run_router fixture @@ -189,14 +210,14 @@ def test_run_router_relevance_training(prerequisites_run_router: requests_mock.m :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None] """ - extraction_port = '8000' - inference_ip = '0.0.0.1' - inference_port = '8000' - project_name = 'TEST' + extraction_port = "8000" + inference_ip = "0.0.0.1" + inference_port = "8000" + project_name = "TEST" mocked_server = prerequisites_run_router - train_on_pdf.project_settings['train_relevance']['train'] = train_relevance - - mocked_server.get(f'http://{inference_ip}:{inference_port}/train_relevance', status_code=status_code) + train_on_pdf.project_settings["train_relevance"]["train"] = train_relevance + + mocked_server.get(f"http://{inference_ip}:{inference_port}/train_relevance", status_code=status_code) return_value = run_router(extraction_port, inference_port, project_name, infer_ip=inference_ip) cmd_output, _ = capsys.readouterr() @@ -204,24 +225,37 @@ def test_run_router_relevance_training(prerequisites_run_router: requests_mock.m assert return_value == return_value_expected -@pytest.mark.parametrize('train_kpi, status_code_infer_relevance, project_name, status_code_train_kpi, cmd_output_expected, return_value_expected', - [ - (True, -1, "TEST", -1, "", False), - (True, 200, "TEST", -1, "text_3434 was generated without error", False), - (True, 200, "TEST", 200, "text_3434 was not generated without error", True), - (True, 200, None, -1, "Error while generating text_3434.", False), - (True, 200, None, 200, "Error while generating text_3434.", True), - (False, -1, None, -1, ("No kpi training done. If you want to have a kpi " - "training please set variable train under train_kpi to true."), True) - ]) -def test_run_router_kpi_training(prerequisites_run_router: requests_mock.mocker.Mocker, - train_kpi: bool, - status_code_infer_relevance: int, - project_name: typing.Union[str, None], - status_code_train_kpi: int, - cmd_output_expected: str, - return_value_expected: bool, - capsys: typing.Generator[CaptureFixture[str], None, None]): +@pytest.mark.parametrize( + "train_kpi, status_code_infer_relevance, project_name, status_code_train_kpi, cmd_output_expected, return_value_expected", + [ + (True, -1, "TEST", -1, "", False), + (True, 200, "TEST", -1, "text_3434 was generated without error", False), + (True, 200, "TEST", 200, "text_3434 was not generated without error", True), + (True, 200, None, -1, "Error while generating text_3434.", False), + (True, 200, None, 200, "Error while generating text_3434.", True), + ( + False, + -1, + None, + -1, + ( + "No kpi training done. If you want to have a kpi " + "training please set variable train under train_kpi to true." + ), + True, + ), + ], +) +def test_run_router_kpi_training( + prerequisites_run_router: requests_mock.mocker.Mocker, + train_kpi: bool, + status_code_infer_relevance: int, + project_name: typing.Union[str, None], + status_code_train_kpi: int, + cmd_output_expected: str, + return_value_expected: bool, + capsys: typing.Generator[CaptureFixture[str], None, None], +): """Tests if kpi training fails and successfully starts :param prerequisites_run_router: Requesting the prerequisites_run_router fixture @@ -241,18 +275,18 @@ def test_run_router_kpi_training(prerequisites_run_router: requests_mock.mocker. :param capsys: Requesting the default fixture capsys for capturing cmd outputs :type capsys: typing.Generator[CaptureFixture[str], None, None] """ - extraction_ip = '0.0.0.0' - extraction_port = '8000' - inference_ip = '0.0.0.1' - inference_port = '8000' + extraction_ip = "0.0.0.0" + extraction_port = "8000" + inference_ip = "0.0.0.1" + inference_port = "8000" mocked_server = prerequisites_run_router - train_on_pdf.project_settings['train_kpi']['train'] = train_kpi - + train_on_pdf.project_settings["train_kpi"]["train"] = train_kpi + # force an exception of generate_text_3434 by removing the folder_text_3434 if not project_name: train_on_pdf.folder_text_3434 = None - - mocked_generate_text = Mock() + + mocked_generate_text = Mock() if project_name: if status_code_train_kpi < 0: mocked_generate_text.side_effect = lambda *args: True @@ -260,10 +294,12 @@ def test_run_router_kpi_training(prerequisites_run_router: requests_mock.mocker. mocked_generate_text.side_effect = lambda *args: False else: mocked_generate_text.side_effect = Exception() - - with patch('train_on_pdf.generate_text_3434', mocked_generate_text): - mocked_server.get(f'http://{inference_ip}:{inference_port}/infer_relevance', status_code=status_code_infer_relevance) - mocked_server.get(f'http://{inference_ip}:{inference_port}/train_kpi', status_code=status_code_train_kpi) + + with patch("train_on_pdf.generate_text_3434", mocked_generate_text): + mocked_server.get( + f"http://{inference_ip}:{inference_port}/infer_relevance", status_code=status_code_infer_relevance + ) + mocked_server.get(f"http://{inference_ip}:{inference_port}/train_kpi", status_code=status_code_train_kpi) return_value = run_router(extraction_port, inference_port, project_name, infer_ip=inference_ip) cmd_output, _ = capsys.readouterr() @@ -271,16 +307,18 @@ def test_run_router_kpi_training(prerequisites_run_router: requests_mock.mocker. assert return_value == return_value_expected -@pytest.mark.parametrize('infer_relevance, train_kpi', - [ - (True, True), - (True, False), - (True, True), - (True, False), - ]) -def test_run_router_successful_run(prerequisites_run_router: requests_mock.mocker.Mocker, - infer_relevance: bool, - train_kpi: bool): +@pytest.mark.parametrize( + "infer_relevance, train_kpi", + [ + (True, True), + (True, False), + (True, True), + (True, False), + ], +) +def test_run_router_successful_run( + prerequisites_run_router: requests_mock.mocker.Mocker, infer_relevance: bool, train_kpi: bool +): """Tests a successful run of run_router :param prerequisites_run_router: Requesting the prerequisites_run_router fixture @@ -290,14 +328,14 @@ def test_run_router_successful_run(prerequisites_run_router: requests_mock.mocke :type train_kpi: Flag for train kpi :type train_kpi: bool """ - extraction_ip = '0.0.0.0' - extraction_port = '8000' - inference_ip = '0.0.0.1' - inference_port = '8000' - project_name = 'TEST' + extraction_ip = "0.0.0.0" + extraction_port = "8000" + inference_ip = "0.0.0.1" + inference_port = "8000" + project_name = "TEST" mocked_server = prerequisites_run_router - - with patch('train_on_pdf.generate_text_3434', Mock()): + + with patch("train_on_pdf.generate_text_3434", Mock()): return_value = run_router(extraction_port, inference_port, project_name, infer_ip=inference_ip) assert return_value == True diff --git a/data_extractor/code/tests/test_utils/test_running.py b/data_extractor/code/tests/test_utils/test_running.py index 24b36a5..bcbe2cb 100644 --- a/data_extractor/code/tests/test_utils/test_running.py +++ b/data_extractor/code/tests/test_utils/test_running.py @@ -12,15 +12,14 @@ def prerequisite_running(path_folder_root_testing: Path): :param path_folder_root_testing: Path for the testing folder :type path_folder_root_testing: Path """ - path_file_running = path_folder_root_testing / 'data' / 'running' + path_file_running = path_folder_root_testing / "data" / "running" # mock the path to the running file - with patch('train_on_pdf.path_file_running', - str(path_file_running)): + with patch("train_on_pdf.path_file_running", str(path_file_running)): yield # cleanup path_file_running.unlink(missing_ok=True) - + def test_set_running(prerequisite_running, path_folder_root_testing: Path): """Tests the set_running function creating a running file @@ -31,13 +30,13 @@ def test_set_running(prerequisite_running, path_folder_root_testing: Path): :type path_folder_root_testing: Path """ # set path to running file and do a cleanup - path_file_running = path_folder_root_testing / 'data' / 'running' + path_file_running = path_folder_root_testing / "data" / "running" path_file_running.unlink(missing_ok=True) - + # perform set_running and assert that running file exists set_running() assert path_file_running.exists() - + # cleanup path_file_running.unlink() @@ -50,7 +49,7 @@ def test_checking_onging_run(prerequisite_running, path_folder_root_testing: Pat :param path_folder_root_testing: Path for the testing folder :type path_folder_root_testing: Path """ - path_file_running = path_folder_root_testing / 'data' / 'running' + path_file_running = path_folder_root_testing / "data" / "running" path_file_running.touch() assert check_running() == True @@ -63,10 +62,10 @@ def test_checking_finished_run(prerequisite_running, path_folder_root_testing: P :param path_folder_root_testing: Path for the testing folder :type path_folder_root_testing: Path """ - path_file_running = path_folder_root_testing / 'data' / 'running' - path_file_running.unlink(missing_ok = True) + path_file_running = path_folder_root_testing / "data" / "running" + path_file_running.unlink(missing_ok=True) assert check_running() == False - + def test_clear_running(prerequisite_running, path_folder_root_testing: Path): """Tests for clearing running file @@ -76,7 +75,7 @@ def test_clear_running(prerequisite_running, path_folder_root_testing: Path): :param path_folder_root_testing: Path for the testing folder :type path_folder_root_testing: Path """ - path_file_running = path_folder_root_testing / 'data' / 'running' + path_file_running = path_folder_root_testing / "data" / "running" path_file_running.touch() clear_running() assert not path_file_running.exists() diff --git a/data_extractor/code/tests/test_utils/test_save_train_info.py b/data_extractor/code/tests/test_utils/test_save_train_info.py index a3f5ec6..05593f0 100644 --- a/data_extractor/code/tests/test_utils/test_save_train_info.py +++ b/data_extractor/code/tests/test_utils/test_save_train_info.py @@ -8,8 +8,7 @@ @pytest.fixture(autouse=True) -def prerequisites_save_train_info(path_folder_root_testing: Path, - path_folder_temporary: Path) -> Path: +def prerequisites_save_train_info(path_folder_root_testing: Path, path_folder_temporary: Path) -> Path: """Defines a fixture for creating all prerequisites for save_train_info :param path_folder_root_testing: Requesting the root testing folder fixture @@ -22,41 +21,37 @@ def prerequisites_save_train_info(path_folder_root_testing: Path, :rtype: Iterator[Path] """ mocked_project_settings = { - 'train_relevance': { - 'output_model_name': 'TEST' - }, - 'train_kpi':{ - 'output_model_name': 'TEST' - }, - 's3_settings': { - 'prefix' : 'corporate_data_extraction_projects' - } + "train_relevance": {"output_model_name": "TEST"}, + "train_kpi": {"output_model_name": "TEST"}, + "s3_settings": {"prefix": "corporate_data_extraction_projects"}, } - path_source_pdf = path_folder_root_testing / 'input' / 'pdf' / 'training' - path_source_annotation = path_folder_root_testing / 'input' / 'pdfs' / 'training' - path_source_mapping = path_folder_root_testing / 'data' / 'TEST' / 'input' / 'kpi_mapping' - path_project_model_dir = path_folder_temporary / 'models' + path_source_pdf = path_folder_root_testing / "input" / "pdf" / "training" + path_source_annotation = path_folder_root_testing / "input" / "pdfs" / "training" + path_source_mapping = path_folder_root_testing / "data" / "TEST" / "input" / "kpi_mapping" + path_project_model_dir = path_folder_temporary / "models" path_project_model_dir.mkdir(parents=True, exist_ok=True) - relevance_model = mocked_project_settings['train_relevance']['output_model_name'] - kpi_model = mocked_project_settings['train_kpi']['output_model_name'] - file_train_info = f'SUMMARY_REL_{relevance_model}_KPI_{kpi_model}.pickle' + relevance_model = mocked_project_settings["train_relevance"]["output_model_name"] + kpi_model = mocked_project_settings["train_kpi"]["output_model_name"] + file_train_info = f"SUMMARY_REL_{relevance_model}_KPI_{kpi_model}.pickle" path_train_info = path_project_model_dir / file_train_info - with (patch('train_on_pdf.project_settings', mocked_project_settings), - patch('train_on_pdf.source_annotation', str(path_source_annotation)), - patch('train_on_pdf.source_mapping', str(path_source_mapping)), - patch('train_on_pdf.os.listdir', side_effect=lambda *args: 'test.pdf'), - patch('train_on_pdf.source_mapping', str(path_folder_temporary / 'source_mapping')), - patch('train_on_pdf.source_annotation', str(path_folder_temporary / 'source_annotation')), - patch('train_on_pdf.source_pdf', str(path_folder_temporary / 'source_pdf')), - patch('train_on_pdf.pd', Mock()) as mocked_pandas): + with ( + patch("train_on_pdf.project_settings", mocked_project_settings), + patch("train_on_pdf.source_annotation", str(path_source_annotation)), + patch("train_on_pdf.source_mapping", str(path_source_mapping)), + patch("train_on_pdf.os.listdir", side_effect=lambda *args: "test.pdf"), + patch("train_on_pdf.source_mapping", str(path_folder_temporary / "source_mapping")), + patch("train_on_pdf.source_annotation", str(path_folder_temporary / "source_annotation")), + patch("train_on_pdf.source_pdf", str(path_folder_temporary / "source_pdf")), + patch("train_on_pdf.pd", Mock()) as mocked_pandas, + ): train_on_pdf.project_model_dir = str(path_project_model_dir) mocked_pandas.read_csv.return_value = {None} mocked_pandas.read_excel.return_value = {None} yield path_train_info - + # cleanup shutil.rmtree(path_folder_temporary) del train_on_pdf.project_model_dir @@ -68,60 +63,51 @@ def test_save_train_info_pickle(prerequisites_save_train_info: Path): :param prerequisites_save_train_info: Requesting the prerequisites_save_train_info fixture :type prerequisites_save_train_info: Path """ - project_name = 'TEST' + project_name = "TEST" path_train_info = prerequisites_save_train_info - + save_train_info(project_name) - + # we have to combine a pathlib and a string object... path_parent_train_info = path_train_info.parent path_file_pickle = path_train_info.name - path_train_info = Path(str(path_parent_train_info) + f'/{path_file_pickle}') - + path_train_info = Path(str(path_parent_train_info) + f"/{path_file_pickle}") + assert path_train_info.exists() - - + + def test_save_train_info_entries(prerequisites_save_train_info: Path): """Tests if all the train infos exists in the pickled train info file :param prerequisites_save_train_info: Requesting the prerequisites_save_train_info fixture :type prerequisites_save_train_info: Path """ - project_name = 'TEST' + project_name = "TEST" path_train_info = prerequisites_save_train_info - + save_train_info(project_name) - - with open(str(path_train_info), 'rb') as file: + + with open(str(path_train_info), "rb") as file: train_info = pickle.load(file) - - expected_keys = [ - 'project_name', - 'train_settings', - 'pdfs_used', - 'annotations', - 'kpis' - ] + + expected_keys = ["project_name", "train_settings", "pdfs_used", "annotations", "kpis"] # check that all keys exist in dict assert all(key in expected_keys for key in train_info.keys()) - + def test_save_tain_info_return_value(): - project_name = 'TEST' - + project_name = "TEST" + assert save_train_info(project_name) is None - -def test_save_train_info_s3_usage(): - """Tests if the s3_usage flag correctly works - """ - project_name = 'TEST' +def test_save_train_info_s3_usage(): + """Tests if the s3_usage flag correctly works""" + project_name = "TEST" s3_usage = True mocked_s3 = Mock() - + save_train_info(project_name, s3_usage, mocked_s3) - + assert mocked_s3.download_files_in_prefix_to_dir.call_count == 3 assert mocked_s3.upload_file_to_s3.called_once() - diff --git a/data_extractor/code/tests/utils_test.py b/data_extractor/code/tests/utils_test.py index 4da7b9c..68ad04e 100644 --- a/data_extractor/code/tests/utils_test.py +++ b/data_extractor/code/tests/utils_test.py @@ -12,7 +12,7 @@ def project_tests_root() -> Path: return Path(__file__).parent.resolve() -def write_to_file(path_csv_file: Path, content: str, header: str = ''): +def write_to_file(path_csv_file: Path, content: str, header: str = ""): """Write to a file for a given path with an optional header string :param path_csv_file: Path to csv file @@ -22,13 +22,13 @@ def write_to_file(path_csv_file: Path, content: str, header: str = ''): :param header: Header of the csv file, defaults to '' :type header: str, optional """ - with open(str(path_csv_file), 'w') as file: + with open(str(path_csv_file), "w") as file: if len(header) > 0: - file.write(f'{header}\n') - file.write(f'{content}\n') - - -def create_single_xlsx_file(path_folder: Path, file_name = 'xlsx_file.xlsx') -> None: + file.write(f"{header}\n") + file.write(f"{content}\n") + + +def create_single_xlsx_file(path_folder: Path, file_name="xlsx_file.xlsx") -> None: """Writes a single xlsx file to path_folder and creates the folder if it does not exist @@ -37,14 +37,14 @@ def create_single_xlsx_file(path_folder: Path, file_name = 'xlsx_file.xlsx') -> :param file_name: Filename of the xlsx file, defaults to 'xlsx_file.xlsx' :type file_name: str, optional """ - path_folder.mkdir(parents = True, exist_ok = True) - + path_folder.mkdir(parents=True, exist_ok=True) + # write single xlsx file - df_current = pd.DataFrame({'Data': [10, 20, 30, 40, 50, 60]}) + df_current = pd.DataFrame({"Data": [10, 20, 30, 40, 50, 60]}) path_current_file = path_folder / file_name - df_current.to_excel(str(path_current_file), engine='openpyxl') - - + df_current.to_excel(str(path_current_file), engine="openpyxl") + + def create_multiple_xlsx_files(path_folder: Path) -> None: """Writes multiple xlsx file to path_folder and creates the folder if it does not exist @@ -53,10 +53,10 @@ def create_multiple_xlsx_files(path_folder: Path) -> None: :type path_folder: Path """ for i in range(5): - create_single_xlsx_file(path_folder, file_name = f'xlsx_file_{i}.xlsx') + create_single_xlsx_file(path_folder, file_name=f"xlsx_file_{i}.xlsx") + -def modify_project_settings(project_settings: typing.Dict, - *args: typing.Tuple[str, str, bool]) -> typing.Dict: +def modify_project_settings(project_settings: typing.Dict, *args: typing.Tuple[str, str, bool]) -> typing.Dict: """Returns are modified project settings dict based on the input args :param project_settings: Project settings diff --git a/data_extractor/code/train_on_pdf.py b/data_extractor/code/train_on_pdf.py index 59b1d47..b4b2c38 100644 --- a/data_extractor/code/train_on_pdf.py +++ b/data_extractor/code/train_on_pdf.py @@ -13,7 +13,7 @@ from s3_communication import S3Communication from pathlib import Path -path_file_running = config_path.NLP_DIR+r'/data/running' +path_file_running = config_path.NLP_DIR + r"/data/running" project_settings = None project_model_dir = None @@ -40,19 +40,19 @@ def set_running(): - with open(path_file_running, 'w'): - pass + with open(path_file_running, "w"): + pass def clear_running(): - try: - os.unlink(path_file_running) - except Exception as e: - pass + try: + os.unlink(path_file_running) + except Exception as e: + pass def check_running(): - return os.path.exists(path_file_running) + return os.path.exists(path_file_running) def create_directory(directory_name): @@ -63,7 +63,7 @@ def create_directory(directory_name): if os.path.isfile(file_path): os.unlink(file_path) except Exception as e: - print('Failed to delete %s. Reason: %s' % (file_path, e)) + print("Failed to delete %s. Reason: %s" % (file_path, e)) def link_files(source_dir, destination_dir): @@ -74,9 +74,9 @@ def link_files(source_dir, destination_dir): def generate_text_3434(project_name, s3_usage, s3_settings): """ - This function merges all infer relevance outputs into one large file, which is then + This function merges all infer relevance outputs into one large file, which is then used to train the kpi extraction model. - + :param project_name: str, representing the project we currently work on :param s3_usage: boolean, if we use s3 as we then have to upload the new csv file to s3 :param s3_settings: dictionary, containing information in case of s3 usage @@ -84,45 +84,47 @@ def generate_text_3434(project_name, s3_usage, s3_settings): """ if s3_usage: s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) # Download infer relevance files - prefix_rel_infer = str(Path(s3_settings['prefix']) / project_name / 'data' / 'output' / 'RELEVANCE' / 'Text') + prefix_rel_infer = str(Path(s3_settings["prefix"]) / project_name / "data" / "output" / "RELEVANCE" / "Text") s3c_main.download_files_in_prefix_to_dir(prefix_rel_infer, str(folder_relevance)) - + with open(folder_text_3434 + r"/text_3434.csv", "w") as file_out: very_first = True - rel_inf_list = list(glob.iglob(folder_relevance + r'/*.csv')) + rel_inf_list = list(glob.iglob(folder_relevance + r"/*.csv")) if len(rel_inf_list) == 0: print("No relevance inference results found.") return False else: try: - for filepath in rel_inf_list: + for filepath in rel_inf_list: print(filepath) with open(filepath) as file_in: first = True for l in file_in: - if(very_first or not first): + if very_first or not first: file_out.write(l) first = False very_first = False except Exception: return False - + if s3_usage: s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), ) - project_prefix_text3434 = str(Path(s3_settings['prefix']) / project_name / 'data' / 'interim' / 'ml') - s3c_interim.upload_file_to_s3(filepath=folder_text_3434 + r"/text_3434.csv", s3_prefix=project_prefix_text3434, s3_key='text_3434.csv') - + project_prefix_text3434 = str(Path(s3_settings["prefix"]) / project_name / "data" / "interim" / "ml") + s3c_interim.upload_file_to_s3( + filepath=folder_text_3434 + r"/text_3434.csv", s3_prefix=project_prefix_text3434, s3_key="text_3434.csv" + ) + return True @@ -138,23 +140,21 @@ def convert_xls_to_csv(s3_usage, s3c_main, s3c_interim): source_dir = source_annotation dest_dir = destination_annotation if s3_usage: - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/annotations', - source_dir) + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/annotations", source_dir) first = True for filename in os.listdir(source_dir): - if filename[-5:] == '.xlsx': + if filename[-5:] == ".xlsx": if not first: - raise ValueError('More than one excel sheet found') - print('Converting ' + filename + ' to csv-format') + raise ValueError("More than one excel sheet found") + print("Converting " + filename + " to csv-format") # only reads first sheet in excel file - read_file = pd.read_excel(source_dir + r'/' + filename, engine='openpyxl') - read_file.to_csv(dest_dir + r'/aggregated_annotation.csv', index=None, header=True) + read_file = pd.read_excel(source_dir + r"/" + filename, engine="openpyxl") + read_file.to_csv(dest_dir + r"/aggregated_annotation.csv", index=None, header=True) if s3_usage: - s3c_interim.upload_files_in_dir_to_prefix(dest_dir, - project_prefix + '/interim/ml/annotations') - first = False + s3c_interim.upload_files_in_dir_to_prefix(dest_dir, project_prefix + "/interim/ml/annotations") + first = False if first: - raise ValueError('No annotation excel sheet found') + raise ValueError("No annotation excel sheet found") def save_train_info(project_name, s3_usage=False, s3c_main=None, s3_settings=None): @@ -172,41 +172,51 @@ def save_train_info(project_name, s3_usage=False, s3c_main=None, s3_settings=Non """ if s3_usage: s3_settings = project_settings["s3_settings"] - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/kpi_mapping', source_mapping) - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/annotations', source_annotation) - s3c_main.download_files_in_prefix_to_dir(project_prefix + '/input/pdfs/training', source_pdf) - + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/kpi_mapping", source_mapping) + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/annotations", source_annotation) + s3c_main.download_files_in_prefix_to_dir(project_prefix + "/input/pdfs/training", source_pdf) + dir_train = {} - dir_train.update({'project_name': project_name}) - dir_train.update({'train_settings': project_settings}) - dir_train.update({'pdfs_used': os.listdir(source_pdf)}) + dir_train.update({"project_name": project_name}) + dir_train.update({"train_settings": project_settings}) + dir_train.update({"pdfs_used": os.listdir(source_pdf)}) first = True for filename in os.listdir(source_annotation): - if(filename[-5:]=='.xlsx'): + if filename[-5:] == ".xlsx": if first: - dir_train.update({'annotations': pd.read_excel(source_annotation + r'/' + filename, engine='openpyxl') }) + dir_train.update({"annotations": pd.read_excel(source_annotation + r"/" + filename, engine="openpyxl")}) first = False - dir_train.update({'kpis': pd.read_csv(source_mapping + '/kpi_mapping.csv')}) - - relevance_model = project_settings['train_relevance']['output_model_name'] - kpi_model = project_settings['train_kpi']['output_model_name'] + dir_train.update({"kpis": pd.read_csv(source_mapping + "/kpi_mapping.csv")}) + + relevance_model = project_settings["train_relevance"]["output_model_name"] + kpi_model = project_settings["train_kpi"]["output_model_name"] name_out = project_model_dir - name_out = name_out + '/SUMMARY_REL_' + relevance_model + '_KPI_' + kpi_model + '.pickle' - - with open(name_out, 'wb') as handle: + name_out = name_out + "/SUMMARY_REL_" + relevance_model + "_KPI_" + kpi_model + ".pickle" + + with open(name_out, "wb") as handle: pickle.dump(dir_train, handle, protocol=pickle.HIGHEST_PROTOCOL) if s3_usage: - response_2 = s3c_main.upload_file_to_s3(filepath=name_out, - s3_prefix=str(Path(s3_settings['prefix']) / project_name / 'models'), - s3_key='SUMMARY_REL_' + relevance_model + '_KPI_' + kpi_model + '.pickle') - + response_2 = s3c_main.upload_file_to_s3( + filepath=name_out, + s3_prefix=str(Path(s3_settings["prefix"]) / project_name / "models"), + s3_key="SUMMARY_REL_" + relevance_model + "_KPI_" + kpi_model + ".pickle", + ) + return None -def run_router(ext_port, infer_port, project_name, ext_ip='0.0.0.0', infer_ip='0.0.0.0', - s3_usage=False, s3c_main=None, s3c_interim=None): +def run_router( + ext_port, + infer_port, + project_name, + ext_ip="0.0.0.0", + infer_ip="0.0.0.0", + s3_usage=False, + s3c_main=None, + s3c_interim=None, +): """ Router function It fist sends a command to the extraction server to begin extraction. @@ -230,23 +240,23 @@ def run_router(ext_port, infer_port, project_name, ext_ip='0.0.0.0', infer_ip='0 else: print("Extraction server is not responding.") return False - - payload = {'project_name': project_name, 'mode': 'train'} + + payload = {"project_name": project_name, "mode": "train"} payload.update(project_settings) - payload = {'payload': json.dumps(payload)} - + payload = {"payload": json.dumps(payload)} + # Sending an execution request to the extraction server for extraction ext_resp = requests.get(f"http://{ext_ip}:{ext_port}/extract", params=payload) print(ext_resp.text) if ext_resp.status_code != 200: return False - + # Sending an execution request to the extraction server for curation ext_resp = requests.get(f"http://{ext_ip}:{ext_port}/curate", params=payload) print(ext_resp.text) if ext_resp.status_code != 200: return False - + # Check if the inference server is live infer_live = requests.get(f"http://{infer_ip}:{infer_port}/liveness") if infer_live.status_code == 200: @@ -254,8 +264,8 @@ def run_router(ext_port, infer_port, project_name, ext_ip='0.0.0.0', infer_ip='0 else: print("Inference server is not responding.") return False - - if project_settings['train_relevance']['train']: + + if project_settings["train_relevance"]["train"]: print("Relevance training will be started.") # Requesting the inference server to start the relevance stage train_resp = requests.get(f"http://{infer_ip}:{infer_port}/train_relevance", params=payload) @@ -263,35 +273,39 @@ def run_router(ext_port, infer_port, project_name, ext_ip='0.0.0.0', infer_ip='0 if train_resp.status_code != 200: return False else: - print("No relevance training done. If you want to have a relevance training please set variable " - "train under train_relevance to true.") - - if project_settings['train_kpi']['train']: + print( + "No relevance training done. If you want to have a relevance training please set variable " + "train under train_relevance to true." + ) + + if project_settings["train_kpi"]["train"]: # Requesting the inference server to start the relevance stage infer_resp = requests.get(f"http://{infer_ip}:{infer_port}/infer_relevance", params=payload) print(infer_resp.text) if infer_resp.status_code != 200: return False try: - temp = generate_text_3434(project_name, project_settings['s3_usage'], project_settings['s3_settings']) + temp = generate_text_3434(project_name, project_settings["s3_usage"], project_settings["s3_settings"]) if temp: - print('text_3434 was generated without error.') + print("text_3434 was generated without error.") else: - print('text_3434 was not generated without error.') + print("text_3434 was not generated without error.") except Exception as e: - print('Error while generating text_3434.') + print("Error while generating text_3434.") print(repr(e)) print(traceback.format_exc()) - print('Next we start the training of the inference model. This may take some time.') + print("Next we start the training of the inference model. This may take some time.") # Requesting the inference server to start the kpi extraction stage infer_resp_kpi = requests.get(f"http://{infer_ip}:{infer_port}/train_kpi", params=payload) print(infer_resp_kpi.text) if infer_resp_kpi.status_code != 200: return False else: - print("No kpi training done. If you want to have a kpi training please set variable" - " train under train_kpi to true.") + print( + "No kpi training done. If you want to have a kpi training please set variable" + " train under train_kpi to true." + ) return True @@ -308,7 +322,7 @@ def copy_file_without_overwrite(src_path, dest_path): def link_extracted_files(src_ext, src_pdf, dest_ext): - extracted_pdfs = [name[:-5] + ".pdf" for name in os.listdir(src_ext)] + extracted_pdfs = [name[:-5] + ".pdf" for name in os.listdir(src_ext)] for pdf in os.listdir(src_pdf): if pdf in extracted_pdfs: json_name = pdf[:-4] + ".json" @@ -343,98 +357,94 @@ def main(): if check_running(): print("Another training or inference process is currently running.") return - - parser = argparse.ArgumentParser(description='End-to-end inference') - + + parser = argparse.ArgumentParser(description="End-to-end inference") + # Add the arguments - parser.add_argument('--project_name', - type=str, - default=None, - help='Name of the Project') - - parser.add_argument('--s3_usage', - type=str, - default=None, - help='Do you want to use S3? Type either Y or N.') - + parser.add_argument("--project_name", type=str, default=None, help="Name of the Project") + + parser.add_argument("--s3_usage", type=str, default=None, help="Do you want to use S3? Type either Y or N.") + args = parser.parse_args() project_name = args.project_name if project_name is None: project_name = input("What is the project name? ") - if project_name is None or project_name=="": + if project_name is None or project_name == "": print("project name must not be empty") return s3_usage = args.s3_usage if s3_usage is None: - s3_usage = input('Do you want to use S3? Type either Y or N.') - if s3_usage is None or str(s3_usage) not in ['Y', 'N']: + s3_usage = input("Do you want to use S3? Type either Y or N.") + if s3_usage is None or str(s3_usage) not in ["Y", "N"]: print("Answer to S3 usage must by Y or N. Stop program. Please restart.") return None else: - s3_usage = s3_usage == 'Y' + s3_usage = s3_usage == "Y" - project_data_dir = config_path.DATA_DIR + r'/' + project_name + project_data_dir = config_path.DATA_DIR + r"/" + project_name create_directory(project_data_dir) s3c_main = None s3c_interim = None if s3_usage: # Opening s3 settings file - s3_settings_path = config_path.DATA_DIR + r'/' + 's3_settings.yaml' - f = open(s3_settings_path, 'r') + s3_settings_path = config_path.DATA_DIR + r"/" + "s3_settings.yaml" + f = open(s3_settings_path, "r") s3_settings = yaml.safe_load(f) f.close() - project_prefix = s3_settings['prefix'] + "/" + project_name + '/data' + project_prefix = s3_settings["prefix"] + "/" + project_name + "/data" # init s3 connector s3c_main = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['main_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['main_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['main_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['main_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["main_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["main_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["main_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["main_bucket"]["s3_bucket_name"]), ) s3c_interim = S3Communication( - s3_endpoint_url=os.getenv(s3_settings['interim_bucket']['s3_endpoint']), - aws_access_key_id=os.getenv(s3_settings['interim_bucket']['s3_access_key']), - aws_secret_access_key=os.getenv(s3_settings['interim_bucket']['s3_secret_key']), - s3_bucket=os.getenv(s3_settings['interim_bucket']['s3_bucket_name']), + s3_endpoint_url=os.getenv(s3_settings["interim_bucket"]["s3_endpoint"]), + aws_access_key_id=os.getenv(s3_settings["interim_bucket"]["s3_access_key"]), + aws_secret_access_key=os.getenv(s3_settings["interim_bucket"]["s3_secret_key"]), + s3_bucket=os.getenv(s3_settings["interim_bucket"]["s3_bucket_name"]), ) settings_path = project_data_dir + "/settings.yaml" - s3c_main.download_file_from_s3(filepath=settings_path, - s3_prefix=project_prefix, - s3_key='settings.yaml') + s3c_main.download_file_from_s3(filepath=settings_path, s3_prefix=project_prefix, s3_key="settings.yaml") # Opening YAML file - f = open(project_data_dir + r'/settings.yaml', 'r') + f = open(project_data_dir + r"/settings.yaml", "r") project_settings = yaml.safe_load(f) f.close() - project_settings.update({'s3_usage': s3_usage}) + project_settings.update({"s3_usage": s3_usage}) if s3_usage: - project_settings.update({'s3_settings': s3_settings}) - - project_model_dir = config_path.MODEL_DIR + r'/' + project_name - ext_port = project_settings['general']['ext_port'] - infer_port = project_settings['general']['infer_port'] - ext_ip = project_settings['general']['ext_ip'] - infer_ip = project_settings['general']['infer_ip'] - relevance_training_output_model_name = project_settings['train_relevance']['output_model_name'] - kpi_inference_training_output_model_name = project_settings['train_kpi']['output_model_name'] - + project_settings.update({"s3_settings": s3_settings}) + + project_model_dir = config_path.MODEL_DIR + r"/" + project_name + ext_port = project_settings["general"]["ext_port"] + infer_port = project_settings["general"]["infer_port"] + ext_ip = project_settings["general"]["ext_ip"] + infer_ip = project_settings["general"]["infer_ip"] + relevance_training_output_model_name = project_settings["train_relevance"]["output_model_name"] + kpi_inference_training_output_model_name = project_settings["train_kpi"]["output_model_name"] + set_running() try: - source_pdf = project_data_dir + r'/input/pdfs/training' - source_annotation = project_data_dir + r'/input/annotations' - source_mapping = project_data_dir + r'/input/kpi_mapping' - destination_pdf = project_data_dir + r'/interim/pdfs/' - destination_annotation = project_data_dir + r'/interim/ml/annotations/' - destination_mapping = project_data_dir + r'/interim/kpi_mapping/' - destination_extraction = project_data_dir + r'/interim/ml/extraction/' - destination_curation = project_data_dir + r'/interim/ml/curation/' - destination_training = project_data_dir + r'/interim/ml/training/' - destination_saved_models_relevance = project_model_dir + r'/RELEVANCE/Text' + r'/' + relevance_training_output_model_name - destination_saved_models_inference = project_model_dir + r'/KPI_EXTRACTION/Text' + r'/' + kpi_inference_training_output_model_name - folder_text_3434 = project_data_dir + r'/interim/ml' - folder_relevance = project_data_dir + r'/output/RELEVANCE/Text' + source_pdf = project_data_dir + r"/input/pdfs/training" + source_annotation = project_data_dir + r"/input/annotations" + source_mapping = project_data_dir + r"/input/kpi_mapping" + destination_pdf = project_data_dir + r"/interim/pdfs/" + destination_annotation = project_data_dir + r"/interim/ml/annotations/" + destination_mapping = project_data_dir + r"/interim/kpi_mapping/" + destination_extraction = project_data_dir + r"/interim/ml/extraction/" + destination_curation = project_data_dir + r"/interim/ml/curation/" + destination_training = project_data_dir + r"/interim/ml/training/" + destination_saved_models_relevance = ( + project_model_dir + r"/RELEVANCE/Text" + r"/" + relevance_training_output_model_name + ) + destination_saved_models_inference = ( + project_model_dir + r"/KPI_EXTRACTION/Text" + r"/" + kpi_inference_training_output_model_name + ) + folder_text_3434 = project_data_dir + r"/interim/ml" + folder_relevance = project_data_dir + r"/output/RELEVANCE/Text" create_directory(source_pdf) create_directory(source_annotation) @@ -446,39 +456,43 @@ def main(): create_directory(destination_extraction) create_directory(destination_training) create_directory(destination_curation) - if project_settings['train_relevance']['train']: + if project_settings["train_relevance"]["train"]: create_directory(destination_saved_models_relevance) - if project_settings['train_kpi']['train']: + if project_settings["train_kpi"]["train"]: create_directory(destination_saved_models_inference) create_directory(folder_relevance) link_files(source_pdf, destination_pdf) link_files(source_annotation, destination_annotation) link_files(source_mapping, destination_mapping) - if project_settings['extraction']['use_extractions']: - source_extraction = project_data_dir + r'/output/TEXT_EXTRACTION' + if project_settings["extraction"]["use_extractions"]: + source_extraction = project_data_dir + r"/output/TEXT_EXTRACTION" if os.path.exists(source_extraction): link_extracted_files(source_extraction, source_pdf, destination_extraction) - - end_to_end_response = run_router(ext_port, infer_port, project_name, ext_ip, infer_ip, - s3_usage, s3c_main, s3c_interim) - + + end_to_end_response = run_router( + ext_port, infer_port, project_name, ext_ip, infer_ip, s3_usage, s3c_main, s3c_interim + ) + if end_to_end_response: - if project_settings['extraction']['store_extractions']: + if project_settings["extraction"]["store_extractions"]: print("Finally we transfer the text extraction to the output folder") source_extraction_data = destination_extraction - destination_extraction_data = project_data_dir + r'/output/TEXT_EXTRACTION' + destination_extraction_data = project_data_dir + r"/output/TEXT_EXTRACTION" if s3_usage: - s3c_interim.download_files_in_prefix_to_dir(project_prefix + '/interim/ml/extraction', - source_extraction_data) - s3c_main.upload_files_in_dir_to_prefix(source_extraction_data, - project_prefix + '/output/TEXT_EXTRACTION') + s3c_interim.download_files_in_prefix_to_dir( + project_prefix + "/interim/ml/extraction", source_extraction_data + ) + s3c_main.upload_files_in_dir_to_prefix( + source_extraction_data, project_prefix + "/output/TEXT_EXTRACTION" + ) else: os.makedirs(destination_extraction_data, exist_ok=True) - end_to_end_response = copy_file_without_overwrite(source_extraction_data, - destination_extraction_data) - - if project_settings['general']['delete_interim_files']: + end_to_end_response = copy_file_without_overwrite( + source_extraction_data, destination_extraction_data + ) + + if project_settings["general"]["delete_interim_files"]: create_directory(destination_pdf) create_directory(destination_mapping) create_directory(destination_annotation) @@ -489,15 +503,15 @@ def main(): if s3_usage: # Show only objects which satisfy our prefix my_bucket = s3c_interim.s3_resource.Bucket(name=s3c_interim.bucket) - for objects in my_bucket.objects.filter(Prefix=project_prefix+'/interim'): + for objects in my_bucket.objects.filter(Prefix=project_prefix + "/interim"): _ = objects.delete() - + if end_to_end_response: save_train_info(project_name, s3_usage, s3c_main) print("End-to-end inference complete") except Exception as e: - print('Process failed to run. Reason: ' + str(repr(e)) + traceback.format_exc()) + print("Process failed to run. Reason: " + str(repr(e)) + traceback.format_exc()) clear_running() diff --git a/data_extractor/code/utils/config_path.py b/data_extractor/code/utils/config_path.py index ac24ad8..e42d7e1 100644 --- a/data_extractor/code/utils/config_path.py +++ b/data_extractor/code/utils/config_path.py @@ -1,14 +1,14 @@ import os try: - path = globals()['_dh'][0] + path = globals()["_dh"][0] except KeyError: path = os.path.dirname(os.path.realpath(__file__)) - + root_dir = os.path.dirname(path) -MODEL_DIR = root_dir + r'/models' -DATA_DIR = root_dir + r'/data' +MODEL_DIR = root_dir + r"/models" +DATA_DIR = root_dir + r"/data" NLP_DIR = root_dir -PYTHON_EXECUTABLE = 'python' +PYTHON_EXECUTABLE = "python" diff --git a/data_extractor/code/utils/s3_communication.py b/data_extractor/code/utils/s3_communication.py index 82df3d3..470b008 100644 --- a/data_extractor/code/utils/s3_communication.py +++ b/data_extractor/code/utils/s3_communication.py @@ -23,9 +23,7 @@ class S3Communication(object): It connects with the bucket and provides methods to read and write data in parquet, csv, and json formats. """ - def __init__( - self, s3_endpoint_url, aws_access_key_id, aws_secret_access_key, s3_bucket - ): + def __init__(self, s3_endpoint_url, aws_access_key_id, aws_secret_access_key, s3_bucket): """Initialize communicator.""" self.s3_endpoint_url = s3_endpoint_url self.aws_access_key_id = aws_access_key_id @@ -63,9 +61,7 @@ def download_file_from_s3(self, filepath, s3_prefix, s3_key): with open(filepath, "wb") as f: f.write(buffer_bytes) - def upload_df_to_s3( - self, df, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_to_ftype_args - ): + def upload_df_to_s3(self, df, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_to_ftype_args): """ Take as input the data frame to be uploaded, and the output s3_key. @@ -79,16 +75,12 @@ def upload_df_to_s3( elif filetype == S3FileType.PARQUET: df.to_parquet(buffer, **pd_to_ftype_args) else: - raise ValueError( - f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})" - ) + raise ValueError(f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})") status = self._upload_bytes(buffer.getvalue(), s3_prefix, s3_key) return status - def download_df_from_s3( - self, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_read_ftype_args - ): + def download_df_from_s3(self, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_read_ftype_args): """Read from s3 and see if the saved data is correct.""" buffer_bytes = self._download_bytes(s3_prefix, s3_key) buffer = BytesIO(buffer_bytes) @@ -100,9 +92,7 @@ def download_df_from_s3( elif filetype == S3FileType.PARQUET: df = pd.read_parquet(buffer, **pd_read_ftype_args) else: - raise ValueError( - f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})" - ) + raise ValueError(f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})") return df def upload_files_in_dir_to_prefix(self, source_dir, s3_prefix): @@ -126,9 +116,7 @@ def download_files_in_prefix_to_dir(self, s3_prefix, destination_dir): Modified from original code here: https://stackoverflow.com/a/33350380 """ paginator = self.s3_resource.meta.client.get_paginator("list_objects") - for result in paginator.paginate( - Bucket=self.bucket, Delimiter="/", Prefix=s3_prefix - ): + for result in paginator.paginate(Bucket=self.bucket, Delimiter="/", Prefix=s3_prefix): # download all files in the sub "directory", if any if result.get("CommonPrefixes") is not None: for subdir in result.get("CommonPrefixes"): diff --git a/data_extractor/code/visitor_container/visitor_main.py b/data_extractor/code/visitor_container/visitor_main.py index 1c34ff5..567d65a 100644 --- a/data_extractor/code/visitor_container/visitor_main.py +++ b/data_extractor/code/visitor_container/visitor_main.py @@ -13,24 +13,23 @@ def main(app_type, project_name, s3_usage, mode): :param mode: string: RB, ML, both, or none - for just doing postprocessing :return: """ - if app_type not in ['training', 'inference']: + if app_type not in ["training", "inference"]: print("app_type should be training or inference. Please restart with valid input.") return False - if s3_usage not in ['Y', 'N']: + if s3_usage not in ["Y", "N"]: print("s3_usage should be Y or N. Please restart with valid input.") return False - if mode not in ['RB', 'ML', 'both', 'none']: + if mode not in ["RB", "ML", "both", "none"]: print("mode should be RB, ML, both or none. Please restart with valid input.") return False - coordinator_ip = os.getenv('coordinator_ip') - coordinator_port = os.getenv('coordinator_port') + coordinator_ip = os.getenv("coordinator_ip") + coordinator_port = os.getenv("coordinator_port") # Example string http://172.40.103.147:2000/liveness - liveness_string = f"http://{coordinator_ip}:" \ - f"{coordinator_port}/liveness" + liveness_string = f"http://{coordinator_ip}:" f"{coordinator_port}/liveness" coordinator_server_live = requests.get(liveness_string) if coordinator_server_live.status_code == 200: print(f"Coordinator server is up. Proceeding to the task {app_type} with project {project_name}.") @@ -38,13 +37,15 @@ def main(app_type, project_name, s3_usage, mode): print("Coordinator server is not responding.") return False - if app_type == 'training': + if app_type == "training": print(f"We will contact the server to start training for project {project_name}.") # Example string http://172.40.103.147:2000/train?project_name=ABC&s3_usage=Y - train_string = f"http://{coordinator_ip}:" \ - f"{coordinator_port}/train" \ - f'?project_name={project_name}'\ - f'&s3_usage={s3_usage}' + train_string = ( + f"http://{coordinator_ip}:" + f"{coordinator_port}/train" + f"?project_name={project_name}" + f"&s3_usage={s3_usage}" + ) coordinator_start_train = requests.get(train_string) print(coordinator_start_train.text) if coordinator_start_train.status_code == 200: @@ -54,11 +55,13 @@ def main(app_type, project_name, s3_usage, mode): else: print(f"We will contact the server to start inference for project {project_name}.") # Example string http://172.40.103.147:2000/infer?project_name=ABC&s3_usage=Y&mode=both - infer_string = f"http://{coordinator_ip}:" \ - f"{coordinator_port}/infer" \ - f'?project_name={project_name}'\ - f'&s3_usage={s3_usage}' \ - f'&mode={mode}' + infer_string = ( + f"http://{coordinator_ip}:" + f"{coordinator_port}/infer" + f"?project_name={project_name}" + f"&s3_usage={s3_usage}" + f"&mode={mode}" + ) coordinator_start_infer = requests.get(infer_string) print(coordinator_start_infer.text) if coordinator_start_infer.status_code == 200: @@ -67,6 +70,6 @@ def main(app_type, project_name, s3_usage, mode): return False -if __name__ == '__main__': - main('training', os.getenv('test_project'), 'Y', 'both') - main('inference', os.getenv('test_project'), 'Y', 'both') +if __name__ == "__main__": + main("training", os.getenv("test_project"), "Y", "both") + main("inference", os.getenv("test_project"), "Y", "both") diff --git a/data_extractor/docs/s3_communication.py b/data_extractor/docs/s3_communication.py index 82df3d3..470b008 100644 --- a/data_extractor/docs/s3_communication.py +++ b/data_extractor/docs/s3_communication.py @@ -23,9 +23,7 @@ class S3Communication(object): It connects with the bucket and provides methods to read and write data in parquet, csv, and json formats. """ - def __init__( - self, s3_endpoint_url, aws_access_key_id, aws_secret_access_key, s3_bucket - ): + def __init__(self, s3_endpoint_url, aws_access_key_id, aws_secret_access_key, s3_bucket): """Initialize communicator.""" self.s3_endpoint_url = s3_endpoint_url self.aws_access_key_id = aws_access_key_id @@ -63,9 +61,7 @@ def download_file_from_s3(self, filepath, s3_prefix, s3_key): with open(filepath, "wb") as f: f.write(buffer_bytes) - def upload_df_to_s3( - self, df, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_to_ftype_args - ): + def upload_df_to_s3(self, df, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_to_ftype_args): """ Take as input the data frame to be uploaded, and the output s3_key. @@ -79,16 +75,12 @@ def upload_df_to_s3( elif filetype == S3FileType.PARQUET: df.to_parquet(buffer, **pd_to_ftype_args) else: - raise ValueError( - f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})" - ) + raise ValueError(f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})") status = self._upload_bytes(buffer.getvalue(), s3_prefix, s3_key) return status - def download_df_from_s3( - self, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_read_ftype_args - ): + def download_df_from_s3(self, s3_prefix, s3_key, filetype=S3FileType.PARQUET, **pd_read_ftype_args): """Read from s3 and see if the saved data is correct.""" buffer_bytes = self._download_bytes(s3_prefix, s3_key) buffer = BytesIO(buffer_bytes) @@ -100,9 +92,7 @@ def download_df_from_s3( elif filetype == S3FileType.PARQUET: df = pd.read_parquet(buffer, **pd_read_ftype_args) else: - raise ValueError( - f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})" - ) + raise ValueError(f"Received unexpected file type arg {filetype}. Can only be one of: {list(S3FileType)})") return df def upload_files_in_dir_to_prefix(self, source_dir, s3_prefix): @@ -126,9 +116,7 @@ def download_files_in_prefix_to_dir(self, s3_prefix, destination_dir): Modified from original code here: https://stackoverflow.com/a/33350380 """ paginator = self.s3_resource.meta.client.get_paginator("list_objects") - for result in paginator.paginate( - Bucket=self.bucket, Delimiter="/", Prefix=s3_prefix - ): + for result in paginator.paginate(Bucket=self.bucket, Delimiter="/", Prefix=s3_prefix): # download all files in the sub "directory", if any if result.get("CommonPrefixes") is not None: for subdir in result.get("CommonPrefixes"): diff --git a/data_extractor/notebooks/annotation_tool/annotation_tool.py b/data_extractor/notebooks/annotation_tool/annotation_tool.py index c568db4..319d48a 100644 --- a/data_extractor/notebooks/annotation_tool/annotation_tool.py +++ b/data_extractor/notebooks/annotation_tool/annotation_tool.py @@ -4,109 +4,109 @@ import glob import numpy as np -pd.set_option('display.max_rows', 100) -pd.set_option('display.max_colwidth', None) +pd.set_option("display.max_rows", 100) +pd.set_option("display.max_colwidth", None) -company = 'Imperial Oil Ltd' +company = "Imperial Oil Ltd" year = 2018 sector = "OG" -annotator = 'Max' -annotation_path = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL' -output_path = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/output' -input_path = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/input' -kpi_mapping_fpath = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/kpi_mapping.csv' +annotator = "Max" +annotation_path = "/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL" +output_path = "/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/output" +input_path = "/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/input" +kpi_mapping_fpath = "/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/kpi_mapping.csv" df_result = pd.read_excel(annotation_path + "/annotations.xlsx") kpi_of_interest_textarea = widgets.Textarea( - placeholder='Insert KPIs as a comma seperated list, e.g. 0, 1, 2, 3.1, ...', - layout = widgets.Layout(width='50%'), - style = {'description_width': 'initial'}, - description='KPI of interest:', - disabled=False - ) - -report_to_analyze_select = widgets.Select( - # options=None, #sorted(select_options, reverse=True), - description='Available results:', - style = {'description_width': 'initial'}, - layout = widgets.Layout(width='100%', height='150px'), - disabled=False - ) + placeholder="Insert KPIs as a comma seperated list, e.g. 0, 1, 2, 3.1, ...", + layout=widgets.Layout(width="50%"), + style={"description_width": "initial"}, + description="KPI of interest:", + disabled=False, +) + +report_to_analyze_select = widgets.Select( + # options=None, #sorted(select_options, reverse=True), + description="Available results:", + style={"description_width": "initial"}, + layout=widgets.Layout(width="100%", height="150px"), + disabled=False, +) kpi_to_analyze_dropdown = widgets.Dropdown( - # options= None, #sorted(options_kpi), - # rows=10, - value = None, - description='Current KPI:', - style = {'description_width': 'initial'}, - layout = widgets.Layout(width='100%') - ) - -answer_select = widgets.SelectMultiple( - # options=None, #sorted(select_options, reverse=True), - description='Correct answer', - style = {'description_width': 'initial'}, - #layout = widgets.Layout(width='50%', height='150px'), - disabled=False - ) - -paragraph_select = widgets.SelectMultiple( - # options=None, #sorted(select_options, reverse=True), - description='Correct paragraph', - style = {'description_width': 'initial'}, - #layout = widgets.Layout(width='50%', height='150px'), - disabled=False - ) - + # options= None, #sorted(options_kpi), + # rows=10, + value=None, + description="Current KPI:", + style={"description_width": "initial"}, + layout=widgets.Layout(width="100%"), +) + +answer_select = widgets.SelectMultiple( + # options=None, #sorted(select_options, reverse=True), + description="Correct answer", + style={"description_width": "initial"}, + # layout = widgets.Layout(width='50%', height='150px'), + disabled=False, +) + +paragraph_select = widgets.SelectMultiple( + # options=None, #sorted(select_options, reverse=True), + description="Correct paragraph", + style={"description_width": "initial"}, + # layout = widgets.Layout(width='50%', height='150px'), + disabled=False, +) + answer_dropdown = widgets.Dropdown( - description="Choose rank of correct kpi:", - # options=None, - style = {'description_width': 'initial'}, - layout = widgets.Layout(width='50%'), - value = None - ) + description="Choose rank of correct kpi:", + # options=None, + style={"description_width": "initial"}, + layout=widgets.Layout(width="50%"), + value=None, +) paragraph_dropdown = widgets.Dropdown( - description="Choose rank of correct pagraph:", - # options=None, - style = {'description_width': 'initial'}, - layout = widgets.Layout(width='50%'), - value = None - ) + description="Choose rank of correct pagraph:", + # options=None, + style={"description_width": "initial"}, + layout=widgets.Layout(width="50%"), + value=None, +) correct_answer_textarea = widgets.Textarea( - value='', - placeholder='Enter your correction.', - layout = widgets.Layout(width='60%'), - description='Correction:', - disabled=False - ) + value="", + placeholder="Enter your correction.", + layout=widgets.Layout(width="60%"), + description="Correction:", + disabled=False, +) correct_paragraph_textarea = widgets.Textarea( - value='', - placeholder='Enter your correction.', - layout = widgets.Layout(width='60%'), - description='Correction:', - disabled=False - ) + value="", + placeholder="Enter your correction.", + layout=widgets.Layout(width="60%"), + description="Correction:", + disabled=False, +) use_button = widgets.Button( - value=False, - description='Use', - disabled=False, - button_style='', # 'success', 'info', 'warning', 'danger' or '' - tooltip='Use the previous defined kpis of interest', - icon='check' # (FontAwesome names without the `fa-` prefix) - ) + value=False, + description="Use", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Use the previous defined kpis of interest", + icon="check", # (FontAwesome names without the `fa-` prefix) +) save_button = widgets.Button( - value=False, - description='Save', - disabled=False, - button_style='', # 'success', 'info', 'warning', 'danger' or '' - tooltip='Save current annotation and switch to next KPI', - icon='check' # (FontAwesome names without the `fa-` prefix) - ) + value=False, + description="Save", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Save current annotation and switch to next KPI", + icon="check", # (FontAwesome names without the `fa-` prefix) +) output_kpi_of_interest = widgets.Output() @@ -114,50 +114,65 @@ output_annotation = widgets.Output() output_save = widgets.Output() + def get_kpi_overview(): kpi_mapping_df = pd.read_csv(kpi_mapping_fpath) - kpi_mapping_df['kpi_id'] = kpi_mapping_df['kpi_id'].map('{:g}'.format) - filtered_columns = kpi_mapping_df.columns[~kpi_mapping_df.columns.str.contains('Unnamed')] + kpi_mapping_df["kpi_id"] = kpi_mapping_df["kpi_id"].map("{:g}".format) + filtered_columns = kpi_mapping_df.columns[~kpi_mapping_df.columns.str.contains("Unnamed")] kpi_mapping_df = kpi_mapping_df[filtered_columns] return kpi_mapping_df - + def get_kpi_of_interest(kpi_str=None): if kpi_str == None: kpis = "".join(kpi_of_interest_textarea.value.split()) else: - kpis= "".join(kpi_str.value.split()) - kpis = kpis.split(',') - # kpis = list(map(float, kpis)) + kpis = "".join(kpi_str.value.split()) + kpis = kpis.split(",") + # kpis = list(map(float, kpis)) return kpis + def update_pdf_selection(input_path, annotation_path, kpi_of_interest): global df_result selection = [] outputs = glob.glob(input_path + "/*") - outputs = [x.rsplit('/', 1)[1] for x in outputs] + outputs = [x.rsplit("/", 1)[1] for x in outputs] df_annotations = df_result kpis = set(map(float, set(kpi_of_interest))) for output in outputs: df_output = pd.read_csv(input_path + "/" + output) - pdf_name = df_output['pdf_name'].values[0] + pdf_name = df_output["pdf_name"].values[0] df_annotations_temp = df_annotations[df_annotations.source_file == pdf_name] - kpis_contained = [float(x) for x in df_annotations_temp['kpi_id'].values if x in kpis] + kpis_contained = [float(x) for x in df_annotations_temp["kpi_id"].values if x in kpis] if set(kpis_contained) == kpis: - selection = selection + [(f'DONE - ' + output, output)] + selection = selection + [(f"DONE - " + output, output)] else: - selection = selection + [(f'TODO ({len(set(kpis_contained))}/{len(kpi_of_interest)}) - ' + output, output)] - return sorted(selection, reverse=True) + selection = selection + [(f"TODO ({len(set(kpis_contained))}/{len(kpi_of_interest)}) - " + output, output)] + return sorted(selection, reverse=True) def get_df_for_selected_report(): - df_infer = pd.read_csv(input_path + "/" + report_to_analyze_select.value) - df_kpi_mapping = pd.read_csv(kpi_mapping_fpath)[['kpi_id', 'question']] - df = df_infer.merge(df_kpi_mapping, how='left', left_on='kpi', right_on = 'question') - df = df[['kpi', 'kpi_id_y', 'answer','page','paragraph', 'source', 'score', 'no_ans_score', 'no_answer_score_plus_boost']] - df = df.rename(columns={'kpi_id_y':'kpi_id'}) + df_infer = pd.read_csv(input_path + "/" + report_to_analyze_select.value) + df_kpi_mapping = pd.read_csv(kpi_mapping_fpath)[["kpi_id", "question"]] + df = df_infer.merge(df_kpi_mapping, how="left", left_on="kpi", right_on="question") + df = df[ + [ + "kpi", + "kpi_id_y", + "answer", + "page", + "paragraph", + "source", + "score", + "no_ans_score", + "no_answer_score_plus_boost", + ] + ] + df = df.rename(columns={"kpi_id_y": "kpi_id"}) return df + def get_answers_for_selected_kpi(): if kpi_to_analyze_dropdown.value == None: kpis_of_interest = get_kpi_of_interest() @@ -165,35 +180,36 @@ def get_answers_for_selected_kpi(): else: kpi_id = kpi_to_analyze_dropdown.value df = get_df_for_selected_report() - df_filtered = df[df['kpi_id']==float(kpi_id)] + df_filtered = df[df["kpi_id"] == float(kpi_id)] return df_filtered def update_rank_selection(): - df_answers = get_answers_for_selected_kpi() - index_list = df_answers.index.values + df_answers = get_answers_for_selected_kpi() + index_list = df_answers.index.values selection = [] for i, x in enumerate(index_list): - selection = selection + [(i+1,[i+1,x])] - selection = selection + [('no correct answer',[-1,-1])] + selection = selection + [(i + 1, [i + 1, x])] + selection = selection + [("no correct answer", [-1, -1])] return selection + def use_button_on_click(b): kpis_of_interest = get_kpi_of_interest() try: kpi_of_interest = map(float, kpis_of_interest) with report_to_analyze_select.hold_trait_notifications(): - new_options = update_pdf_selection(input_path, annotation_path, kpis_of_interest) + new_options = update_pdf_selection(input_path, annotation_path, kpis_of_interest) report_to_analyze_select.options = new_options with kpi_to_analyze_dropdown.hold_trait_notifications(): - new_options = [(x,x) for x in kpis_of_interest] + new_options = [(x, x) for x in kpis_of_interest] kpi_to_analyze_dropdown.values = kpis_of_interest kpi_to_analyze_dropdown.options = new_options output_kpi_of_interest.clear_output() with output_kpi_of_interest: display(kpi_of_interest_textarea) display(use_button) - print('Your selected KPIs are: ' + " ".join(get_kpi_of_interest())) + print("Your selected KPIs are: " + " ".join(get_kpi_of_interest())) display(get_kpi_overview()) output_pdf_selection.clear_output() with output_pdf_selection: @@ -203,21 +219,20 @@ def use_button_on_click(b): with output_kpi_of_interest: display(kpi_of_interest_textarea) display(use_button) - print('Kpi string has not the right format.') + print("Kpi string has not the right format.") display(get_kpi_overview()) - - - #print(new_options) - - + + # print(new_options) + + def kpi_of_interest_handler(change): kpis_of_interest = get_kpi_of_interest() - #print(new_options) + # print(new_options) with report_to_analyze_select.hold_trait_notifications(): - new_options = update_pdf_selection(input_path, annotation_path, kpis_of_interest) + new_options = update_pdf_selection(input_path, annotation_path, kpis_of_interest) report_to_analyze_select.options = new_options with kpi_to_analyze_dropdown.hold_trait_notifications(): - new_options = [(x,x) for x in kpi_of_interest] + new_options = [(x, x) for x in kpi_of_interest] kpi_to_analyze_dropdown.values = kpis_of_interest kpi_to_analyze_dropdown.options = new_options output_pdf_selection.clear_output() @@ -228,55 +243,57 @@ def kpi_of_interest_handler(change): def report_to_analyze_handler(change): kpis_of_interest = get_kpi_of_interest() with kpi_to_analyze_dropdown.hold_trait_notifications(): - new_options = [(x,x) for x in kpis_of_interest] + new_options = [(x, x) for x in kpis_of_interest] kpi_to_analyze_dropdown.value = min(kpis_of_interest) kpi_to_analyze_dropdown.options = new_options - output_annotation.clear_output() + output_annotation.clear_output() with output_annotation: - print('\n') + print("\n") display(kpi_to_analyze_dropdown) display(get_answers_for_selected_kpi()) - + def kpi_to_analyze_handler(change): df_answers = get_answers_for_selected_kpi() with answer_dropdown.hold_trait_notifications(): - #answer_select.value = None + # answer_select.value = None new_rank_options = update_rank_selection() answer_dropdown.options = new_rank_options with paragraph_dropdown.hold_trait_notifications(): - #paragraph_select.value = None + # paragraph_select.value = None new_rank_options = update_rank_selection() - paragraph_dropdown.options = new_rank_options - + paragraph_dropdown.options = new_rank_options + output_annotation.clear_output() with output_annotation: - print('\n') + print("\n") display(kpi_to_analyze_dropdown) display(get_answers_for_selected_kpi()) - display(answer_dropdown) + display(answer_dropdown) display(paragraph_dropdown) - #display(widgets.HBox([widgets.VBox([answer_dropdown, correct_answer_textarea]),widgets.VBox([paragraph_dropdown, correct_paragraph_textarea])])) - + # display(widgets.HBox([widgets.VBox([answer_dropdown, correct_answer_textarea]),widgets.VBox([paragraph_dropdown, correct_paragraph_textarea])])) + + def answer_paragraph_handler(change): output_annotation.clear_output() with output_annotation: - print('\n') + print("\n") display(kpi_to_analyze_dropdown) display(get_answers_for_selected_kpi()) - display(answer_dropdown) + display(answer_dropdown) if answer_dropdown.value[0] == -1: display(correct_answer_textarea) display(paragraph_dropdown) if paragraph_dropdown.value[0] == -1: display(correct_paragraph_textarea) + def build_annotation_entry(): global df_result df2 = pd.DataFrame(data=None, columns=df_result.columns) id_correct_paragraph = paragraph_dropdown.value[1] rank_correct_paragraph = paragraph_dropdown.value[0] - id_correct_answer =answer_dropdown.value[1] + id_correct_answer = answer_dropdown.value[1] rank_correct_answer = answer_dropdown.value[0] correct_answer = correct_answer_textarea.value correct_paragraph = correct_paragraph_textarea.value @@ -284,70 +301,73 @@ def build_annotation_entry(): df_output = get_df_for_selected_report() if paragraph_dropdown.value[1] == -1: paragraph = "[" + str(correct_paragraph) + "]" - source_page = "[]" #+ str(correct_paragraph_page) + - source = 'Text' + source_page = "[]" # + str(correct_paragraph_page) + + source = "Text" paragraph_pred_rank = -1 paragraph_pred_score = -100 else: - paragraph = "[" + str(df_output.loc[id_correct_paragraph, 'paragraph']) + "]" - source_page = "[" + str(df_output.loc[id_correct_paragraph, 'page']) + "]" - source = df_output.loc[id_correct_paragraph, 'source'] + paragraph = "[" + str(df_output.loc[id_correct_paragraph, "paragraph"]) + "]" + source_page = "[" + str(df_output.loc[id_correct_paragraph, "page"]) + "]" + source = df_output.loc[id_correct_paragraph, "source"] paragraph_pred_rank = rank_correct_paragraph paragraph_pred_score = 100 - + if id_correct_answer == -1: answer = correct_answer kpi_pred_rank = -1 kpi_pred_score = -100 - + else: - answer = df_output.loc[id_correct_paragraph, 'answer'] + answer = df_output.loc[id_correct_paragraph, "answer"] kpi_pred_rank = rank_correct_answer - kpi_pred_score = df_output.loc[df_output.index == id_correct_answer, 'score'].values[0] - + kpi_pred_score = df_output.loc[df_output.index == id_correct_answer, "score"].values[0] + max_num = len(df_result) - - new_data = [max_num+1, - company, - report_to_analyze_select.value, - source_page, - kpi_to_investigate, - year, - answer, - source, - paragraph, - annotator, - sector, - "", - paragraph_pred_rank, - paragraph_pred_score, - kpi_pred_rank, - kpi_pred_score] + + new_data = [ + max_num + 1, + company, + report_to_analyze_select.value, + source_page, + kpi_to_investigate, + year, + answer, + source, + paragraph, + annotator, + sector, + "", + paragraph_pred_rank, + paragraph_pred_score, + kpi_pred_rank, + kpi_pred_score, + ] return new_data + def export_results(): df_result.to_excel(annotation_path + "/annotations.xlsx", index=False) - display('Success!') - - + display("Success!") + + def save_on_click(b): global df_result - + new_entry = build_annotation_entry() df_result.loc[len(df_result)] = new_entry - + kpis_of_interest = get_kpi_of_interest() val = kpi_to_analyze_dropdown.value index = kpis_of_interest.index(val) kpi_to_analyze_dropdown.value = kpis_of_interest[index + 1] - + output_save.clear_output() with output_save: display(save_button) - print('Success!') + print("Success!") display(df_result.tail(1)) - + def select_your_kpi_of_interest(): kpi_overview = get_kpi_overview() output_kpi_of_interest.clear_output() @@ -355,7 +375,7 @@ def select_your_kpi_of_interest(): display(kpi_of_interest_textarea) display(use_button) display(kpi_overview) - + use_button.on_click(use_button_on_click) display(output_kpi_of_interest) @@ -365,29 +385,25 @@ def select_the_report_to_analyze(): output_pdf_selection.clear_output() with output_pdf_selection: display(report_to_analyze_select) - - report_to_analyze_select.observe(report_to_analyze_handler, names='value') + + report_to_analyze_select.observe(report_to_analyze_handler, names="value") display(output_pdf_selection) - + + def lets_go(): - - - output_annotation.clear_output() + output_annotation.clear_output() with output_annotation: - print('\n') + print("\n") display(kpi_to_analyze_dropdown) display(get_answers_for_selected_kpi()) - output_save.clear_output() + output_save.clear_output() with output_save: display(save_button) - - - - kpi_to_analyze_dropdown.observe(kpi_to_analyze_handler, names='value') - kpi_to_analyze_handler('value') - answer_dropdown.observe(answer_paragraph_handler, names='value') - paragraph_dropdown.observe(answer_paragraph_handler, names='value') + + kpi_to_analyze_dropdown.observe(kpi_to_analyze_handler, names="value") + kpi_to_analyze_handler("value") + answer_dropdown.observe(answer_paragraph_handler, names="value") + paragraph_dropdown.observe(answer_paragraph_handler, names="value") save_button.on_click(save_on_click) display(output_annotation) display(output_save) - diff --git a/data_extractor/notebooks/annotation_tool/old_versions/tool_widgets.py b/data_extractor/notebooks/annotation_tool/old_versions/tool_widgets.py index 5d8ba24..2d03450 100644 --- a/data_extractor/notebooks/annotation_tool/old_versions/tool_widgets.py +++ b/data_extractor/notebooks/annotation_tool/old_versions/tool_widgets.py @@ -3,31 +3,32 @@ import ipywidgets as widgets -''' +""" annotations_path = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL' output_path = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/data/output' input_path = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/data/input' kpi_mapping_fpath = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/data/input/kpi_mapping.csv' -''' -annotations_path = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL' -output_path = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/output' -input_path = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/input' -kpi_mapping_fpath = '/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/kpi_mapping.csv' +""" +annotations_path = "/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL" +output_path = "/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/output" +input_path = "/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/input" +kpi_mapping_fpath = "/opt/app-root/src/corporate_data_pipeline/NLP_ANNOTATION_TOOL/kpi_mapping.csv" kpi_output = widgets.Output() button_output = widgets.Output() -output=widgets.Output() -output_paragraph=widgets.Output() -layout_hidden = widgets.Layout(visibility = 'hidden') -layout_visible = widgets.Layout(visibility = 'visible', width='60%') -options_kpi =[] +output = widgets.Output() +output_paragraph = widgets.Output() +layout_hidden = widgets.Layout(visibility="hidden") +layout_visible = widgets.Layout(visibility="visible", width="60%") +options_kpi = [] kpi_of_interest = [] + def select_pdf_handler(change): global df_output df_output = pd.read_csv(output_path + "/" + str(select_pdf.value)) - pdf_name = df_output['pdf_name'].values[0] - df_annotations_temp = df_annotations[df_annotations["source_file"]==pdf_name] - kpis_contained = [x for x in df_annotations_temp['kpi_id'].values if x in kpi_of_interest] + pdf_name = df_output["pdf_name"].values[0] + df_annotations_temp = df_annotations[df_annotations["source_file"] == pdf_name] + kpis_contained = [x for x in df_annotations_temp["kpi_id"].values if x in kpi_of_interest] open_kpis = [x for x in kpi_of_interest if x not in kpis_contained] select_out.clear_output() with select_out: @@ -37,22 +38,23 @@ def select_pdf_handler(change): print("The open kpi is " + ", ".join([str(x) for x in open_kpis]) + ".") else: print("There are no open kpi's.") - + + def select_kpi_handler(change): global df_output global df_output_check kpi_to_investigate = select_kpi.value - kpi = kpi_mapping_df.loc[kpi_mapping_df['kpi_id']==kpi_to_investigate, 'question'].values[0] - df_output_check = df_output[df_output['kpi']==kpi] + kpi = kpi_mapping_df.loc[kpi_mapping_df["kpi_id"] == kpi_to_investigate, "question"].values[0] + df_output_check = df_output[df_output["kpi"] == kpi] global index_list global select_list - index_list = df_output_check.index.values + index_list = df_output_check.index.values select_list = [] for i, x in enumerate(index_list): - select_list = select_list + [(i+1,[i+1,x])] - select_list = select_list + [('no correct answer',[-1,-1])] + select_list = select_list + [(i + 1, [i + 1, x])] + select_list = select_list + [("no correct answer", [-1, -1])] dropdown.value = None - w_dropdown_paragraph.value = None + w_dropdown_paragraph.value = None with dropdown.hold_trait_notifications(): dropdown.options = select_list with w_dropdown_paragraph.hold_trait_notifications(): @@ -60,13 +62,14 @@ def select_kpi_handler(change): kpi_output.clear_output() with kpi_output: display(df_output_check) - + + def button_next_kpi_handler(b): global df_out global kpi_of_interest - id_correct_paragraph =w_dropdown_paragraph.value[1] + id_correct_paragraph = w_dropdown_paragraph.value[1] rank_correct_paragraph = w_dropdown_paragraph.value[0] - id_correct_answer =dropdown.value[1] + id_correct_answer = dropdown.value[1] rank_correct_answer = dropdown.value[0] correct_answer = correct_answer_input.value correct_paragraph = correct_paragraph_input.value @@ -74,14 +77,14 @@ def button_next_kpi_handler(b): df_temp = df_annotations.head(0) if w_dropdown_paragraph.value[1] == -1: paragraph = "[" + str(correct_paragraph) + "]" - source_page = "[]" #+ str(correct_paragraph_page) + - source = 'Text' + source_page = "[]" # + str(correct_paragraph_page) + + source = "Text" paragraph_pred_rank = -1 paragraph_pred_score = -100 else: - paragraph = "[" + str(df_output.loc[id_correct_paragraph, 'paragraph']) + "]" - source_page = "[" + str(df_output.loc[id_correct_paragraph, 'page']) + "]" - source = df_output.loc[id_correct_paragraph, 'source'] + paragraph = "[" + str(df_output.loc[id_correct_paragraph, "paragraph"]) + "]" + source_page = "[" + str(df_output.loc[id_correct_paragraph, "page"]) + "]" + source = df_output.loc[id_correct_paragraph, "source"] paragraph_pred_rank = rank_correct_paragraph paragraph_pred_score = 100 if id_correct_answer == -1: @@ -89,57 +92,60 @@ def button_next_kpi_handler(b): kpi_pred_rank = -1 kpi_pred_score = -100 else: - answer = df_output.loc[id_correct_paragraph, 'answer'] + answer = df_output.loc[id_correct_paragraph, "answer"] kpi_pred_rank = rank_correct_answer - kpi_pred_score = df_output.loc[df_output.index == id_correct_answer, 'score'].values[0] - + kpi_pred_score = df_output.loc[df_output.index == id_correct_answer, "score"].values[0] + try: - max_num = np.max(df_out['number'].values) + max_num = np.max(df_out["number"].values) except ValueError: max_num = 1 - new_data = [max_num+1, - company, - df_output_check['pdf_name'].values[0], - source_page, - kpi_to_investigate, - year, - answer, - source, - paragraph, - annotator, - sector, - "", - paragraph_pred_rank, - paragraph_pred_score, - kpi_pred_rank, - kpi_pred_score] + new_data = [ + max_num + 1, + company, + df_output_check["pdf_name"].values[0], + source_page, + kpi_to_investigate, + year, + answer, + source, + paragraph, + annotator, + sector, + "", + paragraph_pred_rank, + paragraph_pred_score, + kpi_pred_rank, + kpi_pred_score, + ] insert_annotation_into_df(df_out, new_data) - # df_out.tail(1) + # df_out.tail(1) kpi_output.clear_output() val = select_kpi.value - index =kpi_of_interest.index(val) + index = kpi_of_interest.index(val) select_kpi.value = kpi_of_interest[index + 1] - select_kpi_handler('value') + select_kpi_handler("value") button_output.clear_output() with button_output: display(df_out.tail(1)) - print('Success!') - - + print("Success!") + + def dropdown_kpi_eventhandler(change): output.clear_output() if dropdown.value == None: pass - else: + else: if dropdown.value[1] == -1: correct_answer_input.layout = layout_visible else: correct_answer_input.layout = layout_hidden with output: - print('id_correct_answer: ' + str(dropdown.value[1])) - print('correct paragraph: ' + str(w_dropdown_paragraph.value[1])) - + print("id_correct_answer: " + str(dropdown.value[1])) + print("correct paragraph: " + str(w_dropdown_paragraph.value[1])) + + def dropdown_paragraph_eventhandler(change): output.clear_output() if w_dropdown_paragraph.value == None: @@ -150,44 +156,63 @@ def dropdown_paragraph_eventhandler(change): with output_paragraph: display(w_dropdown_paragraph) if w_dropdown_paragraph.value[1] == -1: - display(correct_paragraph_input) + display(correct_paragraph_input) with output: - print('id_correct_answer: ' + str(dropdown.value[1])) - print('id_correct_paragraph: ' + str(w_dropdown_paragraph.value[1])) + print("id_correct_answer: " + str(dropdown.value[1])) + print("id_correct_paragraph: " + str(w_dropdown_paragraph.value[1])) + + # Refactor and utility + - - #Refactor and utility def df_from_kpi_infer(infer_fpath, kpi_mapping_fpath): df_infer = pd.read_csv(infer_fpath) - df_kpi_mapping = pd.read_csv(kpi_mapping_fpath)[['kpi_id', 'question']] - df = df_infer.merge(df_kpi_mapping, how='left', left_on='kpi', right_on = 'question') - df = df[['pdf_name', 'kpi', 'kpi_id_y', 'answer','page','paragraph', 'source', 'score' ,'no_ans_score' ,'no_answer_score_plus_boost']] - df = df.rename(columns={'kpi_id_y':'kpi_id'}) + df_kpi_mapping = pd.read_csv(kpi_mapping_fpath)[["kpi_id", "question"]] + df = df_infer.merge(df_kpi_mapping, how="left", left_on="kpi", right_on="question") + df = df[ + [ + "pdf_name", + "kpi", + "kpi_id_y", + "answer", + "page", + "paragraph", + "source", + "score", + "no_ans_score", + "no_answer_score_plus_boost", + ] + ] + df = df.rename(columns={"kpi_id_y": "kpi_id"}) return df + def get_answers_for_kpi_id(df, kpi_id): - df_answers = df[df['kpi_id']==kpi_id] + df_answers = df[df["kpi_id"] == kpi_id] return df_answers - + + def get_kpi_for_id(df, id): - kpi = df['kpi_id']==id + kpi = df["kpi_id"] == id return kpi + def get_kpi_selection_options(df): - index_list = df.index.values + index_list = df.index.values select_list = [] for i, x in enumerate(index_list): - select_list = select_list + [(i+1,[i+1,x])] - select_list = select_list + [('no correct answer',[-1,-1])] + select_list = select_list + [(i + 1, [i + 1, x])] + select_list = select_list + [("no correct answer", [-1, -1])] return select_list + def insert_annotation_into_df(df, new_entry): df.loc[len(df)] = new_entry -def build_annotation_entry(kpi_index, par_index, df, set_answer = None, set_pargraph = None, set_page = None ): + +def build_annotation_entry(kpi_index, par_index, df, set_answer=None, set_pargraph=None, set_page=None): id_correct_paragraph = kpi_index - rank_correct_paragraph = par_index, - id_correct_answer =dropdown.value[1] + rank_correct_paragraph = (par_index,) + id_correct_answer = dropdown.value[1] rank_correct_answer = dropdown.value[0] correct_answer = correct_answer_input.value correct_paragraph = correct_paragraph_input.value @@ -195,99 +220,104 @@ def build_annotation_entry(kpi_index, par_index, df, set_answer = None, set_parg df_temp = df_annotations.head(0) if w_dropdown_paragraph.value[1] == -1: paragraph = "[" + str(correct_paragraph) + "]" - source_page = "[]" #+ str(correct_paragraph_page) + - source = 'Text' + source_page = "[]" # + str(correct_paragraph_page) + + source = "Text" paragraph_pred_rank = -1 paragraph_pred_score = -100 else: - paragraph = "[" + str(df_output.loc[id_correct_paragraph, 'paragraph']) + "]" - source_page = "[" + str(df_output.loc[id_correct_paragraph, 'page']) + "]" - source = df_output.loc[id_correct_paragraph, 'source'] + paragraph = "[" + str(df_output.loc[id_correct_paragraph, "paragraph"]) + "]" + source_page = "[" + str(df_output.loc[id_correct_paragraph, "page"]) + "]" + source = df_output.loc[id_correct_paragraph, "source"] paragraph_pred_rank = rank_correct_paragraph paragraph_pred_score = 100 - + if id_correct_answer == -1: answer = correct_answer kpi_pred_rank = -1 kpi_pred_score = -100 - + else: - answer = df_output.loc[id_correct_paragraph, 'answer'] + answer = df_output.loc[id_correct_paragraph, "answer"] kpi_pred_rank = rank_correct_answer - kpi_pred_score = df_output.loc[df_output.index == id_correct_answer, 'score'].values[0] - + kpi_pred_score = df_output.loc[df_output.index == id_correct_answer, "score"].values[0] + try: - max_num = np.max(df_out['number'].values) + max_num = np.max(df_out["number"].values) except ValueError: max_num = 1 - new_data = [max_num+1, - company, - df_output_check['pdf_name'].values[0], - source_page, - kpi_to_investigate, - year, - answer, - source, - paragraph, - annotator, - sector, - "", - paragraph_pred_rank, - paragraph_pred_score, - kpi_pred_rank, - kpi_pred_score] - df_series = pd.Series(new_data, index = df_temp.columns) - df_temp = df_temp.append(df_series, ignore_index = True) - df_temp = df_temp.set_index([pd.Index([np.max(df_out.index)+1])]) + new_data = [ + max_num + 1, + company, + df_output_check["pdf_name"].values[0], + source_page, + kpi_to_investigate, + year, + answer, + source, + paragraph, + annotator, + sector, + "", + paragraph_pred_rank, + paragraph_pred_score, + kpi_pred_rank, + kpi_pred_score, + ] + df_series = pd.Series(new_data, index=df_temp.columns) + df_temp = df_temp.append(df_series, ignore_index=True) + df_temp = df_temp.set_index([pd.Index([np.max(df_out.index) + 1])]) df_out = df_out.append(df_temp) df_out.tail(1) kpi_output.clear_output() val = select_kpi.value - index =kpi_of_interest.index(val) + index = kpi_of_interest.index(val) select_kpi.value = kpi_of_interest[index + 1] def update_pdf_selection(pdf_path, df_annotations, kpi_of_interest): pdf_selection = [] outputs = glob.glob(pdf_path + "/*") - outputs = [x.rsplit('/', 1)[1] for x in outputs] + outputs = [x.rsplit("/", 1)[1] for x in outputs] for output in outputs: df_output = pd.read_csv(output_path + "/" + output) - pdf_name = df_output['pdf_name'].values[0] + pdf_name = df_output["pdf_name"].values[0] df_annotations_temp = df_annotations[df_annotations.source_file == pdf_name] - kpis_contained = [x for x in df_annotations_temp['kpi_id'].values if x in kpi_of_interest] + kpis_contained = [x for x in df_annotations_temp["kpi_id"].values if x in kpi_of_interest] if set(kpis_contained) == set(kpi_of_interest): - pdf_selection = select_options + [(f'DONE - ' + output, output)] + pdf_selection = select_options + [(f"DONE - " + output, output)] else: - pdf_selection = select_options + [(f'TODO ({len(set(kpis_contained))}/{len(kpi_of_interest)}) - ' + output, output)] + pdf_selection = select_options + [ + (f"TODO ({len(set(kpis_contained))}/{len(kpi_of_interest)}) - " + output, output) + ] return pdf_selection -def safe_annotations(path, df, replace = False): - path = path + '/annotations.xlsx' - if replace: + +def safe_annotations(path, df, replace=False): + path = path + "/annotations.xlsx" + if replace: df.to_excel(path) else: with pd.ExcelWriter(path, mode="a") as f: df.to_excel(f) - + + def lets_go(): - select_pdf_handler('value') - select_kpi_handler('value') - dropdown.observe(dropdown_kpi_eventhandler, names='value') - w_dropdown_paragraph.observe(dropdown_paragraph_eventhandler, names='value') - select_pdf.observe(select_pdf_handler, names='value') - select_kpi.observe(select_kpi_handler, names='value') + select_pdf_handler("value") + select_kpi_handler("value") + dropdown.observe(dropdown_kpi_eventhandler, names="value") + w_dropdown_paragraph.observe(dropdown_paragraph_eventhandler, names="value") + select_pdf.observe(select_pdf_handler, names="value") + select_kpi.observe(select_kpi_handler, names="value") button_next_kpi.on_click(button_next_kpi_handler) - select_kpi_handler('value') + select_kpi_handler("value") display(select_kpi) - print('Possible answers for the chosen KPI:') + print("Possible answers for the chosen KPI:") display(kpi_output) - #dl = widgets.dlink((dropdown, 'value'), (w_dropdown_paragraph, 'value')) - + # dl = widgets.dlink((dropdown, 'value'), (w_dropdown_paragraph, 'value')) - display(dropdown,correct_answer_input) + display(dropdown, correct_answer_input) with output_paragraph: display(w_dropdown_paragraph) @@ -295,116 +325,116 @@ def lets_go(): display(output) display(button_next_kpi) display(button_output) - + + def select_the_report_to_analyze(): select_pdf = widgets.Select( options=sorted(select_options, reverse=True), # rows=10, - description='Available PDFs:', - style = {'description_width': 'initial'}, - layout = widgets.Layout(width='100%', height='150px'), - disabled=False - ) + description="Available PDFs:", + style={"description_width": "initial"}, + layout=widgets.Layout(width="100%", height="150px"), + disabled=False, + ) select_out = widgets.Output() display(wdg.select_pdf) - display(wdg.select_out) - - - + display(wdg.select_out) + + def select_your_kpi_of_interest(): kpi_mapping_df = pd.read_csv(kpi_mapping_fpath) - kpi_mapping_df['kpi_id'] = kpi_mapping_df['kpi_id'].map('{:g}'.format) + kpi_mapping_df["kpi_id"] = kpi_mapping_df["kpi_id"].map("{:g}".format) kpi_mapping_df.head(30) select_of_interest = widgets.Textarea( - value=', '.join(kpi_mapping_df['kpi_id'].values), - placeholder='0, 1, 2, 3.1, ...', - layout = widgets.Layout(width='50%'), - style = {'description_width': 'initial'}, - description='KPI of interest:', - disabled=False + value=", ".join(kpi_mapping_df["kpi_id"].values), + placeholder="0, 1, 2, 3.1, ...", + layout=widgets.Layout(width="50%"), + style={"description_width": "initial"}, + description="KPI of interest:", + disabled=False, ) display_df = kpi_mapping_df - print('Insert your kpi of interest as a comma seperated list.\n') + print("Insert your kpi of interest as a comma seperated list.\n") display(select_of_interest) - print('\n') + print("\n") display(kpi_mapping_df.head(30)) - + + def preload(): - outputs = glob.glob(output_path + "/*") - outputs = [x.rsplit('/', 1)[1] for x in outputs] - df_output=None + outputs = [x.rsplit("/", 1)[1] for x in outputs] + df_output = None select_options = [] for output in outputs: df_output = pd.read_csv(output_path + "/" + output) - pdf_name = df_output['pdf_name'].values[0] - df_annotations_temp = df_annotations[df_annotations["source_file"]==pdf_name] - kpis_contained = [x for x in df_annotations_temp['kpi_id'].values if x in kpi_of_interest] + pdf_name = df_output["pdf_name"].values[0] + df_annotations_temp = df_annotations[df_annotations["source_file"] == pdf_name] + kpis_contained = [x for x in df_annotations_temp["kpi_id"].values if x in kpi_of_interest] if set(kpis_contained) == set(kpi_of_interest): - select_options = select_options + [(f'DONE - ' + output, output)] + select_options = select_options + [(f"DONE - " + output, output)] else: - select_options = select_options + [(f'TODO ({len(set(kpis_contained))}/{len(kpi_of_interest)}) - ' + output, output)] + select_options = select_options + [ + (f"TODO ({len(set(kpis_contained))}/{len(kpi_of_interest)}) - " + output, output) + ] - options_kpi = [(x,x) for x in kpi_of_interest] + options_kpi = [(x, x) for x in kpi_of_interest] df_output = pd.read_csv(output_path + "/" + output_file) - df_output = df_output.drop(columns = 'Unnamed: 0') + df_output = df_output.drop(columns="Unnamed: 0") df_output_check = None - index_list = df_output_check.index.values + index_list = df_output_check.index.values select_list = [] for i, x in enumerate(index_list): - select_list = select_list + [(i+1,[i+1,x])] - select_list = select_list + [('no correct answer',[-1,-1])] - - + select_list = select_list + [(i + 1, [i + 1, x])] + select_list = select_list + [("no correct answer", [-1, -1])] + select_kpi = widgets.Dropdown( options=sorted(options_kpi), # rows=10, - value = min(kpi_of_interest), - description='Current KPI:', - style = {'description_width': 'initial'}, - layout = widgets.Layout(width='100%') + value=min(kpi_of_interest), + description="Current KPI:", + style={"description_width": "initial"}, + layout=widgets.Layout(width="100%"), ) -dropdown = widgets.Dropdown(description="Choose rank of correct kpi:", - options=select_list, - style = {'description_width': 'initial'}, - layout = widgets.Layout(width='50%'), - value = None) +dropdown = widgets.Dropdown( + description="Choose rank of correct kpi:", + options=select_list, + style={"description_width": "initial"}, + layout=widgets.Layout(width="50%"), + value=None, +) -w_dropdown_paragraph = widgets.Dropdown(description="Choose rank of correct pagraph:", - options=select_list, - style = {'description_width': 'initial'}, - layout = widgets.Layout(width='50%'), - value = None) +w_dropdown_paragraph = widgets.Dropdown( + description="Choose rank of correct pagraph:", + options=select_list, + style={"description_width": "initial"}, + layout=widgets.Layout(width="50%"), + value=None, +) button_next_kpi = widgets.Button( value=False, - description='Save', + description="Save", disabled=False, - button_style='', # 'success', 'info', 'warning', 'danger' or '' - tooltip='Save current annotation and switch to next KPI', - icon='check' # (FontAwesome names without the `fa-` prefix) + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Save current annotation and switch to next KPI", + icon="check", # (FontAwesome names without the `fa-` prefix) ) correct_answer_input = widgets.Textarea( - value='', - placeholder='Enter your correction.', - description='Correction:', - layout = layout_hidden, - disabled=False + value="", placeholder="Enter your correction.", description="Correction:", layout=layout_hidden, disabled=False ) correct_paragraph_input = widgets.Textarea( - value='', - placeholder='Enter your correction.', - layout = widgets.Layout(width='60%'), - description='Correction:', - disabled=False + value="", + placeholder="Enter your correction.", + layout=widgets.Layout(width="60%"), + description="Correction:", + disabled=False, ) - diff --git a/inception_transformer/inception_transformer.py b/inception_transformer/inception_transformer.py index 436e05e..2785ccf 100644 --- a/inception_transformer/inception_transformer.py +++ b/inception_transformer/inception_transformer.py @@ -15,10 +15,11 @@ import os import shutil from cassis import * + sys.dont_write_bytecode = True # Load settings data -with open('settings.json', "r", encoding="utf-8", errors='ignore') as settings_file: +with open("settings.json", "r", encoding="utf-8", errors="ignore") as settings_file: settings = json.load(settings_file) settings_file.close() @@ -34,7 +35,7 @@ def select_covering(cas, type_name, covered_annotation, overlap): overlap: Boolean if annotation are allowed to overlap multiple annotations Returns: A list of covering annotations - """ + """ c_begin = covered_annotation.begin c_end = covered_annotation.end @@ -43,9 +44,11 @@ def select_covering(cas, type_name, covered_annotation, overlap): if overlap: annotations_list = [] for annotation in cas._get_feature_structures(type_name): - if (annotation.begin <= c_begin <= annotation.end) \ - or (annotation.begin <= c_end <= annotation.end) \ - or (c_end >= annotation.end and c_begin <= annotation.begin): + if ( + (annotation.begin <= c_begin <= annotation.end) + or (annotation.begin <= c_end <= annotation.end) + or (c_end >= annotation.end and c_begin <= annotation.begin) + ): annotations_list.append(annotation) return annotations_list else: @@ -61,53 +64,66 @@ def specify_logger(logger): :return: specified logger object """ handler = logging.StreamHandler(sys.stdout) - formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s', datefmt='%Y%m%d %H:%M:%S') - logger_dir = settings['LogFilePath'] + formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s", datefmt="%Y%m%d %H:%M:%S") + logger_dir = settings["LogFilePath"] initial_time = datetime.now().strftime("%Y%m%d%H%M%S") - fh = logging.FileHandler(logger_dir + "/" + settings['ProjectName'] + "_" + initial_time + ".log") + fh = logging.FileHandler(logger_dir + "/" + settings["ProjectName"] + "_" + initial_time + ".log") fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) logger.addHandler(fh) logger.addHandler(handler) - print("Logger information are stored in " + logger_dir + "/" + settings['ProjectName'] - + "_" + initial_time + ".log") + print( + "Logger information are stored in " + logger_dir + "/" + settings["ProjectName"] + "_" + initial_time + ".log" + ) return logger def get_uima_cas_xmi_output(logger): - """ This file extracts from the UIMA CAS XMI type via the dkpro-cassis package + """This file extracts from the UIMA CAS XMI type via the dkpro-cassis package the annotated answers and saves it into an excel sheet. :param logger: logging class element :return: None """ tic = time.time() - xmi_name = glob.glob(settings['InputPath'] + "/*.xmi")[0] - xml_name = glob.glob(settings['InputPath'] + "/*.xml")[0] + xmi_name = glob.glob(settings["InputPath"] + "/*.xmi")[0] + xml_name = glob.glob(settings["InputPath"] + "/*.xml")[0] - with open(xml_name, 'rb') as f: + with open(xml_name, "rb") as f: typesystem = load_typesystem(f) f.close() - with open(xmi_name, 'rb') as f: + with open(xmi_name, "rb") as f: cas = load_cas_from_xmi(f, typesystem=typesystem) f.close() logger.info("UIMA file loaded") - df_answers = pd.DataFrame(columns=['KPI', 'ANSWER', 'TYPE', 'ANSWER_X', 'ANSWER_Y', - 'PAGE', 'PAGE_WIDTH', 'PAGE_HEIGHT', 'PAGE_ORIENTATION', - 'COV_SENTENCES']) - - for page in cas.select('org.dkpro.core.api.pdf.type.PdfPage'): - for answer in cas.select_covered('webanno.custom.KPIAnswer', page): + df_answers = pd.DataFrame( + columns=[ + "KPI", + "ANSWER", + "TYPE", + "ANSWER_X", + "ANSWER_Y", + "PAGE", + "PAGE_WIDTH", + "PAGE_HEIGHT", + "PAGE_ORIENTATION", + "COV_SENTENCES", + ] + ) + + for page in cas.select("org.dkpro.core.api.pdf.type.PdfPage"): + for answer in cas.select_covered("webanno.custom.KPIAnswer", page): logger.info("---------------------------------------------------") - logger.info(f"Extracting information of kpi answer \"{answer.get_covered_text()}\".") + logger.info(f'Extracting information of kpi answer "{answer.get_covered_text()}".') - sentence_list_temp = select_covering(cas, 'de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence', - answer, True) + sentence_list_temp = select_covering( + cas, "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence", answer, True + ) - chunk_list_temp = select_covering(cas, 'org.dkpro.core.api.pdf.type.PdfChunk', answer, True) + chunk_list_temp = select_covering(cas, "org.dkpro.core.api.pdf.type.PdfChunk", answer, True) # Find the first chunk as chunks might be mixed in order min_idx = 0 for idx in range(len(chunk_list_temp)): @@ -130,20 +146,33 @@ def get_uima_cas_xmi_output(logger): # Collect all information we want to store for kpi in answer.KPI.elements: - df_answers.loc[len(df_answers)] = [kpi, answer.get_covered_text(), 'KPIAnswer', - answer_x, pdf_chunk.y, - int(page.pageNumber), int(page.width), int(page.height), - int(pdf_chunk.d), - " ".join([x.get_covered_text() for x in sentence_list_temp]) - ] + df_answers.loc[len(df_answers)] = [ + kpi, + answer.get_covered_text(), + "KPIAnswer", + answer_x, + pdf_chunk.y, + int(page.pageNumber), + int(page.width), + int(page.height), + int(pdf_chunk.d), + " ".join([x.get_covered_text() for x in sentence_list_temp]), + ] if len(df_answers) > 0: - df_answers['PDF_NAME'] = \ - cas.select('de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData')[0]['documentTitle'] + df_answers["PDF_NAME"] = cas.select("de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData")[0][ + "documentTitle" + ] logger.info("---------------------------------------------------") logger.info("All information have been extracted.") - df_answers.to_excel(settings['OutputPath'] + '/' + settings['InputPath'].split("/")[-1] - + "_uima_extraction_" + datetime.now().strftime("%Y%m%d%H%M%S") + '.xlsx') + df_answers.to_excel( + settings["OutputPath"] + + "/" + + settings["InputPath"].split("/")[-1] + + "_uima_extraction_" + + datetime.now().strftime("%Y%m%d%H%M%S") + + ".xlsx" + ) toc = time.time() logger.info(f"Answers have been saved to excel and it took {toc - tic} seconds.") @@ -154,15 +183,15 @@ def main(logger=logging.getLogger()): if not logger.handlers: logger = specify_logger(logger) - logger.info('---------------- SETTINGS DATA ----------------') + logger.info("---------------- SETTINGS DATA ----------------") for key in settings: logger.info(str(key) + ": " + str(settings[key])) - logger.info('------------------------------------------------') + logger.info("------------------------------------------------") try: input_path = settings["InputPath"] output_path = settings["OutputPath"] - for subfolder in [x.split("\\")[-1] for x in glob.glob(input_path + "/*") if x[-4:] != 'json']: + for subfolder in [x.split("\\")[-1] for x in glob.glob(input_path + "/*") if x[-4:] != "json"]: settings["InputPath"] = input_path + "/" + subfolder output_folder = output_path + "/" + subfolder if not os.path.isdir(output_folder): @@ -177,11 +206,13 @@ def main(logger=logging.getLogger()): pass settings["OutputPath"] = output_folder - logger.info(f'Start of UIMA file transformation for files in subfolder \"' + subfolder + '\".') + logger.info(f'Start of UIMA file transformation for files in subfolder "' + subfolder + '".') get_uima_cas_xmi_output(logger) - logger.info(f'UIMA file transformation for files in subfolder \"' + subfolder + '\" successfully made.' - ' Input files ' - 'and output file are stored in the folder ' + output_folder) + logger.info( + f'UIMA file transformation for files in subfolder "' + subfolder + '" successfully made.' + " Input files " + "and output file are stored in the folder " + output_folder + ) for file in glob.glob(settings["InputPath"] + "/*"): shutil.copyfile(file, file.replace(settings["InputPath"], settings["OutputPath"])) @@ -195,5 +226,5 @@ def main(logger=logging.getLogger()): logger.error(traceback.format_exc()) -if __name__ == '__main__': +if __name__ == "__main__": main()