-
Notifications
You must be signed in to change notification settings - Fork 77
/
fuzzy_dedup.py
1716 lines (1572 loc) · 64.9 KB
/
fuzzy_dedup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import math
import os
import time
import warnings
from itertools import pairwise
from typing import List, Optional, Tuple, Union
import cudf
import cugraph.dask as dcg
import cugraph.dask.comms.comms as Comms
import cupy as cp
import dask_cudf
import numpy as np
import pandas as pd
import pyarrow as pa
from cugraph import MultiGraph
from dask import dataframe as dd
from dask.utils import M
from tqdm import tqdm
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import FuzzyDuplicatesConfig
from nemo_curator.modules.meta import Sequential
from nemo_curator.utils.distributed_utils import (
get_current_client,
get_num_workers,
performance_report_if_with_ts_suffix,
)
from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import int_ids_to_str
from nemo_curator.utils.fuzzy_dedup_utils.io_utils import (
aggregated_anchor_docs_with_bk_read,
get_restart_offsets,
update_restart_offsets,
)
from nemo_curator.utils.fuzzy_dedup_utils.merge_utils import (
extract_partitioning_index,
filter_text_rows_by_bucket_batch,
merge_left_to_shuffled_right,
)
from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import (
build_partition,
get_agg_text_bytes_df,
)
from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import (
text_bytes_aware_shuffle,
write_partitioned_file,
)
class MinHash:
"""
Computes minhash signatures of a document corpus
"""
def __init__(
self,
seed: int = 42,
num_hashes: int = 260,
char_ngrams: int = 5,
use_64bit_hash: bool = False,
logger: Union[logging.LoggerAdapter, str] = "./",
id_field: str = "id",
text_field: str = "text",
profile_dir: str = None,
cache_dir: str = None,
):
"""
Parameters
----------
seed: Seed for minhash permutations
num_hashes: Length of minhash signature (No. of minhash permutations)
char_ngrams: Width of text window (in characters) while computing minhashes.
use_64bit_hash: Whether to use a 64 bit hash function.
logger: Existing logger to log to, or a path to a log directory.
id_field: Column in the Dataset denoting document ID.
text_field: Column in the Dataset denoting document content.
profile_dir: str, Default None
If specified directory to write dask profile
cache_dir: str, Default None
If specified, will compute & write id, minhash pairs to directory
"""
self.num_hashes = num_hashes
self.char_ngram = char_ngrams
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32
self.id_field = id_field
self.text_field = text_field
if cache_dir is None and profile_dir is not None:
warnings.warn(
"cache_dir for intermediate outputs is required to generate profiles"
)
self.cache_dir = cache_dir
self.profile_dir = profile_dir
if isinstance(logger, str):
self._logger = create_logger(
rank=0,
log_file=os.path.join(logger, "Minhash.log"),
name="Minhash",
)
else:
self._logger = logger
def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray:
"""
Generate seeds for all minhash permutations based on the given seed.
"""
gen = np.random.RandomState(seed)
return gen.randint(0, 1e6, size=n_seeds)
def minhash32(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
) -> cudf.Series:
"""
Compute 32bit minhashes based on the MurmurHash3 algorithm
"""
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")
seeds = cudf.Series(seeds, dtype="uint32")
return ser.str.minhash(seeds=seeds, width=char_ngram)
def minhash64(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
) -> cudf.Series:
"""
Compute 64bit minhashes based on the MurmurHash3 algorithm
"""
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")
seeds = cudf.Series(seeds, dtype="uint64")
return ser.str.minhash64(seeds=seeds, width=char_ngram)
def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
"""
Computes the MinHash Signatures for a given dataset.
Parameters
----------
dataset: DocumentDataset
The input datset to compute MinHashes.
Returns
-------
DocumentDataset containing IDs of all documents and the corresponding MinHash Signature
"""
result = dataset.df[[self.id_field]]
result["_minhash_signature"] = dataset.df[self.text_field].map_partitions(
self.minhash_method,
seeds=self.seeds,
char_ngram=self.char_ngram,
)
if self.cache_dir is None:
return DocumentDataset(result)
t0 = time.time()
self._logger.info("Starting execution for Minhashes")
write_path = os.path.join(self.cache_dir, "_minhashes.parquet")
if os.path.exists(write_path):
warnings.warn(
f"Output path {write_path} already exists and will be overwritten"
)
with performance_report_if_with_ts_suffix(self.profile_dir, "minhash-profile"):
result.to_parquet(write_path, write_index=False, overwrite=True)
self._logger.info(
f"Time taken for Minhash signature computation = {time.time() - t0}s and output written at {write_path}"
)
return DocumentDataset(
dask_cudf.read_parquet(write_path, blocksize="2GB", aggregate_files=True)
)
class LSH:
"""
Performs LSH on a MinhashSignatures
"""
def __init__(
self,
cache_dir: str,
num_hashes: int,
num_buckets: int,
buckets_per_shuffle: int = 1,
logger: Union[logging.LoggerAdapter, str] = "./",
id_fields: Union[str, list] = "id",
minhash_field: str = "_minhash_signature",
profile_dir: Optional[str] = None,
):
"""
Parameters
----------
cache_dir: str
Needs to be specified, will compute & write duplicate id, bucket pairs to cache directory.
num_hashes: Length of minhash signature
num_buckets: Number of bands/buckets to create from the minhash signature.
Hashes_per_signature = num_hashes / num_buckets
buckets_per_shuffle: Number of bands/buckets to shuffle concurrently.
Larger values process larger batches by processing multiple bands
but might lead to memory pressures and related errors.
logger: Existing logger to log to, or a path to a log directory.
id_field: Columns in the Dataset denoting document ID.
minhash_field: Column in the Dataset denoting minhash signature.
profile_dir: str, Default None
If specified directory to write dask profile
"""
self.num_hashes = num_hashes
self.num_buckets = num_buckets
self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
self.minhash_field = minhash_field
self.buckets_per_shuffle = buckets_per_shuffle
self.bucket_ranges = self._generate_bucket_ranges(
self.num_buckets, self.num_hashes
)
if cache_dir is None:
raise ValueError(
"cache_dir for intermediate outputs is required for this stage"
)
self.cache_dir = cache_dir
self.profile_dir = profile_dir
if isinstance(logger, str):
self._logger = create_logger(
rank=0,
log_file=os.path.join(logger, "LSH.log"),
name="LSH",
)
else:
self._logger = logger
def _generate_bucket_ranges(
self, num_buckets: int, num_hashes: int
) -> List[List[int]]:
"""
Generates a list of indices for the minhash ranges given num_bands &
num_hashes.
eg: num_bands=3, num_hashes=6
[[0, 1], [2, 3], [4, 5]]
"""
minhashes_per_bucket = num_hashes // num_buckets
bucket_ranges = [
list(
range(
bucket * minhashes_per_bucket, (bucket + 1) * minhashes_per_bucket
)
)
for bucket in range(num_buckets)
]
return bucket_ranges
def minhash_to_buckets(
self,
df: cudf.DataFrame,
bucket_ranges: List[List[int]],
) -> cudf.DataFrame:
df2 = df[self.id_fields]
for i, h in enumerate(bucket_ranges):
indices = cudf.Series([h]).repeat(len(df2))
df2[f"_bucket_{i}"] = f"b{i}_" + df[self.minhash_field].list.take(
indices
).hash_values(method="md5")
return df2
def bucket_id_to_int(
self,
bucket_ddf: dask_cudf.DataFrame,
bucket_col_name: str = "bucket_id",
start_id: int = 0,
) -> Tuple[dask_cudf.DataFrame, int]:
"""
Maps bucket ids to a contigious integer range from starting from start_id.
"""
unique_bucket_df = (
bucket_ddf[[bucket_col_name]]
.map_partitions(lambda x: x.drop_duplicates(ignore_index=True))
.persist()
)
end_bucket_id = len(unique_bucket_df) - 1 + start_id
unique_bucket_df["bucket_int_id"] = np.uint64(1)
unique_bucket_df["bucket_int_id"] = unique_bucket_df["bucket_int_id"].cumsum()
unique_bucket_df["bucket_int_id"] = (
unique_bucket_df["bucket_int_id"] - 1 + start_id
)
bucket_ddf = bucket_ddf.merge(unique_bucket_df, on=[bucket_col_name])
bucket_ddf = bucket_ddf.drop(columns=[bucket_col_name])
bucket_ddf = bucket_ddf.rename(columns={"bucket_int_id": "_bucket_id"})
bucket_ddf["_bucket_id"] = bucket_ddf["_bucket_id"].astype(np.uint64)
return (bucket_ddf, end_bucket_id)
def _minhash_to_bucket_meta(
self, df: dask_cudf.DataFrame
) -> Tuple[cudf.DataFrame, int]:
meta = df._meta_nonempty[self.id_fields]
meta[self.minhash_field] = [np.ones(self.num_hashes)] * len(meta)
return self.minhash_to_buckets(meta, self.bucket_ranges)
def lsh(
self,
write_path: str,
df: dask_cudf.DataFrame,
) -> None:
"""
Computes buckets and writes them as parquet files to the write_path
"""
meta = self._minhash_to_bucket_meta(df)
df = df.map_partitions(
self.minhash_to_buckets,
bucket_ranges=self.bucket_ranges,
meta=meta,
)
bucket_start_id = 0
for i in range(0, self.num_buckets, self.buckets_per_shuffle):
value_vars = [
f"_bucket_{i}"
for i in range(i, min(self.num_buckets, i + self.buckets_per_shuffle))
]
df2 = df.melt(
id_vars=self.id_fields, value_name="_bucket_id", value_vars=value_vars
)[self.id_fields + ["_bucket_id"]]
df2 = df2.shuffle(
on=["_bucket_id"],
ignore_index=True,
npartitions=max(1, 2 ** math.floor(math.log2(df2.npartitions))),
).map_partitions(lambda x: x[x["_bucket_id"].duplicated(keep=False)])
df2 = df2.reset_index(drop=True)
df2, end_id = self.bucket_id_to_int(
df2, bucket_col_name="_bucket_id", start_id=bucket_start_id
)
# If bucketing return empty dataframe
if end_id < bucket_start_id:
continue
bucket_start_id = end_id + 1
# Workaround for dtype mismatches with empty partitions
dtypes = df2.dtypes.to_dict()
df2 = df2.map_partitions(lambda x: x.astype(dtypes))
if i == 0:
if os.path.exists(write_path):
warnings.warn(
f"Output path {write_path} already exists and will be overwritten"
)
df2.to_parquet(write_path, write_index=False, overwrite=True)
else:
df2.to_parquet(write_path, write_index=False, append=True)
self._logger.info(f"Wrote data for buckets: {value_vars}")
def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
df = dataset.df
write_path = os.path.join(self.cache_dir, "_buckets.parquet")
t0 = time.time()
with performance_report_if_with_ts_suffix(self.profile_dir, f"lsh-profile"):
self.lsh(write_path=write_path, df=df)
self._logger.info(
f"Time taken for LSH = {time.time() - t0}s and output written at {write_path}"
)
buckets_df = dask_cudf.read_parquet(write_path, split_row_groups=False)
return DocumentDataset(buckets_df)
class FuzzyDuplicates:
def __init__(
self,
config: FuzzyDuplicatesConfig,
logger: Union[logging.LoggerAdapter, str] = "./",
):
"""
Parameters
----------
config: FuzzyDuplicatesConfig,
Config options for finding FuzzyDuplicates
logger: Existing logger to log to, or a path to a log directory.
Returns
-------
DocumentDataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are near duplicates.
"""
if isinstance(logger, str):
self._logger = create_logger(
rank=0,
log_file=os.path.join(logger, "FuzzyDuplicates.log"),
name="FuzzyDuplicates",
)
else:
self._logger = logger
self.config = config
self.minhash = MinHash(
seed=self.config.seed,
num_hashes=self.config.num_hashes,
char_ngrams=self.config.char_ngrams,
use_64bit_hash=self.config.use_64_bit_hash,
logger=self._logger,
id_field=self.config.id_field,
text_field=self.config.text_field,
profile_dir=self.config.profile_dir,
cache_dir=self.config.cache_dir,
)
self.lsh = LSH(
cache_dir=self.config.cache_dir,
num_hashes=self.config.num_hashes,
num_buckets=self.config.num_buckets,
buckets_per_shuffle=self.config.buckets_per_shuffle,
logger=self._logger,
id_fields=[self.config.id_field],
profile_dir=self.config.profile_dir,
)
if self.config.false_positive_check:
self.map_buckets = _MapBuckets(
id_fields=[self.config.id_field],
text_field=self.config.text_field,
logger=self._logger,
num_anchors=self.config.num_anchors,
)
self.jaccard_shuffle = _Shuffle(
id_fields=[self.config.id_field],
text_field=self.config.text_field,
logger=self._logger,
profile_dir=self.config.profile_dir,
)
self.jaccard_compute = JaccardSimilarity(
id_field=self.config.id_field,
text_field=self.config.text_field,
ngram_width=self.config.char_ngrams,
anchor_id_fields=[
f"anchor_{i}_{self.config.id_field}"
for i in range(self.config.num_anchors)
],
)
else:
self.buckets_to_edges = BucketsToEdges(
cache_dir=self.config.cache_dir,
id_fields=self.config.id_field,
logger=self._logger,
profile_dir=self.config.profile_dir,
)
jaccard_pairs_fname = (
"jaccard_similarity_results.parquet"
if self.config.false_positive_check
else "_edges.parquet"
)
self.connected_components = ConnectedComponents(
cache_dir=self.config.cache_dir,
jaccard_pairs_path=os.path.join(self.config.cache_dir, jaccard_pairs_fname),
id_column=self.config.id_field,
convert_str_ids=False,
jaccard_threshold=self.config.jaccard_threshold,
logger=self._logger,
profile_dir=self.config.profile_dir,
)
def __call__(self, dataset: DocumentDataset):
"""
Parameters
----------
dataset: DocumentDataset
The input datset to compute FuzzyDuplicates. Must contain a text and unique id field.
Returns
-------
DocumentDataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are near duplicates.
"""
# Minhash + LSH
stage_num = 1
print(f"Stage{stage_num}: Starting Minhash + LSH computation")
minhashLSH = Sequential([self.minhash, self.lsh])
buckets_df = minhashLSH(dataset)
print(f"Stage{stage_num}: Minhash + LSH complete!")
stage_num += 1
if self.config.false_positive_check:
# Map buckets to lower cardinality distribution
print(f"Stage{stage_num} (False Positive Check): Starting Map_Buckets")
t0 = time.time()
mapped_buckets_w_anchors_path = os.path.join(
self.config.cache_dir, "anchor_docs_with_bk.parquet"
)
with performance_report_if_with_ts_suffix(
self.config.profile_dir,
f"map_buckets",
):
ddf_mapped_buckets_w_anchors = (
self.map_buckets.map_buckets_with_anchors(
documents_df=dataset.df, buckets_df=buckets_df.df
)
)
ddf_mapped_buckets_w_anchors.to_parquet(
mapped_buckets_w_anchors_path, write_index=False
)
self._logger.info(
f"Time taken for Map_buckets : {time.time() - t0}s and output written at {mapped_buckets_w_anchors_path}"
)
print(f"Stage{stage_num} (False Postive Check): Map_Buckets Complete!")
stage_num += 1
# Shuffle documents based on mapped buckets
print(f"Stage{stage_num} (False Postive Check): Shuffle docs")
shuffled_docs_path = os.path.join(
self.config.cache_dir, "shuffled_docs.parquet"
)
self.jaccard_shuffle.shuffle_docs_on_buckets(
documents_df=dataset.df,
bucket_w_anchors_path=mapped_buckets_w_anchors_path,
output_shuffled_docs_path=shuffled_docs_path,
bucket_mapping_df_blocksize=self.config.bucket_mapping_blocksize,
parts_per_worker=self.config.parts_per_worker,
bucket_parts_per_worker=self.config.bucket_parts_per_worker,
)
print(f"Stage{stage_num} (False Postive Check): Shuffle docs complete!")
stage_num += 1
# jaccard comparision within buckets
print(
f"Stage{stage_num} (False Postive Check): Jaccard Similarity in Buckets"
)
jaccard_pairs_path = os.path.join(
self.config.cache_dir, "jaccard_similarity_results.parquet"
)
t0 = time.time()
with performance_report_if_with_ts_suffix(
self.config.profile_dir,
"jaccard-similarity",
):
jaccard_pairs_df = self.jaccard_compute.jaccard_compute(
shuffled_docs_path=shuffled_docs_path
)
jaccard_pairs_df.to_parquet(
jaccard_pairs_path,
write_index=False,
write_metadata_file=False,
)
self._logger.info(
f"Time taken for Jaccard Similarity = {time.time()-t0}s and output written at {jaccard_pairs_path}"
)
print(
f"Stage{stage_num} (False Postive Check): Jaccard Similarity in Buckets Complete!"
)
stage_num += 1
else:
# Map buckets to lower cardinality distribution
print(f"Stage{stage_num}: Starting LSH Buckets to Graph edgelist")
self.buckets_to_edges(buckets_df)
print(f"Stage{stage_num}: Starting LSH Buckets to Graph edgelist Complete!")
stage_num += 1
# Connected components across buckets
print(f"Stage{stage_num}: Connected Components across buckets")
cc_path = os.path.join(self.config.cache_dir, "connected_components.parquet")
self.connected_components.cc_workflow(cc_path)
print(f"Stage{stage_num}: Connected Components across buckets complete!")
stage_num += 1
return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False))
class BucketsToEdges:
"""
Maps buckets generated from LSH into an edgelist that
can be processed further by Connected Components to find duplicate
documents
"""
def __init__(
self,
cache_dir: str = None,
id_fields: Union[list, str] = "id",
str_id_name: str = "id",
bucket_field: str = "_bucket_id",
logger: Union[logging.LoggerAdapter, str] = "./",
profile_dir: Optional[str] = None,
):
"""
Parameters
----------
cache_dir: str or None
If specified, will compute & write the edgelist to a file
id_fields: list or str
id fields of documents in buckets_df
str_id_name: str
Ignored if there is a single id field. Multiple id fields
will be combined into a single id field with the given name.
bucket_field: str
Column denoting bucket ID
num_buckets: Number of bands/buckets to create from the minhash signature.
Hashes_per_signature = num_hashes / num_buckets
"""
self.cache_dir = cache_dir
self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
self.str_id_name = str_id_name if len(self.id_fields) > 1 else self.id_fields[0]
self.output_ids = [f"{self.str_id_name}_x", f"{self.str_id_name}_y"]
self.bucket_field = bucket_field
self.profile_dir = profile_dir
if isinstance(logger, str):
self._logger = create_logger(
rank=0,
log_file=os.path.join(logger, "Buckets_to_Edges.log"),
name="Buckets_to_Edges",
)
else:
self._logger = logger
@staticmethod
def _combine_multiple_ids(
input_df: cudf.DataFrame, input_id_fields: list, output_id_field: str
) -> cudf.DataFrame:
if output_id_field in input_df.columns:
raise ValueError(
f"Input df already contains column named: {output_id_field}"
)
output_df = input_df.copy()[input_df.columns.difference(input_id_fields)]
output_df[output_id_field] = input_df[input_id_fields[0]].astype(str)
for input_field in input_id_fields[1:]:
output_df[output_id_field] = output_df[output_id_field] = (
input_df[input_id_fields[0]].astype(str)
+ "-"
+ input_df[input_field].astype(str)
)
return output_df
def buckets_to_edges(
self,
buckets_df: cudf.DataFrame,
) -> cudf.DataFrame:
grouped_buckets = (
buckets_df.groupby(self.bucket_field)[self.str_id_name]
.agg(list)
.list.sort_values()
)
bucket_docs = grouped_buckets.to_arrow().to_pylist()
edges = []
# Create pairs of all documents within a bucket since they are near duplicates
# Effectively create a edge list of all near duplicate documents
for bucket_doc in bucket_docs:
edges.extend(pairwise(bucket_doc))
edges = pd.DataFrame(edges, columns=self.output_ids)
edges = pa.Table.from_pandas(edges)
result_df = cudf.DataFrame.from_arrow(edges)
del edges
result_df = result_df.drop_duplicates(self.output_ids).reset_index(drop=True)
result_df["jaccard"] = np.float32(1.0)
return result_df
def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
buckets_df = dataset.df
if len(self.id_fields) > 1:
buckets_df = buckets_df.map_partitions(
BucketsToEdges._combine_multiple_ids,
input_id_fields=self.id_fields,
output_id_field=self.str_id_name,
)
meta = [(output_id, str) for output_id in self.output_ids]
meta.append(("jaccard", np.float32))
edges_df = buckets_df.map_partitions(self.buckets_to_edges, meta=meta)
if self.cache_dir is None:
return DocumentDataset(edges_df)
write_path = os.path.join(self.cache_dir, "_edges.parquet")
if os.path.exists(write_path):
warnings.warn(
f"Output path {write_path} already exists and will be overwritten"
)
t0 = time.time()
with performance_report_if_with_ts_suffix(
self.profile_dir,
"bucket-to-edges",
):
edges_df.to_parquet(write_path, write_index=False, overwrite=True)
self._logger.info(
f"Time taken for Converted Buckets To Edgelist = {time.time() - t0}s and output written at {write_path}"
)
return DocumentDataset(
dask_cudf.read_parquet(write_path, split_row_groups=False)
)
class _MapBuckets:
"""
buckets to a logical partition by using a modified bin packing algorithm.
Combines buckets generated from LSH (typically high cardinality)
to more coarse lower cardinality bucket groups by mapping multiple buckets
to a logical partition using document length information and a modified bin
packing algorithm.
Only needed if running False Postive check to remove false positives.
"""
def __init__(
self,
id_fields: Union[list, str] = "id",
text_field: str = "text",
bucket_field: str = "_bucket_id",
num_anchors: int = 2,
logger: Union[logging.LoggerAdapter, str] = "./",
):
"""
id_fields: list or str
id fields of df
text_field: str = "text",
bucket_column: str = "bucket_column",
num_anchors: int = 2,
logger: Union[logging.LoggerAdapter, str] = "./",
"""
self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields
self.text_field = text_field
self.num_anchors = num_anchors
self.bucket_field = bucket_field
if isinstance(logger, str):
self._logger = create_logger(
rank=0,
log_file=os.path.join(logger, "Map_Buckets.log"),
name="Map_Buckets",
)
else:
self._logger = logger
@staticmethod
def _get_output_part_ids_with_approx_equal_sum(
bucket_text_bytes_df: cudf.DataFrame,
max_text_bytes_per_part: int,
buckets_column: str,
bytes_column: str,
output_partition_column: str,
) -> cudf.DataFrame:
"""
Create a output_series that maps the ser.index into `nparts`
so that the total sum of bucket_val_counts_df
for each output id are all most equal and
less than max_text_bytes_per_part
This is used downstream for creating equal output_ids
"""
sizes = bucket_text_bytes_df[bytes_column].values
bucket_output_ar = build_partition(
sizes=sizes.get(), max_size=max_text_bytes_per_part
)
df = cudf.DataFrame()
df[buckets_column] = bucket_text_bytes_df[buckets_column]
df[output_partition_column] = bucket_output_ar
return df
def _get_output_map_from_text_bytes_per_bucket(
self,
ddf_bk_text_bytes,
bytes_column,
output_partition_column="_output_partition_id",
):
# String bytes limit for cuDF
# https://github.com/rapidsai/cudf/issues/13733
max_text_bytes_per_part = int(np.iinfo(np.int32).max * 3)
self._logger.info(f"max_text_bytes_per_part = {max_text_bytes_per_part}")
# Increasing in an attempt to prevent hitting
# ulimits
output_map_df_meta = cudf.DataFrame(
{self.bucket_field: [0], output_partition_column: [1]}
)
output_map_df_meta = output_map_df_meta.astype(
{self.bucket_field: np.uint64, output_partition_column: np.int32}
)
output_map_df = ddf_bk_text_bytes.map_partitions(
_MapBuckets._get_output_part_ids_with_approx_equal_sum,
max_text_bytes_per_part=max_text_bytes_per_part,
buckets_column=self.bucket_field,
bytes_column=bytes_column,
output_partition_column=output_partition_column,
meta=output_map_df_meta,
)
output_map_df = output_map_df.persist()
self._logger.info(
f"Step 1 of output_map_df of len: {len(output_map_df)} computed"
)
lower_bounds = (
output_map_df[output_partition_column]
.map_partitions(lambda s: (s.max() + 1))
.compute()
)
lower_bounds = np.cumsum(lower_bounds)
def update_id(df, lower_bound):
df[output_partition_column] += lower_bound
return df
updated_parts = [
output_map_df.get_partition(i).map_partitions(
update_id, lower_bounds[i - 1]
)
for i in range(1, len(lower_bounds))
]
updated_parts.append(output_map_df.get_partition(0))
output_map_df = dask_cudf.concat(updated_parts)
output_map_df = output_map_df.persist()
self._logger.info(
f"All steps of output_map_df of len: {len(output_map_df)} computed"
)
return output_map_df
def _get_output_map_based_on_str_bytes(
self, buckets_df, documents_df, bytes_column="_text_bytes"
):
"""
Add output_partition_id to buckets_ddf
"""
documents_df = documents_df.copy()
documents_df[bytes_column] = documents_df[self.text_field].map_partitions(
lambda s: s.str.byte_count()
)
n_partitions = buckets_df.npartitions
documents_df = documents_df.drop(columns=[self.text_field]).repartition(
npartitions=n_partitions
)
buckets_df = buckets_df.merge(documents_df).repartition(
npartitions=n_partitions
)
del documents_df
ddf_bk_text_bytes, agg_df_len = get_agg_text_bytes_df(
df=buckets_df,
agg_column=self.bucket_field,
bytes_column=bytes_column,
n_partitions=n_partitions,
shuffle=True,
)
self._logger.info(f"Agg_df computed of length = {agg_df_len}")
del buckets_df
output_map_df = self._get_output_map_from_text_bytes_per_bucket(
ddf_bk_text_bytes=ddf_bk_text_bytes,
bytes_column=bytes_column,
)
return output_map_df
def _random_select_anchor(self, buckets_df, n=2):
"""
Randomly select `n` anchors from each bucket.
"""
buckets_df = buckets_df.copy()
buckets_df["_id_hash"] = buckets_df[self.id_fields].hash_values()
buckets_df = buckets_df.sort_values([self.bucket_field, "_id_hash"])
buckets_df["_order_in_bucket"] = buckets_df.groupby(
self.bucket_field
).cumcount()
buckets_df["is_anchor"] = buckets_df["_order_in_bucket"] < n
for i in range(0, n):
buckets_df[f"is_anchor_id_{i}"] = buckets_df["_order_in_bucket"] == i
buckets_df = buckets_df.drop(columns=["_id_hash", "_order_in_bucket"], axis=1)
buckets_df = buckets_df.reset_index(drop=True)
buckets_df = buckets_df[buckets_df.is_anchor]
return buckets_df
def _add_anchor_docs(self, buckets_df, num_anchors):
"""
Get anchor documents for each bucket.
"""
df_anchor_bk = self._random_select_anchor(buckets_df=buckets_df, n=num_anchors)
df_anchor_docs = None
for i in range(num_anchors):
df_anchor_bk_i = df_anchor_bk[df_anchor_bk[f"is_anchor_id_{i}"]][
[self.bucket_field] + self.id_fields
].reset_index(drop=True)
column_mapping = {id: f"anchor_{i}_{id}" for id in self.id_fields}
df_anchor_bk_i = df_anchor_bk_i.rename(columns=column_mapping)
if i == 0:
df_anchor_docs = df_anchor_bk_i
else:
df_anchor_docs = df_anchor_bk_i.merge(
df_anchor_docs, on=[self.bucket_field], how="inner"
)
df_anchor_docs_with_bk = buckets_df.merge(
df_anchor_docs, on=[self.bucket_field], how="inner"
)
return df_anchor_docs_with_bk
def map_buckets_with_anchors(
self,
documents_df: dask_cudf.DataFrame,
buckets_df: dask_cudf.DataFrame,
shuffle_type: Union[str, bool, None] = "tasks",
) -> dask_cudf.DataFrame:
"""
Get anchor docs with bucket info
Args:
input_data_paths: list of paths to input data
input_bucket_path: path to input buckets
text_ddf_blocksize: blocksize for text ddf
num_files: number of files to read
num_workers: number of workers
shuffle_type: type of shuffle to use
Returns:
ddf_anchor_docs_with_bk
"""
output_map_df = self._get_output_map_based_on_str_bytes(
buckets_df=buckets_df, documents_df=documents_df
)
ddf_anchor_docs_with_bk = buckets_df.map_partitions(
self._add_anchor_docs, num_anchors=self.num_anchors
)
self._logger.info("output_map_df is based on string bytes")
ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.merge(
output_map_df, on=self.bucket_field
)
# Bucket is no longer needed
ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.drop(
columns=[self.bucket_field]
)
# Below removes any duplicates lying around after dropping buckets
ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.map_partitions(
M.drop_duplicates,
meta=ddf_anchor_docs_with_bk._meta,
enforce_metadata=False,
transform_divisions=False,
align_dataframes=False,
)
ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.shuffle(
self.id_fields,
ignore_index=True,
shuffle_method=shuffle_type,
).map_partitions(
M.drop_duplicates,
meta=ddf_anchor_docs_with_bk._meta,
enforce_metadata=False,
transform_divisions=False,
align_dataframes=False,
)
del output_map_df
return ddf_anchor_docs_with_bk
class _Shuffle:
def __init__(
self,
id_fields: Union[str, list] = "id",
text_field: str = "text",
logger: Union[logging.LoggerAdapter, str] = "./",
profile_dir: str = None,
int_to_str_id: str = None,
):
if isinstance(logger, str):
self._logger = create_logger(
rank=0,
log_file=os.path.join(logger, "LSH.log"),
name="LSH",
)
else:
self._logger = logger
self.id_fields = id_fields
self.text_field = text_field
self.profile_dir = profile_dir
self.int_to_str_id = int_to_str_id
def shuffle_docs_on_buckets(
self,
documents_df: dask_cudf.DataFrame,
bucket_w_anchors_path: str,
output_shuffled_docs_path: str,
bucket_mapping_df_blocksize,
parts_per_worker: int = 1,
bucket_parts_per_worker: int = 8,
partition_on: str = "_output_partition_id",
):
ddf_anchor_docs_with_bk, bk_mapping = aggregated_anchor_docs_with_bk_read(
path=bucket_w_anchors_path,
blocksize=bucket_mapping_df_blocksize,
)