-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmerge.py
More file actions
177 lines (147 loc) · 5.9 KB
/
merge.py
File metadata and controls
177 lines (147 loc) · 5.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import logging
import pandas as pd
from datasets import load_from_disk, Dataset, concatenate_datasets
from datetime import datetime
from dataset_loader import load_dual_config_dataset
import argparse
import json
from pipelines import deduplicate_df
logger = logging.getLogger("dataset_merger")
logging.basicConfig(level=logging.INFO)
DATA_DIR = "preprocess_outputs/en-am/include"
def deduplicate_against_test(
ds: Dataset,
config: dict,
src_col: str,
tgt_col: str,
logger=None
) -> Dataset:
"""
Remove rows from ds where (src, tgt) pair exists in the test set.
"""
# Load test set using config
source_list, target_list = load_dual_config_dataset(
config["test_dataset"],
dataset_cache=config["download"].get("dataset_cache", "dataset_cache/"))
# Create a set of (src, tgt), src, and tgt from the test set
test_pairs = set(zip(source_list, target_list))
test_srcs = set(source_list)
test_tgts = set(target_list)
def not_in_test(example):
src = example[src_col]
tgt = example[tgt_col]
return (src, tgt) not in test_pairs and src not in test_srcs and tgt not in test_tgts
filtered_ds = ds.filter(not_in_test)
if logger:
logger.info(f"Removed {len(ds) - len(filtered_ds)} rows present in test set (by pair, src, or tgt)")
return filtered_ds
def deduplicate_hf_dataset(
ds: Dataset,
src_col: str = "Source",
tgt_col: str = "Target",
logger=None
) -> Dataset:
"""
Deduplicate a Hugging Face Dataset by:
1. Dropping rows where src == tgt
2. Dropping duplicate (src, tgt) pairs
3. Dropping duplicate src values (keep first occurrence)
4. Dropping duplicate tgt values (keep first occurrence)
"""
# Convert to pandas
df = ds.to_pandas()
before = df.shape[0]
# (1) Drop rows where src == tgt
df = df[df[src_col] != df[tgt_col]]
if logger:
logger.info(f"Step 1: Drop identical src==tgt rows → {len(df)} rows")
# (2) Drop duplicate (src, tgt) pairs
df = df.drop_duplicates(subset=[src_col, tgt_col])
if logger:
logger.info(f"Step 2: Drop duplicate (src, tgt) pairs → {len(df)} rows")
# (3) Drop duplicate sources
df = df.drop_duplicates(subset=[src_col])
if logger:
logger.info(f"Step 3: Drop duplicate sources → {len(df)} rows")
# (4) Drop duplicate targets
df = df.drop_duplicates(subset=[tgt_col])
after = df.shape[0]
if logger:
logger.info(f"Step 4: Drop duplicate targets → {after} rows (Removed {before - after})")
# Convert back to HF dataset
deduped = Dataset.from_pandas(df, preserve_index=False)
return deduped
def merge_and_deduplicate_filtered(data_dir, logger, config, src_col, tgt_col, dedup=True, dedup_against_test=True) -> Dataset:
"""
Merge and deduplicate all datasets in data_dir (no filtering).
"""
logger.info(f"🔎 Looking for datasets in: {data_dir}")
datasets = []
included_datasets = []
lang_pair = f"{src_col}-{tgt_col}"
merge_path = os.path.join(data_dir, lang_pair)
for root, dirs, files in os.walk(merge_path):
if "metadata.json" in files:
meta_path = os.path.join(root, "metadata.json")
try:
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
ds = load_from_disk(root)
logger.info(f"✅ Including {meta.get('dataset_name', root)} → {len(ds)} rows")
datasets.append(ds)
included_datasets.append({
"dataset_name": meta.get("dataset_name", root),
"rows": len(ds),
"quality_score": meta.get("quality_score")
})
except Exception as e:
logger.warning(f"⚠ Error reading {meta_path}: {e}")
if not datasets:
logger.warning(f"No datasets found for merge in {data_dir}")
return None
# Merge all datasets
merged = concatenate_datasets(datasets)
merged_size = len(merged)
logger.info(f"📦 Merged dataset size: {merged_size} rows")
if dedup:
logger.info("🧹 Starting deduplication process...")
deduped = deduplicate_hf_dataset(merged, src_col=src_col, tgt_col=tgt_col, logger=logger)
deduped_size = len(deduped)
logger.info(f"✨ Deduplicated dataset size: {deduped_size} rows")
merged = deduped
else:
deduped_size = merged_size
logger.info("⚠ Deduplication skipped as per configuration.")
if dedup_against_test:
logger.info("🧹 Deduplicating against test set...")
test_config = config.get("test_dataset")
if test_config:
merged = deduplicate_against_test(
merged,
config=config,
src_col=src_col,
tgt_col=tgt_col,
# logger=logger
)
final_size = len(merged)
logger.info(f"✨ Final dataset size after test deduplication: {final_size} rows")
else:
logger.warning("⚠ No test set configuration found; skipping test deduplication.")
merged_path = os.path.join(
data_dir, f"merged_{src_col}-{tgt_col}"
)
merged.save_to_disk(merged_path)
logger.info(f"💾 Saved merged dataset → {merged_path}")
merged_metadata = {
"dataset_name": f"merged",
"lang_pair": f"{src_col}-{tgt_col}",
"included_datasets": included_datasets,
"total_rows_before_dedup": merged_size,
"total_rows_after_dedup": deduped_size,
"processed_at": datetime.utcnow().isoformat()
}
with open(os.path.join(merged_path, "metadata.json"), "w", encoding="utf-8") as f:
json.dump(merged_metadata, f, ensure_ascii=False, indent=2)
logger.info(f"📝 Metadata written for merged dataset → {os.path.join(merged_path, 'metadata.json')}")
return merged