-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_wvs_data.py
More file actions
98 lines (81 loc) · 3.47 KB
/
plot_wvs_data.py
File metadata and controls
98 lines (81 loc) · 3.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import argparse
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
from multicultural_alignment.constants import COUNTRY_MAP, OUTPUT_DIR, PLOT_DIR
def _remove_leading_zero(val: float) -> str:
formatted = f"{val:.2f}"
if formatted.startswith("0."):
return formatted[1:]
elif formatted.startswith("-0."):
return "-" + formatted[2:]
return formatted
def main(label_size: int = 15, font_scale: float = 1.8) -> None:
sns.set_theme(font_scale=font_scale, style="white")
wvs_data_country = pl.read_csv(OUTPUT_DIR / "ground_truth_every_country.csv").rename({"lnge_iso": "language"})
wvs_data_global = pl.read_csv(OUTPUT_DIR / "ground_truth_global.csv").with_columns(
pl.lit("global").alias("cntry_an"), pl.lit("global").alias("language")
)
wvs_data = pl.concat([wvs_data_country, wvs_data_global.select(wvs_data_country.columns)]).with_columns(
pl.col("cntry_an").replace(COUNTRY_MAP)
)
wvs_language_data = pl.read_csv(OUTPUT_DIR / "ground_truth_per_language.csv")
plot_country_correlations(wvs_data, font_size=label_size)
plot_country_distributions(wvs_data)
plot_language_correlations(wvs_language_data)
def plot_language_correlations(wvs_language_data: pl.DataFrame):
pivoted = wvs_language_data.sort(by="lnge_iso").pivot(index="question_key", on="lnge_iso", values="pro_score")
corr_matrix = pivoted.to_pandas().corr(numeric_only=True, min_periods=1)
plt.figure(figsize=(12, 10))
sns.heatmap(
corr_matrix,
annot=False,
cmap="coolwarm",
vmin=-1,
vmax=1,
center=0,
square=True,
linewidths=0.5,
cbar_kws={"shrink": 0.5},
)
plt.title("Correlation Heatmap of Pro Scores by Language")
plt.tight_layout()
plt.savefig(PLOT_DIR / "wvs_language_correlations.png", bbox_inches="tight")
def plot_country_correlations(wvs_data: pl.DataFrame, font_size: int = 15, annotate: bool = True):
pivoted = wvs_data.pivot(index="question_key", on="cntry_an", values="pro_score", sort_columns=True)
corr_matrix = pivoted.to_pandas().corr(numeric_only=True, min_periods=1)
# Create a mask for the upper triangle
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
# Create a heatmap
plt.figure(figsize=(12, 10))
annot_labels = np.vectorize(_remove_leading_zero)(corr_matrix)
heatmap = sns.heatmap(
corr_matrix,
mask=mask,
annot=annot_labels if annotate else False,
cmap="coolwarm",
vmin=-1,
vmax=1,
center=0,
square=True,
linewidths=0.75,
fmt="",
cbar_kws={"shrink": 1},
)
heatmap.set_xticklabels(heatmap.get_xmajorticklabels(), fontsize=font_size)
heatmap.set_yticklabels(heatmap.get_ymajorticklabels(), fontsize=font_size)
plt.tight_layout()
plt.savefig(PLOT_DIR / "country_correlations.png", bbox_inches="tight")
# clear the plot
plt.clf()
def plot_country_distributions(wvs_data) -> None:
sns.catplot(data=wvs_data, x="cntry_an", y="pro_score", kind="violin", col="language", sharex=False)
plt.savefig(PLOT_DIR / "country_distributions.png", bbox_inches="tight")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--label_size", type=int, default=15)
parser.add_argument("--font_scale", type=float, default=1.8)
args = parser.parse_args()
main(label_size=args.label_size, font_scale=args.font_scale)