diff --git a/README.md b/README.md index 8b055cf..a5a817f 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,31 @@ # Microimpute -Microimpute enables variable imputation through a variety of statistical methods. By providing a consistent interface across different imputation techniques, it allows researchers and data scientists to easily compare and benchmark different approaches using quantile loss and log loss calculations to determine the method providing most accurate results. +Microimpute is a Python package for imputing variables from one survey dataset onto another. It wraps five imputation methods behind a common interface so you can benchmark them on your data and pick the one that works best, rather than defaulting to a single approach. -## Features +## Methods -### Multiple imputation methods -- **Statistical Matching**: Distance-based matching for finding similar observations -- **Ordinary Least Squares (OLS)**: Linear regression-based imputation -- **Quantile Regression**: Distribution-aware regression imputation -- **Quantile Random Forests (QRF)**: Non-parametric forest-based approach -- **Mixture Density Networks (MDN)**: Neural network with Gaussian mixture approximation head +- **Statistical Matching**: distance-based matching to find similar donor observations +- **Ordinary Least Squares (OLS)**: linear regression imputation +- **Quantile Regression**: models conditional quantiles instead of the conditional mean +- **Quantile Random Forests (QRF)**: non-parametric, tree-based quantile estimation +- **Mixture Density Networks (MDN)**: neural network with a Gaussian mixture output -### Automated method selection -- **AutoImpute**: Automatically compares and selects the best imputation method for your data -- **Cross-validation**: Built-in evaluation using quantile loss (numerical) and log loss (categorical) -- **Variable type support**: Handles numerical, categorical, and boolean variables +## Autoimpute -### Developer-friendly design -- **Consistent API**: Standardized `fit()` and `predict()` interface across all models -- **Extensible architecture**: Easy to implement custom imputation methods -- **Weighted data handling**: Preserve data distributions with sample weights -- **Input validation**: Automatic parameter and data validation +The `autoimpute` function tunes hyperparameters, runs cross-validation across all five methods, and selects the best performer based on quantile loss (for numerical targets) or log loss (for categorical targets). It handles numerical, categorical, and boolean variables. -### Interactive dashboard -- **Visual exploration**: Analyze imputation results through interactive charts at https://microimpute-dashboard.vercel.app/ -- **GitHub integration**: Load artifacts directly from CI/CD workflows -- **Multiple data sources**: File upload, URL loading and sample data +## API + +All models follow a `fit()` / `predict()` interface. The package supports sample weights to account for survey design, and validates inputs automatically. Adding a custom imputation method is straightforward since new models just need to implement the same interface. + +## Documentation and paper + +- [Documentation](https://policyengine.github.io/microimpute/) with examples and interactive notebooks +- [Paper](https://github.com/PolicyEngine/microimpute/blob/main/paper/main.pdf) presenting microimpute and demonstrating it for SCF-to-CPS net worth imputation + +## Dashboard + +An interactive dashboard for exploring imputation results is available at https://microimpute-dashboard.vercel.app/. It supports file upload, URL loading, direct GitHub artifact integration, and sample data. ## Installation @@ -33,16 +33,12 @@ Microimpute enables variable imputation through a variety of statistical methods pip install microimpute ``` -For image export functionality (PNG/JPG), install with: +For image export (PNG/JPG): ```bash pip install microimpute[images] ``` -## Examples and documentation - -For detailed examples and interactive notebooks, see the [documentation](https://policyengine.github.io/microimpute/). - ## Contributing -Contributions are welcome to the project. Please feel free to submit a Pull Request with your improvements. +Pull requests are welcome. If you find a bug or have a feature idea, open an issue or submit a PR. diff --git a/changelog.d/maria-paper_review.fixed.md b/changelog.d/maria-paper_review.fixed.md new file mode 100644 index 0000000..2354004 --- /dev/null +++ b/changelog.d/maria-paper_review.fixed.md @@ -0,0 +1 @@ +Updated paper and package documentation with latest changes. Fix pandas 2.x compatibility for Arrow string types and dtype checks. diff --git a/docs/_toc.yml b/docs/_toc.yml index dc6d677..cffb6d3 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -38,5 +38,3 @@ parts: - caption: Use cases chapters: - file: use_cases/index - sections: - - file: use_cases/scf_to_cps/imputing-from-scf-to-cps diff --git a/docs/autoimpute/autoimpute.ipynb b/docs/autoimpute/autoimpute.ipynb index 988a9aa..59fde5a 100644 --- a/docs/autoimpute/autoimpute.ipynb +++ b/docs/autoimpute/autoimpute.ipynb @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -293,13 +293,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0fb0cd8da8ad4d65bf84cb943ebaf1c8", + "model_id": "08c1a0e8ab514a84a14be53225051e6b", "version_major": 2, "version_minor": 0 }, @@ -315,18 +315,11 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 7.3s finished\n", "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 1.0s finished\n", "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "R callback write-console: Warning: \n", - "R callback write-console: failed to download mirrors file (cannot open URL 'https://cran.r-project.org/CRAN_mirrors.csv'); using local file '/opt/homebrew/Cellar/r/4.5.0/lib/R/doc/CRAN_mirrors.csv'\n", - " \n", - "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 1.6min\n", - "R callback write-console: Warning: \n", - "R callback write-console: failed to download mirrors file (cannot open URL 'https://cran.r-project.org/CRAN_mirrors.csv'); using local file '/opt/homebrew/Cellar/r/4.5.0/lib/R/doc/CRAN_mirrors.csv'\n", - " \n", - "R callback write-console: Warning: \n", - "R callback write-console: failed to download mirrors file (cannot open URL 'https://cran.r-project.org/CRAN_mirrors.csv'); using local file '/opt/homebrew/Cellar/r/4.5.0/lib/R/doc/CRAN_mirrors.csv'\n", - " \n" + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.8s finished\n" ] }, { @@ -341,7 +334,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 5.2min finished\n" + "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 1.1s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.2s finished\n" ] } ], @@ -375,7 +369,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -384,17 +378,17 @@ "text": [ "Cross-validation results for different imputation methods:\n", "\n", - "Model: QRF\n", - "quantile loss results: 0.0155\n", - "\n", "Model: OLS\n", - "quantile loss results: 0.0124\n", + "quantile loss results: 0.0126\n", "\n", "Model: QuantReg\n", - "quantile loss results: 0.0125\n", + "quantile loss results: 0.0126\n", + "\n", + "Model: QRF\n", + "quantile loss results: 0.0160\n", "\n", "Model: Matching\n", - "quantile loss results: 0.0231\n" + "quantile loss results: 0.0233\n" ] } ], @@ -415,14 +409,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Best performing method: \n" + "Best performing method: \n" ] } ], @@ -442,7 +436,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -454,307 +448,285 @@ "data": [ { "alignmentgroup": "True", - "hovertemplate": "Method=QRF
Quantiles=%{x}
Test Quantile loss=%{y}", - "legendgroup": "QRF", + "error_y": { + "array": [ + 0.0000367764891179606, + 0.00005051876144414505, + 0.00018829257114425592, + 0.00037848750942190044, + 0.0005477958061141973, + 0.0006245017815431527, + 0.0006789805698947316, + 0.0007155653018236752, + 0.0007295643674175897, + 0.0007569494312623643, + 0.0008639491258964105, + 0.0009801026362712217, + 0.0010518961452707697, + 0.0010297648072906266, + 0.0009590446159744011, + 0.0008660289152915666, + 0.0007846310723713395, + 0.00065956869485099, + 0.0003346140844985172 + ] + }, + "hovertemplate": "Method=OLS
Quantiles=%{x}
Quantile loss=%{y}", + "legendgroup": "OLS", "marker": { - "color": "#636EFA", + "color": "#88CCEE", "pattern": { "shape": "" } }, - "name": "QRF", - "offsetgroup": "QRF", + "name": "OLS", + "offsetgroup": "OLS", "orientation": "v", "showlegend": true, "textposition": "auto", "type": "bar", "x": [ "0.05", - "0.05", - "0.1", "0.1", "0.15", - "0.15", - "0.2", "0.2", "0.25", - "0.25", "0.3", - "0.3", - "0.35", "0.35", "0.4", - "0.4", - "0.45", "0.45", "0.5", - "0.5", "0.55", - "0.55", - "0.6", "0.6", "0.65", - "0.65", "0.7", - "0.7", - "0.75", "0.75", "0.8", - "0.8", "0.85", - "0.85", - "0.9", "0.9", - "0.95", "0.95" ], "xaxis": "x", "y": [ - 0.004631754564085741, - 0.004631754564085741, - 0.007341856096256823, - 0.007341856096256823, - 0.010956305360759139, - 0.010956305360759139, - 0.013853899253214152, - 0.013853899253214152, - 0.016765778862585467, - 0.016765778862585467, - 0.018170938524875734, - 0.018170938524875734, - 0.020431459700043803, - 0.020431459700043803, - 0.021268421552075877, - 0.021268421552075877, - 0.02152068721079699, - 0.02152068721079699, - 0.02150832954854852, - 0.02150832954854852, - 0.019861473819837334, - 0.019861473819837334, - 0.01995636925773867, - 0.01995636925773867, - 0.01894943399516724, - 0.01894943399516724, - 0.017987324791786673, - 0.017987324791786673, - 0.016858787546663905, - 0.016858787546663905, - 0.014928394018716465, - 0.014928394018716465, - 0.012899519430338449, - 0.012899519430338449, - 0.009831864043753134, - 0.009831864043753134, - 0.006811614968932836, - 0.006811614968932836 + 0.0037857585632115076, + 0.0064035381731015984, + 0.008682116948116709, + 0.010637028223496603, + 0.01228354449019281, + 0.013704743173963509, + 0.014829116089757849, + 0.01572984140710315, + 0.01635768192573994, + 0.016757835444730255, + 0.01698876716674359, + 0.016908544804384324, + 0.016466981551585432, + 0.01573902508984663, + 0.014655697638186978, + 0.013182752515157806, + 0.011274496133844888, + 0.008723043411546087, + 0.005424527941128667 ], "yaxis": "y" }, { "alignmentgroup": "True", - "hovertemplate": "Method=OLS
Quantiles=%{x}
Test Quantile loss=%{y}", - "legendgroup": "OLS", + "error_y": { + "array": [ + 0.00011053862485353776, + 0.00013152308789946208, + 0.00007328304807248766, + 0.00013761220727722502, + 0.000225390831825253, + 0.0002387781342265084, + 0.0003690300025792442, + 0.0004971448458544792, + 0.0006841372848715391, + 0.0007855762966044914, + 0.0008657286666497124, + 0.0009609612674453248, + 0.0010760799108166322, + 0.0011252152270869854, + 0.0012113725779785698, + 0.0011552041991586672, + 0.0008205546613073608, + 0.0006875490202660445, + 0.000559862252788657 + ] + }, + "hovertemplate": "Method=QuantReg
Quantiles=%{x}
Quantile loss=%{y}", + "legendgroup": "QuantReg", "marker": { - "color": "#EF553B", + "color": "#CC6677", "pattern": { "shape": "" } }, - "name": "OLS", - "offsetgroup": "OLS", + "name": "QuantReg", + "offsetgroup": "QuantReg", "orientation": "v", "showlegend": true, "textposition": "auto", "type": "bar", "x": [ - "0.05", "0.05", "0.1", - "0.1", - "0.15", "0.15", "0.2", - "0.2", - "0.25", "0.25", "0.3", - "0.3", "0.35", - "0.35", - "0.4", "0.4", "0.45", - "0.45", "0.5", - "0.5", - "0.55", "0.55", "0.6", - "0.6", "0.65", - "0.65", - "0.7", "0.7", "0.75", - "0.75", - "0.8", "0.8", "0.85", - "0.85", "0.9", - "0.9", - "0.95", "0.95" ], "xaxis": "x", "y": [ - 0.003808647070946412, - 0.003808647070946412, - 0.006490379484883273, - 0.006490379484883273, - 0.008807739971118509, - 0.008807739971118509, - 0.010699017371221815, - 0.010699017371221815, - 0.012339891153334075, - 0.012339891153334075, - 0.013697475098254898, - 0.013697475098254898, - 0.014736873094883992, - 0.014736873094883992, - 0.015550928634621153, - 0.015550928634621153, - 0.016134574039820177, - 0.016134574039820177, - 0.016485493155635602, - 0.016485493155635602, - 0.016654564673178742, - 0.016654564673178742, - 0.016593511197972737, - 0.016593511197972737, - 0.016203732622264988, - 0.016203732622264988, - 0.015451587766895878, - 0.015451587766895878, - 0.014415220561443473, - 0.014415220561443473, - 0.013044914508801573, - 0.013044914508801573, - 0.011156959998495038, - 0.011156959998495038, - 0.008585027733878748, - 0.008585027733878748, - 0.005229999830606041, - 0.005229999830606041 + 0.003648297363886403, + 0.006220739206137793, + 0.008560805562643258, + 0.010594467066356217, + 0.012137407181277772, + 0.013580247763948514, + 0.014709634995905748, + 0.015584271813924223, + 0.01615515457661314, + 0.01660627579116468, + 0.016924948899226498, + 0.017013821272759678, + 0.016792947494266755, + 0.01604088566988471, + 0.01484899400756596, + 0.01347829835941078, + 0.011713354867882659, + 0.008831687333185652, + 0.005731739757859993 ], "yaxis": "y" }, { "alignmentgroup": "True", - "hovertemplate": "Method=QuantReg
Quantiles=%{x}
Test Quantile loss=%{y}", - "legendgroup": "QuantReg", + "error_y": { + "array": [ + 0.0002099742280199984, + 0.0005577204009778338, + 0.0005913787202045731, + 0.0006465761742329032, + 0.0006427657425320314, + 0.0009665112262689732, + 0.0004039564380447954, + 0.00046810196513831623, + 0.000920067680617316, + 0.0004569072602015561, + 0.0007209820742233905, + 0.0009019129537077271, + 0.000654822071277212, + 0.0002964891583124209, + 0.0008697756394143753, + 0.000869354334057907, + 0.0007434922481074839, + 0.0008949128843092795, + 0.0005736016727487648 + ] + }, + "hovertemplate": "Method=QRF
Quantiles=%{x}
Quantile loss=%{y}", + "legendgroup": "QRF", "marker": { - "color": "#00CC96", + "color": "#DDCC77", "pattern": { "shape": "" } }, - "name": "QuantReg", - "offsetgroup": "QuantReg", + "name": "QRF", + "offsetgroup": "QRF", "orientation": "v", "showlegend": true, "textposition": "auto", "type": "bar", "x": [ "0.05", - "0.05", - "0.1", "0.1", "0.15", - "0.15", - "0.2", "0.2", "0.25", - "0.25", "0.3", - "0.3", - "0.35", "0.35", "0.4", - "0.4", "0.45", - "0.45", - "0.5", "0.5", "0.55", - "0.55", - "0.6", "0.6", "0.65", - "0.65", "0.7", - "0.7", - "0.75", "0.75", "0.8", - "0.8", "0.85", - "0.85", - "0.9", "0.9", - "0.95", "0.95" ], "xaxis": "x", "y": [ - 0.0037797191400323325, - 0.0037797191400323325, - 0.0065355258283926075, - 0.0065355258283926075, - 0.00877891950537431, - 0.00877891950537431, - 0.010730211151855228, - 0.010730211151855228, - 0.012356819692361727, - 0.012356819692361727, - 0.013725116901874718, - 0.013725116901874718, - 0.014793067048809186, - 0.014793067048809186, - 0.015515637986475386, - 0.015515637986475386, - 0.015911871896009084, - 0.015911871896009084, - 0.016247057308316024, - 0.016247057308316024, - 0.016486402970914003, - 0.016486402970914003, - 0.016660653134295333, - 0.016660653134295333, - 0.016356605390546602, - 0.016356605390546602, - 0.015780547245353785, - 0.015780547245353785, - 0.014807938631402215, - 0.014807938631402215, - 0.01340137722886329, - 0.01340137722886329, - 0.01153643264671414, - 0.01153643264671414, - 0.009016559672302424, - 0.009016559672302424, - 0.005650809365013223, - 0.005650809365013223 + 0.004958981353614306, + 0.007709876078177458, + 0.010610160150823023, + 0.013147576784275966, + 0.016230808947783, + 0.01765523354548008, + 0.020390817639546265, + 0.021631607060529953, + 0.022142205928372478, + 0.022393833653832424, + 0.021357150137205804, + 0.020907723382855792, + 0.019930544841328238, + 0.019095495355673266, + 0.01816255322567264, + 0.016320905888505866, + 0.01317617842711581, + 0.010650497941205097, + 0.0072062531829631795 ], "yaxis": "y" }, { "alignmentgroup": "True", - "hovertemplate": "Method=Matching
Quantiles=%{x}
Test Quantile loss=%{y}", + "error_y": { + "array": [ + 0.0019368529298846857, + 0.001756890524910662, + 0.0015769649797891588, + 0.0013970905356792654, + 0.0012172898458174732, + 0.0010376012524176477, + 0.0008580951787241413, + 0.0006789164152700594, + 0.0005004166749212909, + 0.0003237211525341078, + 0.00015512287742088833, + 0.00009066430412933424, + 0.0002393033878359685, + 0.0004137649009508685, + 0.0005916298348888364, + 0.0007705449475824869, + 0.0009499170265880006, + 0.0011295283905246845, + 0.0013092805654417116 + ] + }, + "hovertemplate": "Method=Matching
Quantiles=%{x}
Quantile loss=%{y}", "legendgroup": "Matching", "marker": { - "color": "#AB63FA", + "color": "#117733", "pattern": { "shape": "" } @@ -766,85 +738,47 @@ "textposition": "auto", "type": "bar", "x": [ - "0.05", "0.05", "0.1", - "0.1", - "0.15", "0.15", "0.2", - "0.2", "0.25", - "0.25", - "0.3", "0.3", "0.35", - "0.35", "0.4", - "0.4", - "0.45", "0.45", "0.5", - "0.5", "0.55", - "0.55", - "0.6", "0.6", "0.65", - "0.65", - "0.7", "0.7", "0.75", - "0.75", "0.8", - "0.8", - "0.85", "0.85", "0.9", - "0.9", - "0.95", "0.95" ], "xaxis": "x", "y": [ - 0.023393616982767946, - 0.023393616982767946, - 0.02335740713728864, - 0.02335740713728864, - 0.023321197291809342, - 0.023321197291809342, - 0.02328498744633005, - 0.02328498744633005, - 0.023248777600850746, - 0.023248777600850746, - 0.02321256775537145, - 0.02321256775537145, - 0.023176357909892153, - 0.023176357909892153, - 0.023140148064412853, - 0.023140148064412853, - 0.023103938218933556, - 0.023103938218933556, - 0.02306772837345426, - 0.02306772837345426, - 0.02303151852797496, - 0.02303151852797496, - 0.022995308682495663, - 0.022995308682495663, - 0.022959098837016367, - 0.022959098837016367, - 0.022922888991537067, - 0.022922888991537067, - 0.02288667914605777, - 0.02288667914605777, - 0.022850469300578474, - 0.022850469300578474, - 0.022814259455099174, - 0.022814259455099174, - 0.022778049609619877, - 0.022778049609619877, - 0.02274183976414058, - 0.02274183976414058 + 0.02420659950459554, + 0.02410222421800216, + 0.023997848931408782, + 0.023893473644815394, + 0.023789098358222013, + 0.023684723071628624, + 0.023580347785035247, + 0.023475972498441865, + 0.02337159721184848, + 0.02326722192525509, + 0.02316284663866171, + 0.023058471352068326, + 0.022954096065474938, + 0.02284972077888156, + 0.022745345492288172, + 0.022640970205694794, + 0.02253659491910141, + 0.022432219632508025, + 0.022327844345914644 ], "yaxis": "y" } @@ -858,51 +792,51 @@ }, "tracegroupgap": 0 }, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "shapes": [ { "line": { - "color": "#636EFA", + "color": "#88CCEE", "dash": "dot", "width": 2 }, - "name": "QRF Mean", + "name": "OLS Mean", "type": "line", "x0": -0.5, "x1": 18.5, - "y0": 0.015501800660325096, - "y1": 0.015501800660325096 + "y0": 0.012554475825886228, + "y1": 0.012554475825886228 }, { "line": { - "color": "#EF553B", + "color": "#CC6677", "dash": "dot", "width": 2 }, - "name": "OLS Mean", + "name": "QuantReg Mean", "type": "line", "x0": -0.5, "x1": 18.5, - "y0": 0.012425607261487217, - "y1": 0.012425607261487217 + "y0": 0.01258810415704739, + "y1": 0.01258810415704739 }, { "line": { - "color": "#00CC96", + "color": "#DDCC77", "dash": "dot", "width": 2 }, - "name": "QuantReg Mean", + "name": "QRF Mean", "type": "line", "x0": -0.5, "x1": 18.5, - "y0": 0.012530066986573978, - "y1": 0.012530066986573978 + "y0": 0.01598307386973477, + "y1": 0.01598307386973477 }, { "line": { - "color": "#AB63FA", + "color": "#117733", "dash": "dot", "width": 2 }, @@ -910,8 +844,8 @@ "type": "line", "x0": -0.5, "x1": 18.5, - "y0": 0.02306772837345426, - "y1": 0.02306772837345426 + "y0": 0.023267221925255092, + "y1": 0.023267221925255092 } ], "template": { @@ -1743,7 +1677,11 @@ 0, 1 ], + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", "showgrid": false, + "showline": true, "title": { "font": { "size": 12 @@ -1758,12 +1696,16 @@ 0, 1 ], - "showgrid": false, + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true, "title": { "font": { "size": 12 }, - "text": "Test Quantile loss" + "text": "Quantile loss" }, "zeroline": false } @@ -1806,7 +1748,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -1844,28 +1786,28 @@ " \n", " \n", " 0\n", - " 0.020909\n", - " 0.036174\n", + " 0.015232\n", + " 0.013791\n", " \n", " \n", " 1\n", - " 0.004626\n", - " -0.021702\n", + " -0.015421\n", + " -0.006742\n", " \n", " \n", " 2\n", - " -0.018766\n", - " -0.004827\n", + " -0.005829\n", + " 0.008123\n", " \n", " \n", " 3\n", - " -0.025080\n", - " -0.045854\n", + " -0.027782\n", + " -0.007257\n", " \n", " \n", " 4\n", - " 0.000522\n", - " 0.021657\n", + " -0.029530\n", + " -0.002555\n", " \n", " \n", "\n", @@ -1873,14 +1815,14 @@ ], "text/plain": [ " s1 s4\n", - "0 0.020909 0.036174\n", - "1 0.004626 -0.021702\n", - "2 -0.018766 -0.004827\n", - "3 -0.025080 -0.045854\n", - "4 0.000522 0.021657" + "0 0.015232 0.013791\n", + "1 -0.015421 -0.006742\n", + "2 -0.005829 0.008123\n", + "3 -0.027782 -0.007257\n", + "4 -0.029530 -0.002555" ] }, - "execution_count": 23, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -1896,7 +1838,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -1942,68 +1884,68 @@ " \n", " \n", " 0\n", - " 0.085299\n", + " 0.063504\n", " 0.050680\n", - " 0.044451\n", - " -0.005670\n", - " -0.034194\n", - " -0.032356\n", - " 0.002861\n", - " -0.025930\n", - " 0.020909\n", - " 0.036174\n", + " -0.001895\n", + " 0.066629\n", + " 0.108914\n", + " 0.022869\n", + " -0.035816\n", + " 0.003064\n", + " 0.015232\n", + " 0.013791\n", " \n", " \n", " 1\n", - " 0.005383\n", + " -0.070900\n", " -0.044642\n", - " -0.036385\n", - " 0.021872\n", - " 0.015596\n", - " 0.008142\n", - " -0.031988\n", - " -0.046641\n", - " 0.004626\n", - " -0.021702\n", + " 0.039062\n", + " -0.033213\n", + " -0.034508\n", + " -0.024993\n", + " 0.067737\n", + " -0.013504\n", + " -0.015421\n", + " -0.006742\n", " \n", " \n", " 2\n", - " -0.045472\n", + " 0.005383\n", " 0.050680\n", - " -0.047163\n", - " -0.015999\n", - " -0.024800\n", - " 0.000779\n", - " -0.062917\n", - " -0.038357\n", - " -0.018766\n", - " -0.004827\n", + " -0.001895\n", + " 0.008101\n", + " -0.015719\n", + " -0.002903\n", + " 0.038394\n", + " -0.013504\n", + " -0.005829\n", + " 0.008123\n", " \n", " \n", " 3\n", - " -0.096328\n", - " -0.044642\n", - " -0.083808\n", - " 0.008101\n", - " -0.090561\n", - " -0.013948\n", - " -0.062917\n", - " -0.034215\n", - " -0.025080\n", - " -0.045854\n", + " -0.085430\n", + " 0.050680\n", + " -0.022373\n", + " 0.001215\n", + " -0.026366\n", + " 0.015505\n", + " -0.072133\n", + " -0.017646\n", + " -0.027782\n", + " -0.007257\n", " \n", " \n", " 4\n", - " 0.027178\n", + " -0.067268\n", " 0.050680\n", - " 0.017506\n", - " -0.033213\n", - " 0.045972\n", - " -0.065491\n", - " -0.096435\n", - " -0.059067\n", - " 0.000522\n", - " 0.021657\n", + " -0.012673\n", + " -0.040099\n", + " 0.004636\n", + " -0.058127\n", + " 0.019196\n", + " -0.034215\n", + " -0.029530\n", + " -0.002555\n", " \n", " \n", "\n", @@ -2011,14 +1953,14 @@ ], "text/plain": [ " age sex bmi bp s2 s3 s5 s6 s1 s4\n", - "0 0.085299 0.050680 0.044451 -0.005670 -0.034194 -0.032356 0.002861 -0.025930 0.020909 0.036174\n", - "1 0.005383 -0.044642 -0.036385 0.021872 0.015596 0.008142 -0.031988 -0.046641 0.004626 -0.021702\n", - "2 -0.045472 0.050680 -0.047163 -0.015999 -0.024800 0.000779 -0.062917 -0.038357 -0.018766 -0.004827\n", - "3 -0.096328 -0.044642 -0.083808 0.008101 -0.090561 -0.013948 -0.062917 -0.034215 -0.025080 -0.045854\n", - "4 0.027178 0.050680 0.017506 -0.033213 0.045972 -0.065491 -0.096435 -0.059067 0.000522 0.021657" + "0 0.063504 0.050680 -0.001895 0.066629 0.108914 0.022869 -0.035816 0.003064 0.015232 0.013791\n", + "1 -0.070900 -0.044642 0.039062 -0.033213 -0.034508 -0.024993 0.067737 -0.013504 -0.015421 -0.006742\n", + "2 0.005383 0.050680 -0.001895 0.008101 -0.015719 -0.002903 0.038394 -0.013504 -0.005829 0.008123\n", + "3 -0.085430 0.050680 -0.022373 0.001215 -0.026366 0.015505 -0.072133 -0.017646 -0.027782 -0.007257\n", + "4 -0.067268 0.050680 -0.012673 -0.040099 0.004636 -0.058127 0.019196 -0.034215 -0.029530 -0.002555" ] }, - "execution_count": 24, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -2040,7 +1982,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -2194,138 +2136,138 @@ 132 ], "y": [ - -0.04559945128264711, - 0.003934851612593237, - -0.04009563984984263, - -0.10338947132709418, - -0.007072771253015731, + 0.09061988167926385, + -0.012576582685820214, -0.004320865536613489, - 0.017694380194604446, - 0.024574144485610048, - 0.0342058144930179, - -0.0029449126784123676, - -0.0029449126784123676, - -0.04284754556624487, + -0.037343734133440394, + -0.015328488402222454, + -0.04009563984984263, -0.038719686991641515, - -0.04422349842444599, - -0.0249601584096303, - -0.04972730985725048, - 0.020446285911006685, + -0.0579830270064572, + 0.0025588987543921156, + -0.046975404140848234, 0.0025588987543921156, - -0.04147159270804375, 0.01219056876179996, - -0.0318399227006359, - 0.06447677737344255, + 0.020446285911006685, + -0.007072771253015731, + -0.08962994274508297, + 0.01219056876179996, -0.04422349842444599, - -0.011200629827619093, - -0.02083229983502694, + 0.005310804470794357, -0.005696818394814609, - -0.033215875558837024, 0.03833367306762126, + -0.011200629827619093, -0.12678066991651324, - -0.06761469701386505, - 0.0025588987543921156, 0.010814615903598841, -0.05935897986465832, + -0.051103262715451604, + 0.06310082451524143, -0.0318399227006359, -0.030463969842434782, - -0.00019300696201012598, - 0.06034891879883919, - 0.001182945896190995, - -0.019456346976825818, - 0.0080627101871966, + 0.005310804470794357, + 0.02182223876920781, + 0.041085578784023497, + -0.04559945128264711, + 0.014942474478202204, + 0.06447677737344255, + 0.02319819162740893, 0.04383748450042574, - -0.033215875558837024, - -0.0029449126784123676, - -0.019456346976825818, - 0.01219056876179996, + 0.03833367306762126, + -0.005696818394814609, + 0.00943866304539772, -0.015328488402222454, - -0.0318399227006359, - -0.030463969842434782, + -0.034591828417038145, + 0.0342058144930179, 0.020446285911006685, 0.014942474478202204, - -0.02083229983502694, - 0.027326050202012293, - 0.02182223876920781, - -0.09650970703608859, - 0.02319819162740893, - -0.07587041416307178, - -0.07999827273767514, - -0.051103262715451604, + -0.008448724111216851, + 0.039709625925822375, + 0.001182945896190995, + -0.037343734133440394, + 0.03833367306762126, + -0.0318399227006359, + -0.011200629827619093, + -0.019456346976825818, + 0.039709625925822375, + 0.0741084473808504, + 0.06998058880624704, 0.03282986163481677, 0.05209320164963247, - 0.01219056876179996, + 0.126394655992493, + -0.027712064126032544, + 0.024574144485610048, 0.017694380194604446, + -0.051103262715451604, + 0.04934129593323023, + -0.001568959820211247, + 0.04658939021682799, + -0.005696818394814609, + -0.04835135699904936, + -0.015328488402222454, + -0.04009563984984263, -0.0249601584096303, - 0.053469154507833586, - -0.05385516843185383, -0.00019300696201012598, 0.15391371315651542, - 0.039709625925822375, - -0.02358420555142918, - 0.07961225881365488, + 0.006686757328995478, + -0.05935897986465832, -0.035967781275239266, - -0.04972730985725048, - 0.0025588987543921156, - 0.04934129593323023, - -0.007072771253015731, - -0.013952535544021335, - 0.04246153164222462, - -0.04284754556624487, - -0.02220825269322806, - 0.12501870313429186, - -0.05385516843185383, - 0.053469154507833586, - 0.06172487165704031, - -0.04972730985725048, - -0.016704441260423575, + 0.1277706088506941, + 0.006686757328995478, + 0.03558176735121902, + 0.016318427336403322, + 0.020446285911006685, + 0.0741084473808504, -0.026336111267831423, + 0.08786797596286161, + -0.04422349842444599, + -0.06623874415566393, + 0.04383748450042574, + -0.04972730985725048, + 0.020446285911006685, + -0.005696818394814609, 0.07686035309725264, - -0.07311850844666953, - -0.009824676969417972, - 0.030077955918414535, + -0.008448724111216851, + 0.03282986163481677, + 0.03145390877661565, 0.05484510736603471, - 0.07823630595545376, - -0.007072771253015731, - -0.027712064126032544, + 0.024574144485610048, + 0.06034891879883919, + 0.07273249452264928, + 0.024574144485610048, + -0.029088016984233665, + -0.0029449126784123676, + -0.011200629827619093, 0.001182945896190995, - 0.09199583453746497, - 0.07548440023905152, -0.007072771253015731, - 0.1277706088506941, - -0.012576582685820214, + 0.04246153164222462, + -0.004320865536613489, 0.10988322169407955, - 0.0080627101871966, - -0.013952535544021335, - -0.04284754556624487, - -0.046975404140848234, + -0.037343734133440394, + -0.016704441260423575, + 0.006686757328995478, 0.00943866304539772, 0.041085578784023497, - 0.030077955918414535, - 0.02182223876920781, + -0.001568959820211247, -0.06623874415566393, - 0.11951489170148738, + 0.03558176735121902, + 0.017694380194604446, + 0.02182223876920781, + 0.041085578784023497, + -0.034591828417038145, + -0.037343734133440394, + 0.03145390877661565, + 0.05622106022423583, 0.014942474478202204, - -0.005696818394814609, - -0.060734932722859444, - -0.02358420555142918, - -0.06348683843926169, - -0.04972730985725048, - -0.011200629827619093, - -0.046975404140848234, - 0.016318427336403322, - -0.018080394118624697, 0.010814615903598841, - -0.0029449126784123676, - -0.06623874415566393, + 0.10988322169407955, + 0.09887559882847057, + 0.001182945896190995, 0.07823630595545376, - 0.03833367306762126, -0.10063756561069194, - 0.03558176735121902, 0.04658939021682799, + 0.02182223876920781, 0.05759701308243695, - -0.005696818394814609, - -0.037343734133440394, + 0.04658939021682799, 0.08374011738825825 ] }, @@ -2473,139 +2415,139 @@ 132 ], "y": [ - 0.02090924803691378, - 0.004626123494339312, - -0.018766371041145515, - -0.025079773212822724, - 0.0005221962276462163, - 0.0023782190339952007, - 0.008649160755021803, - 0.014061356781762173, - 0.0254923032629858, - -0.011912383665013354, - -0.018407832152535906, - 0.024214427396466377, - -0.02256780214714429, - 0.004919903247914598, - -0.008170622618974065, - -0.018548330310735238, - -0.007105185961827832, - 0.012008661347633834, - 0.018618084713260215, - 0.011580347815175038, - -0.007449829402368672, - 0.01264639252518127, - -0.019867269093045897, - -0.015250402274648698, - -0.010913040915993408, - 0.003691389849940784, - -0.0024928165805605546, - -0.0007776304487895633, - 0.006325872630235647, - -0.005115114618450097, - -0.028534716572419568, - -0.005768925548395457, - -0.013047736065762188, - -0.028820647738861192, - -0.0035482860212044933, - -0.018530818319814676, - 0.004165370131410912, - 0.006749368070504823, - 0.0019226569832983451, - 0.02264996119206698, - 0.020792167630816705, - -0.0038184998660977126, - 0.011053015302265596, - -0.037005112578249275, - -0.01585513668744488, - 0.02560374181865984, - 0.02982157455613024, - 0.016710154185276348, - 0.006970480052674291, - -0.004867785078419127, - -0.021188768112257218, - 0.015228956794648838, - 0.031766761148240404, - -0.004424832962744514, - 0.012040504341003278, - -0.021868959740455263, - -0.014030859836463884, - -0.012065482400146238, - -0.012421658895909645, - 0.008993450695372609, - 0.014045154218359003, - 0.026886444708028364, - 0.03098052571069148, - 0.013277525924978313, - -0.0016929388165091135, - 0.02107173313979645, - -0.013609680442188583, - -0.0005291128611427387, - 0.012633195111147563, - 0.01585177092183963, - 0.01592201625116141, - -0.013505588816878359, - -0.029989934542940452, - 0.024035107762849686, - -0.0036208937433724725, - 0.00986821795135186, - -0.0025563066174634427, - 0.015401158704997121, - -0.0017121976478600984, - 0.0191111920844166, - -0.010060858295794301, - -0.020294974572644696, - 0.03240917262316084, - -0.002681574554502042, - 0.00011636986553689114, - 0.005574322230719957, - 0.013739407286284734, - -0.0066654621851517035, - -0.001037285439235573, - 0.028032372378139492, - 0.04316051018347186, - 0.012379703775175534, - 0.002297693317367038, - -0.0097459261449615, - -0.002440354209744388, - -0.00032389027157234823, - -0.00040304651144529, - 0.023079074995717545, - 0.0069802916130410056, - -0.008488594622839591, - 0.0011337923289514538, - -0.00845554673950999, - 0.012251890585047658, - 0.016676498249086746, - -0.01474459352050577, - 0.0050141833128581155, - -0.0021218581167275433, - 0.014514906599180545, - -0.0023141889635125527, - -0.01592533089818831, - 0.007626497719330275, - 0.007189881329180314, - 0.0014891685535230936, - 0.019136067420616056, - 0.015459647789467304, - 0.010780314394988014, - -0.019002304466773874, - -0.017332340171545398, - 0.009044765347958866, - 0.022799569535975638, - 0.0174367303520582, - -0.013712887832598355, - 0.021126195472609306, - -0.015975502568585168, - -0.0028846246835958628, - 0.0020442019813950094, - -0.03266377372798207, - 0.03286550024138621, - -0.013566458567642754, - 0.008925403172101965, - 0.02019472487087253, - 0.009589831559328486, - -0.028183039372319 + 0.015231685126095648, + -0.015421076522561732, + -0.005828944561737012, + -0.027782168227463364, + -0.0295299254606757, + -0.005044688697438994, + -0.027131385829086006, + -0.014967919098312174, + -0.018483150909256105, + -0.01901799646632892, + 0.0019843111862389486, + 0.006493388405783137, + -0.006212982280698071, + -0.02150559103893539, + -0.02260473952565378, + 0.0049281533846481825, + -0.02442570645715391, + -0.01064156355108878, + -0.004259935401689829, + -0.008386295263997364, + -0.018015115574451734, + 0.0007205377946915463, + -0.013088402487870042, + -0.014816805743643826, + -0.03829371114792571, + 0.00926588506800341, + -0.03464258186869676, + -0.008082039196945013, + 0.011364869658653243, + 0.0009835337264478081, + 0.011194546620714103, + -0.03700414327288421, + 0.008013545770845095, + -0.014911745323678609, + -0.016557630340603, + 0.015794064646605606, + -0.003318999321802166, + -0.023694428109623075, + 0.0037830236465225455, + 0.022283588751957895, + 0.020259778298767666, + -0.007782211098720479, + 0.001633485541469888, + -0.00336730837946193, + -0.00196576220854981, + -0.0037485742600287713, + -0.023103696888923084, + -0.02538158359937403, + 0.010877528444466239, + 0.003396870126202189, + -0.00931576420516173, + -0.02393598933865847, + -0.022060080451219353, + 0.007751428520500037, + -0.017131675982750533, + -0.016893899230157834, + 0.0016655963268504305, + 0.014312087552817005, + 0.01972723784114364, + -0.005721194774450792, + 0.022818103401221923, + 0.008053488228865244, + 0.018756167519343916, + -0.013320754164637287, + 0.034324781399503214, + -0.024136787863848923, + -0.009247336090314272, + -0.01918941549182599, + 0.010341943198896466, + -0.038897981926843614, + 0.010881436382772332, + -0.01566089827165408, + 0.020615764532841765, + 0.013009030249356294, + 0.0073269433667188175, + 0.006450426551069444, + -0.002280843953683123, + 0.020357206907443796, + 0.020805411773395996, + 0.02742199422986559, + 0.011743537467864502, + -0.0027340598135190738, + 0.011738907246611663, + -0.016267946667548862, + 0.002143508330794054, + 0.007558658619120678, + -0.007106225052230407, + -0.0011209748205877786, + -0.019965748788626093, + 0.003937343011525476, + -0.0213359342001851, + 0.01935601032943401, + -0.010330455512844257, + 0.03782723369570142, + 0.0146552688808932, + 0.013230530856790384, + -0.004987096013140689, + 0.0333764888489525, + -0.00968681496343167, + 0.004587070452381084, + 0.005656357169409813, + -0.023245933365898277, + 0.015162951494691587, + 0.0026953573522491715, + 0.0017848887739108383, + -0.00334948187707321, + -0.021721428589336345, + -0.010634357965284093, + 0.0055700317473139855, + -0.0009479228487899657, + -0.0020297213245332613, + -0.013546639944096802, + -0.018097376394800455, + 0.01808774523221887, + 0.012340078071508173, + -0.020561983875713168, + -0.0017113426742070572, + -0.01261277549337654, + -0.03618135352270031, + 0.02271639240732281, + -0.010989417897220211, + 0.03516725603493444, + -0.015797770140132998, + 0.00854976213952673, + -0.0064595754436547435, + -0.0008266335032858316, + -0.005144033539497858, + -0.03487386294915826, + -0.017215007399405, + -0.01813056307774026, + 0.007110999760004685, + -0.01887035231907848, + -0.032472689994759454 ] } ], @@ -3602,138 +3544,138 @@ 132 ], "y": [ + 0.01770335448356722, -0.002592261998183278, -0.002592261998183278, -0.03949338287409329, + 0.03430885887772673, + -0.03949338287409329, -0.0763945037500033, - 0.07120997975363674, -0.03949338287409329, + -0.002592261998183278, -0.03949338287409329, -0.03949338287409329, - 0.03430885887772673, + -0.002592261998183278, -0.03949338287409329, -0.03949338287409329, -0.0763945037500033, -0.0763945037500033, - -0.0763945037500033, - -0.03949338287409329, - 0.01585829843977173, - -0.0763945037500033, - -0.03949338287409329, - -0.002592261998183278, - -0.002592261998183278, - 0.0029429061332032365, - -0.002592261998183278, 0.07120997975363674, - -0.03949338287409329, 0.07120997975363674, -0.0018542395806650938, - -0.002592261998183278, 0.03430885887772673, - -0.047980640675552584, - -0.002592261998183278, -0.002592261998183278, + -0.047980640675552584, 0.03430885887772673, -0.03949338287409329, + -0.002592261998183278, + -0.03949338287409329, -0.03949338287409329, -0.0763945037500033, -0.03949338287409329, - -0.002592261998183278, - 0.05017634085436802, + 0.03430885887772673, + 0.03430885887772673, -0.03949338287409329, 0.03430885887772673, + 0.10811110062954676, + 0.03430885887772673, 0.07120997975363674, - -0.002592261998183278, - -0.002592261998183278, -0.03949338287409329, -0.03949338287409329, + 0.03430885887772673, -0.002592261998183278, -0.002592261998183278, - -0.002592261998183278, + 0.13025177315509276, 0.07120997975363674, 0.02545258986750832, - 0.07120997975363674, -0.03949338287409329, - 0.03910600459159503, - -0.03949338287409329, - 0.08006624876385515, - -0.0763945037500033, + 0.10811110062954676, + -0.007020396503292483, -0.03949338287409329, + 0.03430885887772673, -0.03949338287409329, + 0.03430885887772673, + 0.03430885887772673, + 0.10811110062954676, + -0.002592261998183278, + 0.07120997975363674, -0.002592261998183278, -0.021411833644897377, + 0.03430885887772673, + 0.03430885887772673, + 0.03430885887772673, + 0.03430885887772673, + 0.03430885887772673, -0.03949338287409329, + -0.03949338287409329, + -0.002592261998183278, 0.03430885887772673, -0.03949338287409329, - 0.14501222150545676, + -0.002592261998183278, + 0.03430885887772673, -0.03949338287409329, -0.05056371913686628, 0.07194800217115493, - 0.056080520194513636, - -0.03949338287409329, + -0.012555564634678981, + 0.012906208769698923, + 0.07120997975363674, + 0.10811110062954676, + 0.08080427118137334, + -0.0763945037500033, 0.03430885887772673, 0.07120997975363674, - -0.03949338287409329, - -0.03949338287409329, 0.07120997975363674, -0.03949338287409329, - 0.005156973385757823, - -0.0763945037500033, - -0.002592261998183278, - -0.002592261998183278, + 0.03430885887772673, + -0.03949338287409329, -0.002592261998183278, - -0.0763945037500033, - 0.07120997975363674, -0.002592261998183278, -0.0708593356186168, - -0.002592261998183278, + 0.0003598276718895252, -0.03949338287409329, 0.07120997975363674, - -0.06938329078358041, - -0.002592261998183278, + 0.07120997975363674, + 0.03430885887772673, -0.03949338287409329, 0.14132210941786577, + 0.15534453535071155, 0.10811110062954676, - 0.03430885887772673, - -0.03949338287409329, - -0.03949338287409329, - 0.14501222150545676, + 0.08486339447772344, -0.002592261998183278, -0.002592261998183278, 0.07120997975363674, + 0.03430885887772673, + -0.002592261998183278, -0.002592261998183278, 0.03430885887772673, + -0.03949338287409329, + 0.03430885887772673, -0.002592261998183278, -0.002592261998183278, - -0.0763945037500033, - 0.028404679537581124, + 0.03430885887772673, -0.002592261998183278, 0.07120997975363674, - 0.03430885887772673, - 0.03430885887772673, -0.03949338287409329, - 0.08670845052151895, - 0.07120997975363674, + -0.03949338287409329, + -0.002592261998183278, 0.03430885887772673, - -0.03395821474270679, -0.002592261998183278, - -0.03949338287409329, -0.002592261998183278, + -0.0763945037500033, -0.03949338287409329, - -0.02583996815000658, + 0.02360753382371283, + 0.07120997975363674, -0.002592261998183278, -0.03949338287409329, - -0.03949338287409329, - 0.003311917341962329, - -0.03949338287409329, + 0.09187460744414634, -0.002592261998183278, - 0.10811110062954676, - -0.0763945037500033, 0.03430885887772673, + -0.002592261998183278, + -0.0763945037500033, -0.024732934523729287, + 0.03430885887772673, 0.023238522614953735, - -0.002592261998183278, - -0.011079519799642579, + -0.03949338287409329, -0.03949338287409329 ] }, @@ -3881,139 +3823,139 @@ 132 ], "y": [ - 0.036173881319049764, - -0.021702339029765656, - -0.004827344613356406, - -0.04585371324963408, - 0.021657333531173127, - -0.019557492599216835, - -0.016391484441786622, - 0.004826645265102516, - 0.026677561541063308, - -0.020494778395885635, - -0.019045046146368566, - 0.020546868283924196, - -0.03767286475161845, - -0.025381010739657484, - -0.010055873741307193, - 0.006350417632745559, - -0.019896188609062136, - 0.006649615916558177, - 0.02621139123628575, - -0.014090746522772122, - -0.001262154321621097, - 0.011773637398311868, - 0.011258135449664754, - 0.008944049563999953, - 0.006999038596481625, - 0.012591280717550547, - 0.006376848293006614, - 0.009152159296335255, - 0.024846242702600232, - -0.030649754557607656, - -0.03226696508115764, - 0.007357869247591211, - -0.012233157667270823, - -0.005235643701531674, - -0.019963033597623834, - -0.02294552944805546, - -0.004610187191359073, - 0.014874207825065582, - -0.02541069014530403, - 0.0007426496564170359, - 0.04671645093525994, - 0.011725793349133122, - 0.046377763208378976, - -0.04412613967921996, - -0.003307935367806759, - 0.010059493995214983, - 0.057837475697029336, - 0.02464075426683707, - 0.03232886503700636, - 0.015181239790241672, - 0.007730755650970042, - -0.01288826786920192, - 0.019209494823279768, - 0.017901355466069357, - 0.0010382296741955493, - -0.03525362630963111, - -0.0010484738172200827, - -0.034411594412093106, - -0.029311390546604275, - -0.026183464604745584, - 0.006211365730593833, - 0.007955487839617569, - 0.009444435976199097, - 0.029891333441255042, - 0.006322089058982078, - 0.01240042724866708, - 0.03372764125476694, - 0.01563560428520054, - 0.004682485638364525, - -0.005096275037787555, - 0.016834842476349577, - 0.02777791931447148, - -0.04590318168669977, - 0.039194821530438365, - -0.03277665344423763, - -0.02034178179038435, - -0.021864092307613023, - 0.022562752824849954, - 0.011182533788756978, - 0.0013496505057414734, - -0.03312248173854206, - 0.026222471799173384, - 0.0035463695703147743, - -0.018847397811071878, - 0.006249373998074355, - -0.027683935756017426, - 0.00739858101231083, - -0.02736892056924086, - 0.0003022669656364365, - -0.0007816308585437103, - 0.01700147809620549, - -0.0061187072597464025, - 0.014402384577943744, - -0.0046736485052005, - -0.02700606133328556, - 0.014764184029720314, - -0.013134114176875222, - 0.028808489012539, - 0.02254331768624029, - 0.005034238597544158, - 0.02928319889004928, - 0.0029064231086462057, - -0.001784491148696758, - -0.00893612903310391, - 0.0030842824733004984, - -0.021545185329186856, - 0.056749951384627524, - 0.07189639275286117, - 0.015165622525579, - -0.02023890492078882, - 0.011598392971903482, - 0.020595548628953693, - 0.01685751810776878, - 0.0014485791272968842, - 0.012291320055589793, - 0.012358662118871731, - -0.012061731114809418, - 0.011744440395706587, - -0.009350081063623469, - 0.02917756393068547, - 0.006130601782959003, - -0.01613858615607142, - 0.017689170808935114, - -0.032720269822326155, - 0.03713214633579796, - 0.003601191679832815, - -0.03355297405509526, - 0.051788117379947894, - -0.02597487243898064, - 0.008109666350814736, - 0.0271125648455035, - 0.013573789266825956, - -0.04185854704123922 + 0.013791087535585032, + -0.0067415479818104754, + 0.008122595600109256, + -0.007257066606169001, + -0.0025554928421562782, + -0.01730642961557491, + -0.03973989461680715, + -0.035215263664104976, + 0.012493718436498158, + -0.023285725079843646, + -0.000379727084753736, + -0.019743775483154393, + -0.009392833506987743, + -0.035302615781131336, + -0.03905225690190696, + -0.03256533884487734, + 0.005676995172456719, + -0.003969962854634164, + 0.006605608754211003, + 0.0022337245636956634, + -0.003705667926297255, + 0.0160131771773543, + 0.0012497484695890783, + -0.014183640472615826, + -0.024897592873238887, + -0.01193835158341511, + -0.011352501675463188, + -0.02211270494407038, + -0.01327963980677072, + 0.014679377827269066, + 0.028201762472852287, + -0.0497156513407594, + 0.015564039073432956, + 0.0059047509621594245, + -0.0002891958481750804, + 0.0413585013037903, + 0.010549489171225434, + -0.04074701058422851, + -0.00726694018568288, + 0.0049352105579747434, + 0.010367920997501279, + 0.00007617849551178838, + 0.0272796380000965, + 0.013649300291033686, + -0.015212767758098487, + 0.00543914128208987, + -0.03867304908811106, + -0.04146073514129962, + 0.037180730183873835, + -0.021689236815534446, + 0.006731851807441467, + -0.004709125236394509, + 0.015675703756764758, + -0.014059228838315884, + 0.006718223688295505, + -0.0325135193563347, + -0.029345864135223824, + 0.016299251273056356, + 0.020736592779718496, + 0.01752768688371003, + 0.0036211377320435026, + 0.00582584141816508, + -0.006378574065545362, + -0.03888936237339751, + 0.0046740506114529325, + -0.04102493363171138, + -0.0036541432594371158, + -0.031746617901640996, + 0.02287841546841343, + -0.017455500382433535, + 0.004170520880805757, + 0.02950946182329947, + 0.004286409933083225, + 0.00949468220189915, + 0.00848214300974643, + -0.0003506567655354102, + 0.00622718589199493, + -0.023382695759762345, + 0.052460227354301635, + 0.04570683846266361, + 0.008142368340912764, + 0.003042472748240509, + -0.009428181419382162, + -0.03414234589319587, + 0.004322787221553486, + 0.008760128728215626, + -0.021232834359493004, + 0.012613683896552527, + 0.016963752411511428, + -0.0011747844791662968, + -0.006854960778212722, + 0.03271464639099751, + -0.020651434806323598, + 0.013655494578349706, + 0.0325822587161095, + 0.01240140962718438, + 0.004465601030109772, + 0.033314444821553636, + 0.005823579974981913, + 0.02011640909906851, + 0.006129851307818842, + -0.03152688781818483, + 0.02140947382021927, + 0.014218926009913056, + -0.018311539156434627, + 0.02234257426494738, + -0.001616018921944109, + 0.0205886244053183, + 0.013896438117549436, + -0.024496748572054607, + 0.05235423105172521, + -0.032992757155550216, + -0.024303522197554835, + 0.01212643354855765, + 0.03392301443916518, + -0.006049284904496647, + -0.007631149465924967, + -0.028761809138011133, + -0.04800248232671432, + -0.02509379567625947, + -0.00009153975376653429, + 0.023345773261553425, + -0.01910895717320799, + 0.017869477687563542, + -0.02194514854495601, + -0.024607711275376646, + 0.029117388128151366, + -0.035875224977058126, + -0.029496018124353707, + -0.03515240882947335, + 0.006314919004553546, + -0.03774485237314226, + -0.044343789309302856 ] } ], @@ -4919,13 +4861,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5dfb811d4a2b457c80f89229f35c4d88", + "model_id": "4798c8e161a24d56ac7449decf7a734f", "version_major": 2, "version_minor": 0 }, @@ -4941,11 +4883,12 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 5.7s finished\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 1.6s finished\n", "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 2.9s finished\n", - "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 4.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 10.3s finished\n" + "[Parallel(n_jobs=-1)]: Batch computation too fast (0.1265852451324463s.) Setting batch_size=2.\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", + "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 0.1s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" ] }, { @@ -4976,28 +4919,28 @@ " \n", " \n", " 0\n", - " 0.020909\n", - " 0.036174\n", + " 0.015232\n", + " 0.013791\n", " \n", " \n", " 1\n", - " 0.004626\n", - " -0.021702\n", + " -0.015421\n", + " -0.006742\n", " \n", " \n", " 2\n", - " -0.018766\n", - " -0.004827\n", + " -0.005829\n", + " 0.008123\n", " \n", " \n", " 3\n", - " -0.025080\n", - " -0.045854\n", + " -0.027782\n", + " -0.007257\n", " \n", " \n", " 4\n", - " 0.000522\n", - " 0.021657\n", + " -0.029530\n", + " -0.002555\n", " \n", " \n", " ...\n", @@ -5006,28 +4949,28 @@ " \n", " \n", " 128\n", - " -0.013566\n", - " -0.025975\n", + " -0.017215\n", + " -0.029496\n", " \n", " \n", " 129\n", - " 0.008925\n", - " 0.008110\n", + " -0.018131\n", + " -0.035152\n", " \n", " \n", " 130\n", - " 0.020195\n", - " 0.027113\n", + " 0.007111\n", + " 0.006315\n", " \n", " \n", " 131\n", - " 0.009590\n", - " 0.013574\n", + " -0.018870\n", + " -0.037745\n", " \n", " \n", " 132\n", - " -0.028183\n", - " -0.041859\n", + " -0.032473\n", + " -0.044344\n", " \n", " \n", "\n", @@ -5036,22 +4979,22 @@ ], "text/plain": [ " s1 s4\n", - "0 0.020909 0.036174\n", - "1 0.004626 -0.021702\n", - "2 -0.018766 -0.004827\n", - "3 -0.025080 -0.045854\n", - "4 0.000522 0.021657\n", + "0 0.015232 0.013791\n", + "1 -0.015421 -0.006742\n", + "2 -0.005829 0.008123\n", + "3 -0.027782 -0.007257\n", + "4 -0.029530 -0.002555\n", ".. ... ...\n", - "128 -0.013566 -0.025975\n", - "129 0.008925 0.008110\n", - "130 0.020195 0.027113\n", - "131 0.009590 0.013574\n", - "132 -0.028183 -0.041859\n", + "128 -0.017215 -0.029496\n", + "129 -0.018131 -0.035152\n", + "130 0.007111 0.006315\n", + "131 -0.018870 -0.037745\n", + "132 -0.032473 -0.044344\n", "\n", "[133 rows x 2 columns]" ] }, - "execution_count": 26, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -5093,13 +5036,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "99f2a175dec14d80be25b0042cf5fe39", + "model_id": "4e0b2ffc1515464dbc2e2f23be9f28ff", "version_major": 2, "version_minor": 0 }, @@ -5115,20 +5058,21 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 1.2s finished\n", + "[Parallel(n_jobs=-1)]: Batch computation too fast (0.1340169906616211s.) Setting batch_size=2.\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.1s finished\n", "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.2s finished\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.4s finished\n", "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 1.0s finished\n", - "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 3.4s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 9.3s finished\n" + "[Parallel(n_jobs=-1)]: Done 3 out of 3 | elapsed: 0.8s finished\n", + "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Imputation results available for models: dict_keys(['best_method', 'QRF', 'QuantReg', 'Matching'])\n", + "Imputation results available for models: dict_keys(['best_method', 'QuantReg', 'QRF', 'Matching'])\n", "The best performing model is: OLSResults\n" ] } @@ -5157,7 +5101,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pe3.13", + "display_name": "pe3.13 (3.13.0)", "language": "python", "name": "python3" }, diff --git a/docs/autoimpute/index.md b/docs/autoimpute/index.md index 1600464..d85a15c 100644 --- a/docs/autoimpute/index.md +++ b/docs/autoimpute/index.md @@ -1,5 +1,5 @@ # Autoimpute -This chapter describes how the `autoimpute` function works to automize the entire method comparison, selection, and imputation pipeline in a single function. +The `autoimpute` function automates the entire method comparison, selection, and imputation pipeline in a single call. -The pipeline begins with input validation to ensure all necessary columns exist and quantiles are properly specified. It then preprocesses the donor and receiver datasets to prepare them for model training and evaluation. The function supports imputing numerical, categorical and boolean variable types, internally selecting the method corresponding to each variable type. At its core, `autoimpute` employs cross-validation on the donor data to evaluate multiple imputation methods. Each model is assessed on its ability to accurately predict known values using two different metrics: quantile loss for numerical imputation and log loss for categorical imputation. The method with the lowest average loss (with different metrics combined with a weighted-rank approach) across target variables is automatically selected as the optimal approach for the specific dataset and imputation task. The chosen model is then trained on the complete donor dataset and applied to generate imputations for the missing values in the receiver data. Finally, the pipeline reintegrates these imputed values back into the original receiver dataset, producing a complete dataset ready for downstream analysis. +The pipeline begins with input validation, then preprocesses the donor and receiver datasets for model training and evaluation. It supports numerical, categorical, and boolean variable types, selecting the appropriate method for each. At its core, `autoimpute` runs cross-validation on the donor data to evaluate multiple imputation methods. Each model is scored using quantile loss for numerical variables and log loss for categorical variables. The method with the lowest average loss (combining different metrics via a weighted-rank approach) across target variables is selected automatically. That model is then trained on the full donor dataset and applied to generate imputations for the receiver. The result is an `AutoImputeResult` object containing the imputations, the augmented receiver dataset, fitted models, and cross-validation results. diff --git a/docs/imputation-benchmarking/benchmarking-methods.ipynb b/docs/imputation-benchmarking/benchmarking-methods.ipynb index 9f9c3aa..6881559 100644 --- a/docs/imputation-benchmarking/benchmarking-methods.ipynb +++ b/docs/imputation-benchmarking/benchmarking-methods.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "45bc35df", "metadata": {}, "outputs": [ @@ -29,10 +29,10 @@ "data": [ { "alignmentgroup": "True", - "hovertemplate": "Method=QRF
Quantiles=%{x}
Test Quantile loss=%{y}", + "hovertemplate": "Method=QRF
Quantiles=%{x}
Quantile loss=%{y}", "legendgroup": "QRF", "marker": { - "color": "#636EFA", + "color": "#88CCEE", "pattern": { "shape": "" } @@ -166,10 +166,10 @@ }, { "alignmentgroup": "True", - "hovertemplate": "Method=OLS
Quantiles=%{x}
Test Quantile loss=%{y}", + "hovertemplate": "Method=OLS
Quantiles=%{x}
Quantile loss=%{y}", "legendgroup": "OLS", "marker": { - "color": "#EF553B", + "color": "#CC6677", "pattern": { "shape": "" } @@ -303,10 +303,10 @@ }, { "alignmentgroup": "True", - "hovertemplate": "Method=QuantReg
Quantiles=%{x}
Test Quantile loss=%{y}", + "hovertemplate": "Method=QuantReg
Quantiles=%{x}
Quantile loss=%{y}", "legendgroup": "QuantReg", "marker": { - "color": "#00CC96", + "color": "#DDCC77", "pattern": { "shape": "" } @@ -440,10 +440,10 @@ }, { "alignmentgroup": "True", - "hovertemplate": "Method=Matching
Quantiles=%{x}
Test Quantile loss=%{y}", + "hovertemplate": "Method=Matching
Quantiles=%{x}
Quantile loss=%{y}", "legendgroup": "Matching", "marker": { - "color": "#AB63FA", + "color": "#117733", "pattern": { "shape": "" } @@ -585,12 +585,12 @@ }, "tracegroupgap": 0 }, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "shapes": [ { "line": { - "color": "#636EFA", + "color": "#88CCEE", "dash": "dot", "width": 2 }, @@ -603,7 +603,7 @@ }, { "line": { - "color": "#EF553B", + "color": "#CC6677", "dash": "dot", "width": 2 }, @@ -616,7 +616,7 @@ }, { "line": { - "color": "#00CC96", + "color": "#DDCC77", "dash": "dot", "width": 2 }, @@ -629,7 +629,7 @@ }, { "line": { - "color": "#AB63FA", + "color": "#117733", "dash": "dot", "width": 2 }, @@ -1470,7 +1470,11 @@ 0, 1 ], + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", "showgrid": false, + "showline": true, "title": { "font": { "size": 12 @@ -1485,12 +1489,16 @@ 0, 1 ], - "showgrid": false, + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true, "title": { "font": { "size": 12 }, - "text": "Test Quantile loss" + "text": "Quantile loss" }, "zeroline": false } @@ -1516,7 +1524,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "757ba887bc594faba40db29be449fc8a", + "model_id": "e01f756472f74ff7b937e4cb30273e97", "version_major": 2, "version_minor": 0 }, @@ -1530,7 +1538,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c4ac70f725d34a4ab138535bfefbf164", + "model_id": "c2b8d34026e0466893d9871472ee7088", "version_major": 2, "version_minor": 0 }, @@ -1685,10 +1693,10 @@ "data": [ { "alignmentgroup": "True", - "hovertemplate": "Method=QRF
Quantiles=%{x}
Test Quantile loss=%{y}", + "hovertemplate": "Method=QRF
Quantiles=%{x}
Quantile loss=%{y}", "legendgroup": "QRF", "marker": { - "color": "#636EFA", + "color": "#88CCEE", "pattern": { "shape": "" } @@ -1784,10 +1792,10 @@ }, { "alignmentgroup": "True", - "hovertemplate": "Method=OLS
Quantiles=%{x}
Test Quantile loss=%{y}", + "hovertemplate": "Method=OLS
Quantiles=%{x}
Quantile loss=%{y}", "legendgroup": "OLS", "marker": { - "color": "#EF553B", + "color": "#CC6677", "pattern": { "shape": "" } @@ -1883,10 +1891,10 @@ }, { "alignmentgroup": "True", - "hovertemplate": "Method=QuantReg
Quantiles=%{x}
Test Quantile loss=%{y}", + "hovertemplate": "Method=QuantReg
Quantiles=%{x}
Quantile loss=%{y}", "legendgroup": "QuantReg", "marker": { - "color": "#00CC96", + "color": "#DDCC77", "pattern": { "shape": "" } @@ -1982,10 +1990,10 @@ }, { "alignmentgroup": "True", - "hovertemplate": "Method=Matching
Quantiles=%{x}
Test Quantile loss=%{y}", + "hovertemplate": "Method=Matching
Quantiles=%{x}
Quantile loss=%{y}", "legendgroup": "Matching", "marker": { - "color": "#AB63FA", + "color": "#117733", "pattern": { "shape": "" } @@ -2038,44 +2046,44 @@ ], "xaxis": "x", "y": [ - 12257049.472826088, - 12257049.472826088, - 11844892.22826087, - 11844892.22826087, - 11432734.983695652, - 11432734.983695652, - 11020577.739130436, - 11020577.739130436, - 10608420.494565217, - 10608420.494565217, - 10196263.25, - 10196263.25, - 9784106.005434783, - 9784106.005434783, - 9371948.760869564, - 9371948.760869564, - 8959791.516304348, - 8959791.516304348, - 8547634.27173913, - 8547634.27173913, - 8135477.027173913, - 8135477.027173913, - 7723319.782608695, - 7723319.782608695, - 7311162.5380434785, - 7311162.5380434785, - 6899005.293478261, - 6899005.293478261, - 6486848.048913044, - 6486848.048913044, - 6074690.804347826, - 6074690.804347826, - 5662533.559782608, - 5662533.559782608, - 5250376.315217392, - 5250376.315217392, - 4838219.070652175, - 4838219.070652175 + 12257052.081521738, + 12257052.081521738, + 11844897.445652174, + 11844897.445652174, + 11432742.80978261, + 11432742.80978261, + 11020588.173913043, + 11020588.173913043, + 10608433.538043479, + 10608433.538043479, + 10196278.902173912, + 10196278.902173912, + 9784124.266304348, + 9784124.266304348, + 9371969.630434783, + 9371969.630434783, + 8959814.994565217, + 8959814.994565217, + 8547660.358695652, + 8547660.358695652, + 8135505.722826087, + 8135505.722826087, + 7723351.0869565215, + 7723351.0869565215, + 7311196.451086956, + 7311196.451086956, + 6899041.815217392, + 6899041.815217392, + 6486887.179347826, + 6486887.179347826, + 6074732.543478261, + 6074732.543478261, + 5662577.907608695, + 5662577.907608695, + 5250423.271739131, + 5250423.271739131, + 4838268.635869565, + 4838268.635869565 ], "yaxis": "y" } @@ -2089,12 +2097,12 @@ }, "tracegroupgap": 0 }, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "shapes": [ { "line": { - "color": "#636EFA", + "color": "#88CCEE", "dash": "dot", "width": 2 }, @@ -2107,7 +2115,7 @@ }, { "line": { - "color": "#EF553B", + "color": "#CC6677", "dash": "dot", "width": 2 }, @@ -2120,7 +2128,7 @@ }, { "line": { - "color": "#00CC96", + "color": "#DDCC77", "dash": "dot", "width": 2 }, @@ -2133,7 +2141,7 @@ }, { "line": { - "color": "#AB63FA", + "color": "#117733", "dash": "dot", "width": 2 }, @@ -2141,8 +2149,8 @@ "type": "line", "x0": -0.5, "x1": 18.5, - "y0": 8547634.27173913, - "y1": 8547634.27173913 + "y0": 8547660.35869565, + "y1": 8547660.35869565 } ], "template": { @@ -2974,7 +2982,11 @@ 0, 1 ], + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", "showgrid": false, + "showline": true, "title": { "font": { "size": 12 @@ -2989,12 +3001,16 @@ 0, 1 ], - "showgrid": false, + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true, "title": { "font": { "size": 12 }, - "text": "Test Quantile loss" + "text": "Quantile loss" }, "zeroline": false } @@ -3019,7 +3035,7 @@ "0 networth wasserstein_distance 2.818652e+07\n", "Model: Matching, distribution similarity: \n", " Variable Metric Distance\n", - "0 networth wasserstein_distance 1.798087e+07\n" + "0 networth wasserstein_distance 1.798092e+07\n" ] } ], @@ -3232,7 +3248,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pe3.13", + "display_name": "pe3.13 (3.13.0)", "language": "python", "name": "python3" }, diff --git a/docs/imputation-benchmarking/cross-validation.md b/docs/imputation-benchmarking/cross-validation.md index 5e440e8..8da745f 100644 --- a/docs/imputation-benchmarking/cross-validation.md +++ b/docs/imputation-benchmarking/cross-validation.md @@ -1,14 +1,11 @@ # Cross-validation and model imputation comparison -This page documents the cross-validation utilities for evaluating imputation model performance. Cross-validation provides robust -estimates of how well a model will generalize to unseen data by training and testing on multiple data splits. Functions like `get_imputations`, will then build upon it, to standardize evaluation for all models, ensuring possible through a consistent experimental setup. +This page documents the cross-validation utilities for evaluating imputation model performance. Cross-validation estimates how well a model generalizes to unseen data by training and testing on multiple splits. The `get_imputations` function builds on this to standardize evaluation across all models. -Microimpute's cross-validation automatically selects the appropriate metric based on variable type. Numerical variables are evaluated using quantile loss, which measures prediction accuracy across the conditional distribution. Categorical variables are evaluated using log loss (cross-entropy), which penalizes confident but incorrect predictions, see the [Metrics page](./metrics.md) for more details. +Microimpute's cross-validation selects the metric based on variable type: quantile loss for numerical variables, log loss for categorical variables. See the [Metrics page](./metrics.md) for details. ## Cross-validation -Cross-validation provides robust estimates of how well a model will generalize to unseen data by training and testing on multiple data splits. Microimpute's cross-validation automatically selects the appropriate metric based on variable type: quantile loss for numerical variables and log loss for categorical variables. - ### cross_validate_model ```python @@ -98,7 +95,7 @@ For model selection, focus on the test loss (`mean_test`). When comparing multip ## Imputation generation for model comparison -The `get_imputations` function generates imputations using cross-validation for multiple model classes in a single call, organizing results in a consistent format for downstream comparison and evaluation. +The `get_imputations` function generates imputations for multiple model classes in a single call, organizing results in a consistent format for comparison. ### get_imputations diff --git a/docs/imputation-benchmarking/index.md b/docs/imputation-benchmarking/index.md index 337e1c5..d8b9075 100644 --- a/docs/imputation-benchmarking/index.md +++ b/docs/imputation-benchmarking/index.md @@ -1,5 +1,5 @@ # Benchmarking different imputation methods -This chapter describes how the Microimpute package allows you to compare different imputation methods using preprocessing, cross-validation, metric comparison, and evaluation tools. +This chapter describes how microimpute lets you compare imputation methods using preprocessing, cross-validation, metric comparison, and evaluation tools. -The benchmarking functionality enables systematically comparing multiple imputation models using a common dataset, allowing for robust evaluation of their performance. It supports cross-validation to diagnose overfitting and measure performance on training data leveraging the availability of ground truth. By assessing accuracy of numeric imputation across various quantiles, it is possible to gain a more comprehensive understanding of how each method performs across different levels of the distribution. Categorical imputation is assessed with log loss. This process is further supported by visualizations that highlight differences between approaches, making it easy to identify which imputation methods perform best under specific conditions. Predictor evaluation tools are also available to inform decision-making when setting up the imputation task. +The benchmarking functionality supports systematic comparison of multiple models on a common dataset. Cross-validation diagnoses overfitting and measures performance on held-out data where ground truth is available. Numerical imputation accuracy is assessed across quantiles via quantile loss; categorical imputation uses log loss. Visualizations show differences between methods, and predictor evaluation tools help inform variable selection when setting up the imputation task. diff --git a/docs/imputation-benchmarking/metrics.md b/docs/imputation-benchmarking/metrics.md index cf2f6a3..32f429f 100644 --- a/docs/imputation-benchmarking/metrics.md +++ b/docs/imputation-benchmarking/metrics.md @@ -1,14 +1,14 @@ # Metrics and evaluation -This page documents the evaluation metrics and predictor analysis tools available for assessing imputation quality. These utilities help understand model performance, compare methods, and analyze the contribution of individual predictors. +This page documents the evaluation metrics and predictor analysis tools for assessing imputation quality. ## Loss metrics -Microimpute employs evaluation metrics tailored to the type of variable being imputed. The framework automatically selects the appropriate metric based on whether the imputed variable is numerical or categorical, ensuring meaningful performance assessment across different data types. +Microimpute selects the evaluation metric based on whether the imputed variable is numerical or categorical. ### Quantile loss -Quantile loss assesses imputation quality for numerical variables. This approach provides a more nuanced evaluation than traditional metrics like mean squared error, particularly for capturing performance across different parts of the distribution. +Quantile loss assesses imputation quality for numerical variables. It captures performance across different parts of the distribution, unlike mean squared error which only measures average accuracy. The quantile loss implements the standard pinball loss formulation: @@ -62,7 +62,7 @@ When predictions are class labels rather than probabilities, the function conver ### compute_loss -A unified function that selects the appropriate loss metric based on the specified type, providing a consistent interface for both numerical and categorical evaluation. +A unified function that selects the loss metric based on the specified type. ```python def compute_loss( @@ -86,7 +86,7 @@ Returns a tuple of `(element_wise_losses, mean_loss)`. ### compare_metrics -Compares metrics across multiple imputation methods, automatically detecting variable types and applying the appropriate metric. For models that handle both numerical and categorical variables, the evaluation produces separate results for each metric type. +Compares metrics across multiple imputation methods, detecting variable types and applying the appropriate metric. For models that handle both numerical and categorical variables, results are produced separately for each metric type. ```python def compare_metrics( @@ -106,7 +106,7 @@ Returns a DataFrame with columns `Method`, `Imputed Variable`, `Percentile`, `Lo ## Distribution comparison -Beyond point-wise loss metrics, evaluating how well imputed values preserve distributional characteristics provides insight into whether the imputation maintains the statistical properties of the original data. +Evaluating how well imputed values preserve distributional characteristics tells you whether the imputation maintains the statistical properties of the original data. ### Wasserstein distance @@ -126,7 +126,7 @@ $$D_{KL}(P||Q) = \sum_{x \in \mathcal{X}} P(x) \log\left(\frac{P(x)}{Q(x)}\right where $P$ is the reference distribution (original data), $Q$ is the approximation (imputed data), and $\mathcal{X}$ is the set of all possible categorical values. KL divergence measures how much information is lost when using the imputed distribution to approximate the true distribution. Lower values indicate better preservation of the original categorical distribution. -When sample weights are provided, the probability distributions are computed as weighted proportions rather than simple counts, ensuring proper comparison of weighted survey data. +When sample weights are provided, the probability distributions are computed as weighted proportions rather than simple counts, so that weighted survey data can be compared correctly. ### kl_divergence @@ -178,11 +178,11 @@ Note that data must not contain null or infinite values. If your data contains s ## Predictor analysis -Understanding which predictors contribute most to imputation quality helps with feature selection and model interpretation. These tools analyze predictor-target relationships and evaluate sensitivity to predictor selection. +These tools analyze which predictors contribute most to imputation quality, helping with feature selection and model interpretation. ### Mutual information -Mutual information measures the reduction in uncertainty about one variable given knowledge of another. Unlike correlation coefficients that capture only linear relationships, mutual information detects any statistical dependency, making it valuable for mixed data types. +Mutual information measures the reduction in uncertainty about one variable given knowledge of another. Unlike correlation coefficients, which capture only linear relationships, mutual information detects any statistical dependency. For discrete random variables $X$ and $Y$: @@ -200,21 +200,23 @@ where $H(X)$ and $H(Y)$ are the entropies of $X$ and $Y$ respectively. Normalize def compute_predictor_correlations( data: pd.DataFrame, predictors: List[str], - imputed_variables: List[str], + imputed_variables: Optional[List[str]] = None, + method: str = "all", ) -> Dict[str, pd.DataFrame] ``` -| Parameter | Type | Description | -|-----------|------|-------------| -| data | pd.DataFrame | Dataset containing predictors and target variables | -| predictors | List[str] | Column names of predictor variables | -| imputed_variables | List[str] | Column names of target variables | +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| data | pd.DataFrame | - | Dataset containing predictors and target variables | +| predictors | List[str] | - | Column names of predictor variables | +| imputed_variables | List[str] | None | Column names of target variables | +| method | str | "all" | Which correlation method to use: "all", "mi" (mutual information), "pearson", or "spearman" | -Returns a dictionary containing `predictor_target_mi` DataFrame with mutual information scores. +Returns a dictionary containing DataFrames with correlation scores (e.g. `predictor_target_mi` for mutual information). ### Leave-one-out analysis -Leave-one-out predictor analysis evaluates model performance when each predictor is excluded. By comparing loss with and without each predictor, you can assess its contribution to imputation quality. Predictors whose removal causes large increases in loss are most important, while those with minimal impact might be candidates for removal to simplify the model. +Leave-one-out predictor analysis evaluates model performance when each predictor is excluded. Predictors whose removal causes large increases in loss are the most important, while those with minimal impact might be candidates for removal. ### leave_one_out_analysis @@ -223,9 +225,13 @@ def leave_one_out_analysis( data: pd.DataFrame, predictors: List[str], imputed_variables: List[str], - model_class: Type, - quantiles: Optional[List[float]] = QUANTILES, -) -> Dict[str, Any] + model_class: Type[Imputer], + weight_col: Optional[Union[str, np.ndarray, pd.Series]] = None, + quantiles: List[float] = QUANTILES, + train_size: float = TRAIN_SIZE, + n_jobs: int = 1, + random_state: int = RANDOM_STATE, +) -> pd.DataFrame ``` | Parameter | Type | Default | Description | @@ -233,14 +239,18 @@ def leave_one_out_analysis( | data | pd.DataFrame | - | Complete dataset | | predictors | List[str] | - | Column names of predictor variables | | imputed_variables | List[str] | - | Column names of variables to impute | -| model_class | Type | - | Imputer class to evaluate | +| model_class | Type[Imputer] | - | Imputer class to evaluate | +| weight_col | str, np.ndarray, or pd.Series | None | Sample weights column name or array | | quantiles | List[float] | [0.05 to 0.95 in steps of 0.05] | Quantiles to evaluate | +| train_size | float | 0.8 | Proportion of data for training | +| n_jobs | int | 1 | Number of parallel jobs | +| random_state | int | 42 | Random seed for reproducibility | -Returns a dictionary containing loss increase and relative impact for each predictor. +Returns a DataFrame containing loss increase and relative impact for each predictor. ### Progressive predictor inclusion -Progressive inclusion analysis adds predictors one at a time in order of their mutual information with the target. This greedy forward selection reveals the optimal inclusion order, marginal contribution of each predictor, and the minimal set of predictors achieving near-optimal performance. Diminishing returns in loss reduction indicate when additional predictors provide negligible improvement. +Progressive inclusion adds predictors one at a time, ordered by their mutual information with the target. This greedy forward selection reveals the optimal inclusion order and the marginal contribution of each predictor. Diminishing returns in loss reduction indicate when additional predictors add little. ### progressive_predictor_inclusion @@ -249,8 +259,12 @@ def progressive_predictor_inclusion( data: pd.DataFrame, predictors: List[str], imputed_variables: List[str], - model_class: Type, + model_class: Type[Imputer], + weight_col: Optional[Union[str, np.ndarray, pd.Series]] = None, quantiles: Optional[List[float]] = QUANTILES, + train_size: Optional[float] = TRAIN_SIZE, + max_predictors: Optional[int] = None, + random_state: Optional[int] = RANDOM_STATE, ) -> Dict[str, Any] ``` @@ -259,8 +273,12 @@ def progressive_predictor_inclusion( | data | pd.DataFrame | - | Complete dataset | | predictors | List[str] | - | Column names of predictor variables | | imputed_variables | List[str] | - | Column names of variables to impute | -| model_class | Type | - | Imputer class to evaluate | +| model_class | Type[Imputer] | - | Imputer class to evaluate | +| weight_col | str, np.ndarray, or pd.Series | None | Sample weights column name or array | | quantiles | List[float] | [0.05 to 0.95 in steps of 0.05] | Quantiles to evaluate | +| train_size | float | 0.8 | Proportion of data for training | +| max_predictors | int | None | Maximum number of predictors to include (None for all) | +| random_state | int | 42 | Random seed for reproducibility | Returns a dictionary containing `inclusion_order` (list of predictors in optimal order) and `predictor_impacts` (list of dicts with predictor name and loss reduction). @@ -282,7 +300,7 @@ metrics_df = compare_metrics( "QRF": qrf_imputations, "OLS": ols_imputations, }, - imputed_variables=imputed_variables + imputed_variables=imputed_variables, ) # Evaluate distributional match with survey weights @@ -295,7 +313,13 @@ dist_df_weighted = compare_distributions( ) # Analyze predictor importance -mi_scores = compute_predictor_correlations(data, predictors, imputed_variables) -loo_results = leave_one_out_analysis(data, predictors, imputed_variables, QRF) -inclusion_results = progressive_predictor_inclusion(data, predictors, imputed_variables, QRF) +mi_scores = compute_predictor_correlations( + data, predictors, imputed_variables, method="mi" +) +loo_results = leave_one_out_analysis( + data, predictors, imputed_variables, QRF, weight_col="wgt" +) +inclusion_results = progressive_predictor_inclusion( + data, predictors, imputed_variables, QRF, weight_col="wgt" +) ``` diff --git a/docs/imputation-benchmarking/preprocessing.md b/docs/imputation-benchmarking/preprocessing.md index caa352c..736d73e 100644 --- a/docs/imputation-benchmarking/preprocessing.md +++ b/docs/imputation-benchmarking/preprocessing.md @@ -1,12 +1,12 @@ # Data preprocessing -Preprocessing transformations can improve model performance by normalizing scale differences or handling skewed distributions. These are supported by `preprocess_data` and transformation-specific functions. +Preprocessing transformations can improve model performance by normalizing scale differences or handling skewed distributions. The main entry point is `preprocess_data`, with transformation-specific functions also available. ## Transformation options Microimpute supports three transformation types that can be applied to numeric columns before training. Each transformation automatically excludes categorical and boolean columns to prevent encoding issues. -**Normalization (z-score)** standardizes data to have mean 0 and standard deviation 1. This transformation is useful when predictors have different scales, ensuring that all features contribute equally to distance-based or gradient-based models. +**Normalization (z-score)** standardizes data to have mean 0 and standard deviation 1. This is useful when predictors have different scales, so that all features contribute equally to distance-based or gradient-based models. **Log transformation** applies the natural logarithm to values. This is effective for right-skewed distributions common in financial data like income or wealth. The transformation requires all values to be strictly positive. diff --git a/docs/imputation-benchmarking/visualizations.md b/docs/imputation-benchmarking/visualizations.md index d6c9450..21dde02 100644 --- a/docs/imputation-benchmarking/visualizations.md +++ b/docs/imputation-benchmarking/visualizations.md @@ -159,8 +159,10 @@ For quantile loss, the plot shows train and test loss across quantiles as groupe from microimpute.visualizations import model_performance_results # Visualize cross-validation results for a single model +# Pass the full cv_results dict (not just the inner DataFrame) +# to preserve error bar data from cross-validation folds perf_viz = model_performance_results( - results=cv_results["quantile_loss"]["results"], + results=cv_results, model_name="QRF", method_name="Cross-validation", metric="quantile_loss" diff --git a/docs/index.md b/docs/index.md index 354005d..979a4c6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,21 +1,21 @@ # Microimpute -Microimpute is a powerful framework that enables variable imputation through a variety of statistical methods. By providing a consistent interface across different imputation techniques, it allows researchers and data scientists to easily compare and benchmark different approaches using quantile loss and log loss calculations to determine the method provding most accurate results. Thus, Microimpute provides two main uses: imputing one or multiple variables with one of the methods available, and comparing and benchmarking different methods to inform a method's choice. +Microimpute is a Python package for imputing variables from one survey dataset onto another. It wraps five imputation methods behind a common interface so you can benchmark them on your data and pick the one that works best: impute one or multiple variables with any of the available methods, and compare their performance using quantile loss and log loss. -The framework currently supports the following imputation methods: -- Statistical Matching -- Ordinary Least Squares Linear Regression -- Quantile Random Forests +The package currently supports: +- Hot Deck Matching +- Ordinary Least Squares (OLS) Linear Regression +- Quantile Random Forests (QRF) - Quantile Regression -- Mixture Density Networks +- Mixture Density Networks (MDN) -This is a work in progress that may evolve over time, including new statistical imputation methods and features. +This is a work in progress and may evolve over time with new methods and features. ## Microimputation dashboard Users can visualize imputation and benchmarking results at https://microimpute-dashboard.vercel.app/. -To use the dashboard for visualization, CSV files must contain the following columns in this exact order: +To use the dashboard, CSV files must contain the following columns in this exact order: - `type`: Type of metric (e.g., "benchmark_loss", "distribution_distance", "predictor_correlation") - `method`: Imputation method name (e.g., "QRF", "OLS", "QuantReg", "Matching", "MDN") - `variable`: Variable being imputed or analyzed @@ -25,4 +25,4 @@ To use the dashboard for visualization, CSV files must contain the following col - `split`: Data split indicator (e.g., "train", "test", "full") - `additional_info`: JSON-formatted string with additional metadata -Users can use the `format_csv()` function from `microimpute.utils` to automatically format imputation and benchmarking results into the correct structure for dashboard visualization. This function accepts outputs from various analysis functions (autoimpute results, comparison metrics, distribution comparisons, etc.) and returns a properly formatted DataFrame. \ No newline at end of file +The `format_csv()` function from `microimpute.utils` formats imputation and benchmarking results into the correct structure for the dashboard. It accepts outputs from various analysis functions (autoimpute results, comparison metrics, distribution comparisons) and returns a properly formatted DataFrame. diff --git a/docs/models/imputer/implement-new-model.md b/docs/models/imputer/implement-new-model.md index 839da1d..edf2125 100644 --- a/docs/models/imputer/implement-new-model.md +++ b/docs/models/imputer/implement-new-model.md @@ -1,15 +1,15 @@ # Creating a new imputer model -This document demonstrates how to create a new imputation model by extending the `Imputer` and `ImputerResults` abstract base classes in Microimpute. +This document shows how to create a new imputation model by extending the `Imputer` and `ImputerResults` abstract base classes. -## Understanding the Microimpute architecture +## Architecture -Microimpute uses a two-class architecture for imputation models: +Microimpute uses a two-class architecture: -1. **Imputer**: The base model class that handles model initialization and fitting +1. **Imputer**: Handles model initialization and fitting 2. **ImputerResults**: Represents a fitted model and handles prediction -This separation provides a clean distinction between the model definition and the fitted model instance, similar to statsmodels' approach. Remember to check how currently supported models have been implemented if you would like to ensure full compatibility. +This separation is similar to statsmodels' approach. Look at the existing model implementations for reference. ```python from typing import Dict, List, Optional, Any @@ -178,11 +178,9 @@ The new `NewModel` model is then ready to be integrated into the Microimpute ben ```python from microimpute.models import OLS, QRF -from microimpute.comparisons import ( - get_imputations, - compare_quantile_loss, -) -from microimpute.visualizations.plotting import method_comparison_results +from microimpute.comparisons import get_imputations +from microimpute.comparisons.metrics import compare_metrics +from microimpute.visualizations import method_comparison_results # Define models to compare model_classes = [NewModel, OLS, QRF] @@ -195,45 +193,33 @@ method_imputations = get_imputations( model_classes, X_train, X_test, predictors, imputed_variables ) -# Compare quantile loss -loss_comparison_df = compare_quantile_loss(Y_test, method_imputations, imputed_variables) +# Compare metrics across methods +loss_comparison_df = compare_metrics(Y_test, method_imputations, imputed_variables) # Plot the comparison comparison_viz = method_comparison_results( - data=loss_comparison_df, - metric_name="Test Quantile Loss", - data_format="long", - ) -fig = comparison_viz.plot( - show_mean=True, + data=loss_comparison_df, + metric="quantile_loss", + data_format="long", ) +fig = comparison_viz.plot(show_mean=True) fig.show() ``` -## Best practices for implementing new models - -When implementing a new imputation model for Microimpute, adhering to certain best practices will ensure your model integrates seamlessly with the framework and provides a consistent experience for users. +## Best practices ### Architecture -The two-class architecture forms the foundation of a well-designed imputation model. You should create an `Imputer` subclass that handles model definition and fitting operations, establishing the core functionality of your approach. This class should be complemented by an `ImputerResults` subclass that represents the fitted model state and handles all prediction-related tasks. This separation of concerns creates a clean distinction between the fitting and prediction phases of your model's lifecycle. - -Within these classes, you must implement the required abstract methods to fulfill the contract with the base classes. Your `Imputer` subclass should provide a thorough implementation of the `_fit()` method that handles the training process for your specific algorithm. Similarly, your `ImputerResults` subclass needs to implement the `_predict()` method that applies the fitted model to new data and generates predictions at requested quantiles. Check how currently supported models have been implemented if you would like to ensure iterative imputation for multiple target variables fully compatible with how other models do it. For example, to be able to compute quantile loss accross imputed quantiles and variables and compare it with different methods you must ensure iterative imputation within the fitting and predicting methods. +Create an `Imputer` subclass for fitting and an `ImputerResults` subclass for prediction. Implement `_fit()` in the former and `_predict()` in the latter. Look at existing models to see how they handle iterative imputation across multiple target variables, which is needed for cross-method comparison via quantile loss. ### Error handling -Robust error handling is crucial for creating reliable imputation models. Your implementation should wrap model fitting and prediction operations in appropriate try/except blocks to capture and handle potential errors gracefully. When exceptions occur, provide informative error messages that help users understand what went wrong and how to address the issue. Use appropriate error types such as ValueError for input validation failures and RuntimeError for operational failures during model execution. - -Effective logging complements good error handling by providing visibility into the model's operation. Use the self.logger instance consistently throughout your code to record important information about the model's state and progress. Log significant events like the start and completion of fitting operations, parameter values, and any potential issues or warnings that arise during execution. +Wrap fitting and prediction in try/except blocks. Use `ValueError` for bad inputs, `RuntimeError` for operational failures, and include informative messages. Use `self.logger` for logging significant events (fitting start/end, parameter values, warnings). ### Parameters and validation -Type safety and parameter validation enhance the usability and reliability of your model. Add comprehensive type hints to all methods and parameters to enable better IDE support and make your code more self-documenting. Apply the `validate_call` decorator with the standard VALIDATE_CONFIG configuration to method signatures to enforce parameter validation consistently. - -Your implementation should thoughtfully support model-specific parameters that may be needed to control the behavior of your algorithm. Design your `_fit()` and `_predict()` methods to accept and properly utilize these parameters, ensuring they affect the model's operation as intended. Document all parameters clearly in your docstrings, explaining their purpose, expected values, and default behavior to guide users in effectively configuring your model. - -### Documentation +Add type hints to all methods. Use the `@validate_call(config=VALIDATE_CONFIG)` decorator for parameter validation. Document model-specific parameters in docstrings with their purpose, expected values, and defaults. -Comprehensive documentation makes your model accessible to others. Include detailed class-level docstrings that explain your model's theoretical approach, strengths, limitations, and appropriate use cases. Document all methods with properly structured docstring sections covering arguments, return values, and potential exceptions. Where appropriate, provide usage examples that demonstrate how to initialize, train, and use your model for prediction tasks. +### Testing -The documentation should be complemented by thorough unit tests that verify your implementation works correctly. Create tests that check both basic interface compliance (ensuring your model adheres to the expected API) and model-specific functionality (validating that your algorithm produces correct results). Comprehensive testing helps catch issues early and provides confidence that your implementation will work reliably in production environments. +Write tests for both interface compliance (does your model follow the expected API?) and model-specific correctness (does it produce sensible results?). diff --git a/docs/models/imputer/index.md b/docs/models/imputer/index.md index 3cf0dc0..26e20d8 100644 --- a/docs/models/imputer/index.md +++ b/docs/models/imputer/index.md @@ -1,11 +1,11 @@ # The Imputer class -The `Imputer` class serves as an abstract base class that defines the common interface for all imputation models within the Microimpute framework. It establishes a structure with essential methods for data validation, model fitting, and prediction. Every specialized imputation model in the system inherits from this class and implements the required abstract methods to provide its unique functionality. +The `Imputer` class is the abstract base class that defines the common interface for all imputation models in microimpute. Every model inherits from it and implements the required abstract methods for fitting and prediction. ## Key features -The Imputer architecture provides numerous benefits to the overall system design. It defines a consistent API with standardized `fit()` and `predict()` methods, ensuring that all models can be used interchangeably regardless of their underlying implementation details. This uniformity makes it straightforward to swap imputation techniques within your workflow. Thus, all imputers will share basic functionality like the handling of weighted data, using a "weights" column for sampling training data, to preserve data distributions better. +All models share standardized `fit()` and `predict()` methods, so they can be used interchangeably regardless of underlying implementation. All imputers also share functionality like weighted data handling through a `weight_col` parameter. -The design carefully enforces proper usage by ensuring no model can call `predict()` without first fitting the model to the data. This logical constraint helps prevent common errors and makes the API more intuitive to use. Additionally, the base implementation handles validation of parameters and input data, reducing code duplication across different model implementations and ensuring that all models perform appropriate validation checks. +The design enforces that `predict()` cannot be called before `fit()`. The base implementation also handles parameter and input data validation, so individual models don't need to duplicate those checks. -When using the different imputers in isolation, and not as part of wider pipeline functions like `autoimpute` preprocessing is supported by `preprocess_data` which can help normalize the data and split it into train and test splits. For an example of how to integrate them see [matching-imputation.ipynb](../matching/matching-imputation.ipynb). +When using imputers in isolation (not through `autoimpute`), preprocessing is available via `preprocess_data`, which can normalize the data and split it into train/test sets. See [matching-imputation.ipynb](../matching/matching-imputation.ipynb) for an example. diff --git a/docs/models/matching/index.md b/docs/models/matching/index.md index 849d851..ca58c4f 100644 --- a/docs/models/matching/index.md +++ b/docs/models/matching/index.md @@ -1,27 +1,23 @@ # Hot-Deck Matching -The `Matching` model implements imputation through an elegant nearest neighbor distance hot deck matching approach. This technique draws from the principles of statistical matching, using existing complete records (donors) to provide values for records with missing data (recipients) by establishing meaningful connections based on similarities in predictor variables. +The `Matching` model imputes missing values using nearest neighbor distance hot deck matching. It finds donor records that are similar to each recipient based on predictor variables and transfers the donor's observed values. ## Variable type support -The matching model can handle any variable type—numerical, categorical, boolean, or mixed. Since it transfers actual observed values from similar records rather than generating predictions, it naturally preserves the original data type and distribution of each variable. +Matching handles any variable type: numerical, categorical, boolean, or mixed. Because it transfers actual observed values rather than generating predictions, it preserves the original data type and distribution of each variable. ## How it works -Statistical or hot-deck matching in Microimpute builds upon the foundation of R's StatMatch package, accessed through the rpy2 interface to provide a seamless integration of R's statistical power with Python's flexibility. The implementation leverages the well-established nearest neighbor distance hot deck matching algorithm, which has a strong theoretical foundation in statistical literature. +The implementation builds on R's StatMatch package, accessed through the rpy2 interface. -During the fitting phase, the model carefully preserves both the complete donor dataset and the relevant variable names that will guide the matching process. This stored information becomes the knowledge base from which the model will draw when making imputations. +During fitting, the model stores the complete donor dataset and the relevant variable names. During prediction, each record in the test dataset (the recipients) is compared against the stored donors using distance calculations on the predictor variables. The algorithm finds the closest donor for each recipient and transfers the target variable values. -The prediction stage initiates a deliberate matching process where each record in the test dataset (the recipients) is systematically compared with the stored donor records. The comparison calculates similarity distances based on the predictor variables, identifying the donor records that most closely resemble each recipient. The matching algorithm efficiently navigates the multidimensional space defined by the predictor variables to find optimal donor-recipient pairs. - -Once the matching is complete, the model transfers the values from the matched donors to the recipients for the specified imputed variable. This transfer preserves the natural relationships and patterns present in the original data, as the values being imputed were actually observed rather than synthetically generated. +Because the imputed values are drawn from actually observed records, the natural relationships in the original data are preserved. ## Key features -The statistical matching imputer offers a truly non-parametric approach that operates without imposing restrictive assumptions about the underlying data distribution. This distribution-free nature makes it particularly valuable in scenarios where the data doesn't conform to common statistical assumptions or when the relationships are too complex to model parametrically. - -One of the most compelling advantages of this method is its ability to preserve the empirical distribution of the imputed variables. Since the imputed values come directly from observed data points, the resulting dataset maintains the natural structure, variability, and relationships present in the original data. This preservation extends to features like multimodality, skewness, and natural bounds that might be lost in model-based approaches. +Matching is non-parametric: it makes no assumptions about the data distribution. This makes it useful when the data doesn't fit standard parametric models, or when the relationships between predictors and targets are hard to specify in closed form. -The technique demonstrates versatility in handling complex relationships between variables, particularly when there exists a good match across datasets. Without requiring explicit specification of interaction terms or functional forms, it naturally captures the intricate dependencies that exist in the data through the matching process. This makes it especially valuable for datasets where the relationships are not well understood or are difficult to express mathematically. +The method preserves the empirical distribution of the imputed variables. Since values come directly from observed data points, features like multimodality, skewness, and natural bounds are maintained. A model-based approach might smooth these away. -Perhaps most distinctively, the statistical matching approach returns actual observed values rather than modeled estimates. This characteristic ensures that the imputed values are realistic and plausible, as they represent real observations from similar data points. The method essentially says, "We have seen this pattern before, and here's what the missing values looked like in that situation," providing a grounded approach to filling in missing information. +One limitation is that Matching does not incorporate quantile information. It matches donor and receiver units identically regardless of the quantile being predicted, which means it cannot distinguish between different parts of the conditional distribution. It may also fail to capture non-linear predictor-target relationships despite producing a plausible marginal distribution. diff --git a/docs/models/matching/matching-imputation.ipynb b/docs/models/matching/matching-imputation.ipynb index 1cdace2..df9d7d0 100644 --- a/docs/models/matching/matching-imputation.ipynb +++ b/docs/models/matching/matching-imputation.ipynb @@ -77,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -161,7 +161,7 @@ " -0.002592\n", " 0.019907\n", " -0.017646\n", - " False\n", + " True\n", " 1\n", " \n", " \n", @@ -175,7 +175,7 @@ " -0.039493\n", " -0.068332\n", " -0.092204\n", - " True\n", + " False\n", " 2\n", " \n", " \n", @@ -189,7 +189,7 @@ " -0.002592\n", " 0.002861\n", " -0.025930\n", - " False\n", + " True\n", " 3\n", " \n", " \n", @@ -245,7 +245,7 @@ " -0.002592\n", " 0.031193\n", " 0.007207\n", - " True\n", + " False\n", " 438\n", " \n", " \n", @@ -273,7 +273,7 @@ " -0.011080\n", " -0.046883\n", " 0.015491\n", - " False\n", + " True\n", " 440\n", " \n", " \n", @@ -287,7 +287,7 @@ " 0.026560\n", " 0.044529\n", " -0.025930\n", - " True\n", + " False\n", " 441\n", " \n", " \n", @@ -301,7 +301,7 @@ " -0.039493\n", " -0.004222\n", " 0.003064\n", - " True\n", + " False\n", " 442\n", " \n", " \n", @@ -311,22 +311,22 @@ ], "text/plain": [ " age sex bmi bp s1 ... s4 s5 s6 bool wgt\n", - "0 0.038076 0.050680 0.061696 0.021872 -0.044223 ... -0.002592 0.019907 -0.017646 False 1\n", - "1 -0.001882 -0.044642 -0.051474 -0.026328 -0.008449 ... -0.039493 -0.068332 -0.092204 True 2\n", - "2 0.085299 0.050680 0.044451 -0.005670 -0.045599 ... -0.002592 0.002861 -0.025930 False 3\n", + "0 0.038076 0.050680 0.061696 0.021872 -0.044223 ... -0.002592 0.019907 -0.017646 True 1\n", + "1 -0.001882 -0.044642 -0.051474 -0.026328 -0.008449 ... -0.039493 -0.068332 -0.092204 False 2\n", + "2 0.085299 0.050680 0.044451 -0.005670 -0.045599 ... -0.002592 0.002861 -0.025930 True 3\n", "3 -0.089063 -0.044642 -0.011595 -0.036656 0.012191 ... 0.034309 0.022688 -0.009362 True 4\n", "4 0.005383 -0.044642 -0.036385 0.021872 0.003935 ... -0.002592 -0.031988 -0.046641 False 5\n", ".. ... ... ... ... ... ... ... ... ... ... ...\n", - "437 0.041708 0.050680 0.019662 0.059744 -0.005697 ... -0.002592 0.031193 0.007207 True 438\n", + "437 0.041708 0.050680 0.019662 0.059744 -0.005697 ... -0.002592 0.031193 0.007207 False 438\n", "438 -0.005515 0.050680 -0.015906 -0.067642 0.049341 ... 0.034309 -0.018114 0.044485 False 439\n", - "439 0.041708 0.050680 -0.015906 0.017293 -0.037344 ... -0.011080 -0.046883 0.015491 False 440\n", - "440 -0.045472 -0.044642 0.039062 0.001215 0.016318 ... 0.026560 0.044529 -0.025930 True 441\n", - "441 -0.045472 -0.044642 -0.073030 -0.081413 0.083740 ... -0.039493 -0.004222 0.003064 True 442\n", + "439 0.041708 0.050680 -0.015906 0.017293 -0.037344 ... -0.011080 -0.046883 0.015491 True 440\n", + "440 -0.045472 -0.044642 0.039062 0.001215 0.016318 ... 0.026560 0.044529 -0.025930 False 441\n", + "441 -0.045472 -0.044642 -0.073030 -0.081413 0.083740 ... -0.039493 -0.004222 0.003064 False 442\n", "\n", "[442 rows x 12 columns]" ] }, - "execution_count": 7, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -347,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -477,7 +477,7 @@ "max 1.107267e-01 5.068012e-02 1.705552e-01 1.320436e-01 1.539137e-01 1.852344e-01 442.000000" ] }, - "execution_count": 8, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -501,7 +501,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -540,7 +540,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -643,7 +643,7 @@ "73 0.012648 0.050680 -0.020218 -0.002228 NaN NaN NaN 74" ] }, - "execution_count": 10, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -672,7 +672,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -691,7 +691,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -710,7 +710,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -744,31 +744,31 @@ " 287\n", " 0.024574\n", " -0.039493\n", - " True\n", + " False\n", " \n", " \n", " 211\n", " 0.030078\n", " -0.039493\n", - " True\n", + " False\n", " \n", " \n", " 72\n", " 0.038334\n", " -0.039493\n", - " False\n", + " True\n", " \n", " \n", " 321\n", " -0.013953\n", " -0.002592\n", - " True\n", + " False\n", " \n", " \n", " 73\n", " -0.031840\n", " -0.039493\n", - " False\n", + " True\n", " \n", " \n", "\n", @@ -776,14 +776,14 @@ ], "text/plain": [ " s1 s4 bool\n", - "287 0.024574 -0.039493 True\n", - "211 0.030078 -0.039493 True\n", - "72 0.038334 -0.039493 False\n", - "321 -0.013953 -0.002592 True\n", - "73 -0.031840 -0.039493 False" + "287 0.024574 -0.039493 False\n", + "211 0.030078 -0.039493 False\n", + "72 0.038334 -0.039493 True\n", + "321 -0.013953 -0.002592 False\n", + "73 -0.031840 -0.039493 True" ] }, - "execution_count": 13, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -808,7 +808,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -826,22 +826,22 @@ "x": [ 0.12501870313429186, 0.03430885887772673, - false, + true, -0.0249601584096303, -0.03949338287409329, - false, + true, 0.10300345740307394, -0.002592261998183278, true, 0.05484510736603471, 0.14132210941786577, - true, + false, 0.03833367306762126, 0.03430885887772673, - false, + true, 0.09887559882847057, -0.002592261998183278, - true, + false, 0.030077955918414535, 0.03430885887772673, true, @@ -859,7 +859,7 @@ false, -0.06761469701386505, -0.002592261998183278, - true, + false, -0.05523112129005496, -0.0763945037500033, true, @@ -871,40 +871,40 @@ true, -0.07311850844666953, -0.03949338287409329, - false, + true, 0.03833367306762126, 0.03430885887772673, false, 0.020446285911006685, 0.07120997975363674, - false, + true, -0.12678066991651324, -0.047980640675552584, true, 0.013566521620001083, 0.03430885887772673, - false, + true, -0.012576582685820214, -0.002592261998183278, - false, + true, 0.045213437358626866, -0.002592261998183278, false, -0.007072771253015731, -0.03949338287409329, - true, + false, 0.016318427336403322, -0.002592261998183278, false, -0.009824676969417972, -0.03949338287409329, - true, + false, -0.030463969842434782, -0.002592261998183278, false, -0.02220825269322806, -0.002592261998183278, - false, + true, -0.051103262715451604, 0.03430885887772673, false, @@ -913,52 +913,52 @@ false, 0.0342058144930179, -0.002592261998183278, - true, + false, 0.0080627101871966, 0.021024455362399115, - true, + false, 0.0025588987543921156, -0.002592261998183278, false, 0.0025588987543921156, -0.002592261998183278, - false, + true, -0.016704441260423575, 0.03430885887772673, true, 0.045213437358626866, 0.03615391492152222, - false, + true, 0.07823630595545376, -0.002592261998183278, - false, + true, -0.011200629827619093, -0.002592261998183278, false, 0.03145390877661565, 0.019917421736121838, - false, + true, 0.024574144485610048, 0.03430885887772673, false, -0.001568959820211247, -0.03949338287409329, - false, + true, -0.00019300696201012598, -0.03949338287409329, - false, + true, -0.00019300696201012598, -0.05056371913686628, - true, + false, -0.06623874415566393, -0.002592261998183278, - false, + true, -0.004320865536613489, 0.07120997975363674, false, 0.04383748450042574, -0.014400620678474476, - true, + false, 0.03282986163481677, -0.03949338287409329, false, @@ -967,61 +967,61 @@ false, -0.04422349842444599, -0.0763945037500033, - true, + false, -0.035967781275239266, -0.05167075276314359, - false, + true, -0.007072771253015731, -0.002592261998183278, - false, + true, -0.07311850844666953, -0.0763945037500033, - false, + true, -0.019456346976825818, -0.03949338287409329, true, -0.007072771253015731, 0.07120997975363674, - true, + false, -0.008448724111216851, -0.03949338287409329, - false, + true, 0.08924392882106273, 0.10811110062954676, - true, + false, -0.0249601584096303, -0.03949338287409329, true, 0.03282986163481677, -0.002592261998183278, - true, + false, -0.04422349842444599, -0.002592261998183278, - false, + true, -0.0029449126784123676, -0.03949338287409329, false, -0.033215875558837024, -0.0763945037500033, - true, + false, 0.08236416453005713, 0.07120997975363674, - false, + true, -0.0318399227006359, 0.0029429061332032365, false, -0.04972730985725048, -0.03949338287409329, - false, + true, 0.010814615903598841, -0.03949338287409329, - false, + true, -0.005696818394814609, 0.03430885887772673, false, 0.06172487165704031, -0.002592261998183278, - false, + true, 0.05622106022423583, 0.07120997975363674, true, @@ -1039,13 +1039,13 @@ true, 0.039709625925822375, 0.07120997975363674, - true, + false, 0.045213437358626866, 0.07120997975363674, - true, + false, -0.04972730985725048, 0.01585829843977173, - true, + false, -0.026336111267831423, -0.03949338287409329, false, @@ -1054,10 +1054,10 @@ false, 0.08511607024645937, 0.03430885887772673, - false, + true, 0.016318427336403322, 0.02655962349378563, - true, + false, 0.020446285911006685, -0.002592261998183278, true, @@ -1069,16 +1069,16 @@ true, -0.046975404140848234, -0.03949338287409329, - true, + false, -0.0029449126784123676, -0.047242618258034386, - false, + true, 0.04658939021682799, -0.03949338287409329, - true, + false, -0.007072771253015731, -0.03949338287409329, - true, + false, -0.030463969842434782, -0.0763945037500033, true, @@ -1095,19 +1095,19 @@ "y": [ 0.024574144485610048, -0.03949338287409329, - true, + false, 0.030077955918414535, -0.03949338287409329, - true, + false, 0.03833367306762126, -0.03949338287409329, - false, + true, -0.013952535544021335, -0.002592261998183278, - true, + false, -0.0318399227006359, -0.03949338287409329, - false, + true, 0.04246153164222462, -0.0763945037500033, true, @@ -1116,25 +1116,25 @@ true, -0.062110885581060565, 0.026928634702544724, - false, + true, -0.04284754556624487, -0.002592261998183278, false, -0.005696818394814609, -0.03949338287409329, - false, + true, -0.015328488402222454, -0.002592261998183278, false, -0.001568959820211247, -0.03949338287409329, - true, + false, -0.07587041416307178, -0.0763945037500033, - true, + false, -0.005696818394814609, -0.002592261998183278, - true, + false, -0.007072771253015731, -0.03949338287409329, false, @@ -1149,13 +1149,13 @@ false, 0.039709625925822375, 0.10811110062954676, - true, + false, -0.009824676969417972, 0.03430885887772673, false, 0.01219056876179996, -0.03949338287409329, - false, + true, -0.034591828417038145, -0.0763945037500033, true, @@ -1173,22 +1173,22 @@ true, 0.024574144485610048, 0.15534453535071155, - true, + false, -0.035967781275239266, 0.07120997975363674, - true, + false, -0.037343734133440394, -0.03949338287409329, false, -0.0318399227006359, -0.03949338287409329, - false, + true, -0.004320865536613489, -0.0011162171631468765, - true, + false, -0.07587041416307178, -0.0763945037500033, - true, + false, 0.06998058880624704, 0.07120997975363674, true, @@ -1197,7 +1197,7 @@ false, -0.02358420555142918, -0.03949338287409329, - false, + true, 0.06034891879883919, 0.10811110062954676, false, @@ -1209,22 +1209,22 @@ false, 0.0025588987543921156, -0.03949338287409329, - true, + false, -0.0579830270064572, -0.03949338287409329, - true, + false, -0.001568959820211247, -0.03949338287409329, - true, + false, -0.05935897986465832, 0.012906208769698923, - false, + true, -0.037343734133440394, -0.011079519799642579, - false, + true, -0.07036660273026729, -0.002592261998183278, - false, + true, 0.027326050202012293, -0.03949338287409329, false, @@ -1233,37 +1233,37 @@ false, -0.0579830270064572, -0.03949338287409329, - false, + true, 0.001182945896190995, 0.03430885887772673, - false, + true, -0.001568959820211247, -0.03949338287409329, - true, + false, 0.010814615903598841, -0.03949338287409329, - false, + true, 0.024574144485610048, 0.05091436327188625, false, -0.060734932722859444, -0.0763945037500033, - true, + false, -0.007072771253015731, 0.03430885887772673, - false, + true, 0.00943866304539772, -0.002592261998183278, false, 0.039709625925822375, 0.10811110062954676, - true, + false, 0.11951489170148738, 0.08670845052151895, - false, + true, -0.04422349842444599, -0.03949338287409329, - true, + false, 0.001182945896190995, 0.03430885887772673, false, @@ -1272,7 +1272,7 @@ false, 0.08374011738825825, -0.03949338287409329, - true, + false, -0.009824676969417972, -0.002592261998183278, false, @@ -1284,31 +1284,31 @@ true, -0.004320865536613489, -0.002592261998183278, - true, + false, -0.015328488402222454, -0.002592261998183278, false, 0.01219056876179996, 0.07120997975363674, - false, + true, 0.039709625925822375, 0.10811110062954676, - true, + false, 0.05071724879143135, 0.03430885887772673, true, -0.0579830270064572, -0.03949338287409329, - false, + true, -0.02083229983502694, 0.07120997975363674, false, -0.037343734133440394, -0.002592261998183278, - true, + false, -0.0029449126784123676, 0.07120997975363674, - true, + false, 0.03558176735121902, -0.0763945037500033, true, @@ -1317,16 +1317,16 @@ false, -0.0318399227006359, -0.03949338287409329, - false, + true, -0.033215875558837024, -0.002592261998183278, false, 0.017694380194604446, 0.03430885887772673, - false, + true, -0.016704441260423575, -0.002592261998183278, - true, + false, -0.04284754556624487, -0.002592261998183278, true, @@ -1338,28 +1338,28 @@ false, 0.010814615903598841, -0.03949338287409329, - false, + true, 0.08374011738825825, -0.03949338287409329, - true, + false, -0.04422349842444599, -0.03949338287409329, - true, + false, -0.0029449126784123676, -0.002592261998183278, false, 0.03145390877661565, -0.03949338287409329, - false, + true, -0.06623874415566393, -0.03949338287409329, - true, + false, 0.001182945896190995, -0.007020396503292483, - true, + false, 0.01219056876179996, -0.03949338287409329, - false + true ] }, { @@ -2363,7 +2363,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -2487,7 +2487,7 @@ "[5 rows x 20 columns]" ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -2518,7 +2518,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -3976,7 +3976,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -3997,15 +3997,15 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 3.3s\n", - "[Parallel(n_jobs=1)]: Done 4 tasks | elapsed: 12.8s\n" + "[Parallel(n_jobs=1)]: Done 1 tasks | elapsed: 0.1s\n", + "[Parallel(n_jobs=1)]: Done 4 tasks | elapsed: 0.2s\n" ] }, { @@ -4028,7 +4028,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 16.0s finished\n" + "[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 0.3s finished\n" ] } ], @@ -4051,7 +4051,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -4062,8 +4062,32 @@ }, "data": [ { + "error_y": { + "array": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + "type": "data" + }, "marker": { - "color": "#00CC96" + "color": "#DDCC77" }, "name": "Train", "type": "bar", @@ -4111,8 +4135,32 @@ ] }, { + "error_y": { + "array": [ + 0.0023955551235670214, + 0.002189517306342527, + 0.001998045402184313, + 0.0018257279461347798, + 0.0016784747085066194, + 0.0015633841084603756, + 0.0014879382056280817, + 0.001458303139487877, + 0.0014772385383559053, + 0.0015429572455057856, + 0.0016498782177496946, + 0.0017906358968372744, + 0.0019579462215216135, + 0.0021456065826219173, + 0.0023487442428032587, + 0.002563682729688605, + 0.0027876937544495414, + 0.003018758272598448, + 0.003255374672423644 + ], + "type": "data" + }, "marker": { - "color": "#AB63FA" + "color": "#117733" }, "name": "Test", "type": "bar", @@ -4175,8 +4223,8 @@ "r": 50, "t": 80 }, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "template": { "data": { "bar": [ @@ -4998,14 +5046,22 @@ }, "width": 750, "xaxis": { + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", "showgrid": false, + "showline": true, "title": { "text": "Quantile" }, "zeroline": false }, "yaxis": { - "showgrid": false, + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true, "title": { "text": "Average Quantile Loss" }, @@ -5022,7 +5078,7 @@ "# Plot the results for numerical variables\n", "if \"quantile_loss\" in matching_results:\n", " perf_results_viz = model_performance_results(\n", - " results=matching_results[\"quantile_loss\"][\"results\"],\n", + " results=matching_results,\n", " model_name=\"Matching\",\n", " method_name=\"Cross-validation quantile loss average\",\n", " )\n", @@ -5050,7 +5106,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -5071,7 +5127,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -5086,7 +5142,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -5107,14 +5163,14 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'dist_fun': 'Gower', 'constrained': True, 'constr_alg': 'hungarian', 'k': 5}\n" + "{'dist_fun': 'Manhattan', 'k': 6}\n" ] } ], @@ -5133,7 +5189,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pe3.13", + "display_name": "pe3.13 (3.13.0)", "language": "python", "name": "python3" }, diff --git a/docs/models/mdn/index.md b/docs/models/mdn/index.md index c48797c..4f83612 100644 --- a/docs/models/mdn/index.md +++ b/docs/models/mdn/index.md @@ -1,6 +1,6 @@ # Mixture Density Network -The `MDN` model uses deep neural networks with mixture density outputs to predict missing values by learning complex, potentially multi-modal conditional distributions. Built on PyTorch Tabular, this approach combines the flexibility of neural networks with the probabilistic richness of Gaussian mixture models. +The `MDN` model uses neural networks with mixture density outputs to predict missing values. Built on PyTorch Tabular, it learns conditional distributions as mixtures of Gaussians, which lets it capture multi-modal relationships. ## Variable type support @@ -16,7 +16,7 @@ The model supports automatic caching based on data hashes, avoiding redundant re ## Key features -MDN offers several advantages for complex imputation tasks. The mixture density approach can model multi-modal distributions that simpler methods cannot capture, making it suitable for variables with complex conditional distributions. The neural network backbone can learn non-linear relationships without requiring explicit feature engineering. +MDN can model multi-modal distributions that simpler methods cannot capture, making it suited for variables with complex conditional distributions. The neural network backbone learns non-linear relationships without explicit feature engineering. Training leverages GPU acceleration when available and includes early stopping to prevent overfitting. The automatic model caching system speeds up repeated analyses on the same dataset. diff --git a/docs/models/mdn/mdn-imputation.ipynb b/docs/models/mdn/mdn-imputation.ipynb index 674f716..824d78e 100644 --- a/docs/models/mdn/mdn-imputation.ipynb +++ b/docs/models/mdn/mdn-imputation.ipynb @@ -115,16 +115,7 @@ "execution_count": 1, "id": "cell-2", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Error importing in API mode: ImportError(\"dlopen(/Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so, 0x0002): Library not loaded: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib\\n Referenced from: <21BE8260-F4D5-3597-9DD0-6953BC4DDF3D> /Users/movil1/envs/pe3.13/lib/python3.13/site-packages/_rinterface_cffi_api.abi3.so\\n Reason: tried: '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file), '/Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.dylib' (no such file)\")\n", - "Trying to import in ABI mode.\n" - ] - } - ], + "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", @@ -502,28 +493,28 @@ " \n", " \n", " 287\n", - " -0.239280\n", - " -0.307822\n", + " 0.577170\n", + " 0.438805\n", " \n", " \n", " 211\n", - " 0.339311\n", - " 0.358265\n", + " 0.481174\n", + " 0.488350\n", " \n", " \n", " 72\n", - " -0.276636\n", - " -0.277714\n", + " -0.721495\n", + " -0.631198\n", " \n", " \n", " 321\n", - " 0.540901\n", - " 0.551054\n", + " -1.211108\n", + " -1.460021\n", " \n", " \n", " 73\n", - " 0.460601\n", - " 0.446176\n", + " -1.162974\n", + " -1.194782\n", " \n", " \n", "\n", @@ -531,11 +522,11 @@ ], "text/plain": [ " s1 s4\n", - "287 -0.239280 -0.307822\n", - "211 0.339311 0.358265\n", - "72 -0.276636 -0.277714\n", - "321 0.540901 0.551054\n", - "73 0.460601 0.446176" + "287 0.577170 0.438805\n", + "211 0.481174 0.488350\n", + "72 -0.721495 -0.631198\n", + "321 -1.211108 -1.460021\n", + "73 -1.162974 -1.194782" ] }, "execution_count": 6, @@ -769,184 +760,184 @@ ], "xaxis": "x", "y": [ - 0.15996141731739044, - 0.032140739262104034, - 0.023270580917596817, - -0.07243411242961884, - 0.06383831799030304, - 0.0471138134598732, - 0.08792641758918762, - 0.183455690741539, - 0.034883130341768265, - -0.020984657108783722, - 0.1328243613243103, - 0.002175786066800356, - 0.8124674558639526, - 0.9326153993606567, - 0.1008240208029747, - 0.06365294754505157, - 0.3842753469944, - 0.46691128611564636, - 0.017988024279475212, - -0.016335798427462578, - 0.06642946600914001, - -0.008764325641095638, - -0.0118333101272583, - 0.010635918006300926, - -0.012584746815264225, - -0.0853499099612236, - -0.002956878859549761, - -0.024943694472312927, - -0.033975280821323395, - -0.07795815169811249, - -0.07504265010356903, - -0.03126005083322525, - 0.03940160199999809, - 0.053814880549907684, - -0.04383951798081398, - 0.019542427733540535, - -0.09670525044202805, - -0.0822581872344017, - 0.06511250138282776, - 0.09403425455093384, - 0.00344745977781713, - -0.00915258377790451, - 0.0950278714299202, - -0.048670198768377304, - 0.039458371698856354, - -0.07556705176830292, - -0.006443846970796585, - 0.020815376192331314, - -0.036735083907842636, - -0.044321704655885696, - -0.08580169081687927, - -0.059006646275520325, - -0.019963562488555908, - 0.009087007492780685, - -0.130485400557518, - -0.0006790757179260254, - -0.037002887576818466, - -0.09501223266124725, - 0.012499166652560234, - 0.02030245214700699, - 0.014375532045960426, - -0.02604287676513195, - 0.057647883892059326, - 0.05878327041864395, - -0.04572222754359245, - -0.03820172697305679, - -0.06333865225315094, - 0.04185344651341438, - 0.05389529839158058, - 0.01455413457006216, - 0.09297429770231247, - -0.0037352601066231728, - -0.014067748561501503, - 0.026063010096549988, - 0.007550985552370548, - -0.025209808722138405, - -0.04939274489879608, - -0.030361073091626167, - 0.03330788016319275, - -0.014227045699954033, - 0.0016633199993520975, - -0.08225220441818237, - 0.010647851973772049, - -0.08228402584791183, - -0.07383066415786743, - 0.061680346727371216, - -0.5339921116828918, - -0.6321333646774292, - 0.025442680343985558, - -0.022219836711883545, - -0.026420146226882935, - -0.10970157384872437, - -0.041401173919439316, - -0.0358879491686821, - -0.003495529294013977, - -0.04557208716869354, - 0.013635829091072083, - -0.043398745357990265, - 0.026523813605308533, - 0.00031686751754023135, - -0.09354585409164429, - -0.05077875778079033, - -0.03905845806002617, - -0.02609078772366047, - -0.006901499815285206, - 0.05072978138923645, - -0.0048773703165352345, - -0.01315099187195301, - 0.083067886531353, - 0.11018361896276474, - -0.042783450335264206, - -0.042475901544094086, - 0.007468030788004398, - 0.03531891107559204, - -0.09278491139411926, - -0.008651846088469028, - 0.03183288499712944, - -0.08680566400289536, - -0.01793188974261284, - -0.08094043284654617, - 0.08602085709571838, - 0.06315568834543228, - -0.020230960100889206, - 0.003904549404978752, - -0.03103073686361313, - -0.06506631523370743, - 0.036732763051986694, - -0.07117997109889984, - -0.0006138663738965988, - 0.03311414271593094, - 0.06245873123407364, - 0.002597265876829624, - 0.01770462840795517, - 0.058917708694934845, - -0.06402486562728882, - 0.048038020730018616, - -0.03813735395669937, - -0.028543436899781227, - -0.0075743068009614944, - 0.03675296530127525, - -0.014315500855445862, - 0.040746308863162994, - 0.07579956203699112, - 0.1126079186797142, - 0.08204692602157593, - 0.07413098216056824, - -0.0642552450299263, - 0.002614614088088274, - -0.033612947911024094, - -0.015575357712805271, - 0.08901567757129669, - 0.12694990634918213, - 0.12089148908853531, - 0.0254050400108099, - -0.005140587687492371, - 0.06322228908538818, - 0.017512474209070206, - 0.013693263754248619, - 0.059920668601989746, - 0.04606965184211731, - -0.03190714493393898, - -0.01763644628226757, - -0.03793055936694145, - -0.0012233610032126307, - 0.012116450816392899, - -0.026453981176018715, - 0.0048493873327970505, - -0.008992618881165981, - 0.05157541483640671, - -0.028200700879096985, - -0.05151493847370148, - -0.05033167079091072, - -0.11441226303577423, - -0.11985128372907639, - -0.07205710560083389, - 0.0047546266578137875, - -0.018449068069458008, - -0.006497040390968323 + 0.07409832626581192, + -0.0053354548290371895, + -0.12105628103017807, + -0.014382392168045044, + 0.04452350735664368, + -0.07941135764122009, + 0.04921489953994751, + 0.07562752068042755, + 0.08775658905506134, + -0.003818345023319125, + 0.12863661348819733, + 0.05178040266036987, + 0.04000586271286011, + -0.16531065106391907, + 0.06332286447286606, + 0.08498921245336533, + -0.010948514565825462, + -0.14337725937366486, + 0.17172245681285858, + -0.07638034224510193, + 0.033117081969976425, + -0.0034198835492134094, + -0.007393532898277044, + -0.026365965604782104, + -0.12046261131763458, + -0.027624964714050293, + 0.029128486290574074, + -0.06916160881519318, + -0.05742168426513672, + -0.10975533723831177, + 0.020898770540952682, + -0.1172960177063942, + -0.04515509307384491, + 0.026763781905174255, + 0.14253804087638855, + 0.2255028486251831, + -0.24605007469654083, + -0.03075171262025833, + -0.09652869403362274, + 0.20666047930717468, + -0.014879127964377403, + 0.024144992232322693, + -0.09012438356876373, + 0.022548293694853783, + -0.05704033374786377, + -0.11739443987607956, + -0.06362107396125793, + -0.02206820249557495, + -0.03970325365662575, + -0.08100323379039764, + -0.09121052920818329, + -0.01153072714805603, + 0.09131903946399689, + -0.05644859001040459, + -0.1356145739555359, + 0.13108958303928375, + -0.08045151084661484, + -0.0677538514137268, + -0.013112485408782959, + 0.03593648970127106, + -0.09408782422542572, + 0.10447344928979874, + 0.026929795742034912, + 0.004074967000633478, + -0.06840845197439194, + -0.020132899284362793, + -0.03740965574979782, + 0.01231334824115038, + 0.13786044716835022, + 0.061320580542087555, + 0.03548247739672661, + 0.15061387419700623, + 0.0036609352100640535, + -0.02198277786374092, + 0.13655224442481995, + -0.08470708131790161, + 0.0033969569485634565, + 0.11218808591365814, + 0.15245385468006134, + -0.05473918840289116, + -0.01783975213766098, + -0.10014862567186356, + 0.02230358123779297, + -0.1551317274570465, + -0.021243643015623093, + 0.024414116516709328, + -0.020500551909208298, + 0.14334836602210999, + 0.09190063178539276, + -0.06325235962867737, + 0.034143995493650436, + -0.010623842477798462, + 0.017537854611873627, + -0.021074065938591957, + -0.11523549258708954, + -0.025434572249650955, + 0.018908560276031494, + 0.012480437755584717, + 0.02391844429075718, + 0.04059349000453949, + -0.04446566849946976, + -0.07962190359830856, + -0.01866019144654274, + -0.09491998702287674, + -0.10590796172618866, + 0.11002372950315475, + -0.045452192425727844, + 0.048709649592638016, + 0.16254602372646332, + -0.23081794381141663, + 0.033685214817523956, + -0.0006430193898268044, + -0.015713322907686234, + 0.026455357670783997, + 0.06460753083229065, + 0.09441250562667847, + -0.008670782670378685, + -0.08653200417757034, + -0.07236170768737793, + -0.09426410496234894, + 0.14369378983974457, + -0.03581035137176514, + 0.008754465728998184, + 0.046415455639362335, + 0.03268638253211975, + 0.06044789403676987, + 0.08978655189275742, + -0.04920205473899841, + -0.1123436689376831, + 0.0034600854851305485, + 0.02713024616241455, + 0.003619933035224676, + -0.02590913698077202, + -0.09626477211713791, + -0.026048922911286354, + 0.01121927797794342, + -0.0919136255979538, + -0.04032476991415024, + -0.04339555650949478, + 0.041690241545438766, + 0.09366445243358612, + 0.16090711951255798, + 0.12886887788772583, + 0.024945732206106186, + 0.10310709476470947, + 0.09357581287622452, + -0.024438243359327316, + -0.02837105467915535, + -0.10025747120380402, + -0.15249663591384888, + 0.045413028448820114, + 0.16909754276275635, + 0.15786541998386383, + 0.13495682179927826, + -0.02144046686589718, + 0.08984969556331635, + -0.04134964570403099, + -0.051257528364658356, + 0.045112937688827515, + 0.047064878046512604, + 0.034162651747465134, + 0.021895015612244606, + -0.05455716326832771, + -0.0889664888381958, + 0.011050431989133358, + -0.014817613177001476, + 0.163625568151474, + 0.03653216361999512, + -0.06846585869789124, + -0.08761408925056458, + -0.08431434631347656, + -0.06531454622745514, + -0.01888906955718994, + -0.09313887357711792, + -0.10920677334070206, + -0.05248451232910156, + 0.001476499019190669, + -0.011065206490457058 ], "yaxis": "y" }, @@ -959,12 +950,12 @@ "name": "Perfect prediction", "type": "scatter", "x": [ - -1.969916820526123, - 1.6161797046661377 + -2.2862510681152344, + 2.5850915908813477 ], "y": [ - -1.969916820526123, - 1.6161797046661377 + -2.2862510681152344, + 2.5850915908813477 ] } ], @@ -1942,47 +1933,47 @@ " \n", " 0\n", " 0.125019\n", - " -1.208731\n", - " -0.793742\n", - " -0.239280\n", - " 0.260678\n", - " 0.750109\n", + " -1.071396\n", + " -0.291549\n", + " 0.577170\n", + " 1.379413\n", + " 2.154069\n", " \n", " \n", " 1\n", " -0.024960\n", - " -0.430616\n", - " -0.080356\n", - " 0.339311\n", - " 0.740931\n", - " 1.117335\n", + " -1.594807\n", + " -0.647442\n", + " 0.481174\n", + " 1.623415\n", + " 2.562861\n", " \n", " \n", " 2\n", " 0.103003\n", - " -1.449445\n", - " -0.907212\n", - " -0.276636\n", - " 0.273373\n", - " 0.844860\n", + " -2.563489\n", + " -1.645817\n", + " -0.721495\n", + " 0.225897\n", + " 1.118832\n", " \n", " \n", " 3\n", " 0.054845\n", - " -0.039490\n", - " 0.249726\n", - " 0.540901\n", - " 0.896468\n", - " 1.195148\n", + " -3.448399\n", + " -2.342187\n", + " -1.211108\n", + " 0.049215\n", + " 1.042724\n", " \n", " \n", " 4\n", " 0.038334\n", - " -0.671013\n", - " -0.092434\n", - " 0.460601\n", - " 1.070943\n", - " 1.577428\n", + " -3.044891\n", + " -2.123165\n", + " -1.162974\n", + " -0.149736\n", + " 0.683350\n", " \n", " \n", "\n", @@ -1990,11 +1981,11 @@ ], "text/plain": [ " Actual Q10 Q25 Q50 Q75 Q90\n", - "0 0.125019 -1.208731 -0.793742 -0.239280 0.260678 0.750109\n", - "1 -0.024960 -0.430616 -0.080356 0.339311 0.740931 1.117335\n", - "2 0.103003 -1.449445 -0.907212 -0.276636 0.273373 0.844860\n", - "3 0.054845 -0.039490 0.249726 0.540901 0.896468 1.195148\n", - "4 0.038334 -0.671013 -0.092434 0.460601 1.070943 1.577428" + "0 0.125019 -1.071396 -0.291549 0.577170 1.379413 2.154069\n", + "1 -0.024960 -1.594807 -0.647442 0.481174 1.623415 2.562861\n", + "2 0.103003 -2.563489 -1.645817 -0.721495 0.225897 1.118832\n", + "3 0.054845 -3.448399 -2.342187 -1.211108 0.049215 1.042724\n", + "4 0.038334 -3.044891 -2.123165 -1.162974 -0.149736 0.683350" ] }, "execution_count": 8, @@ -2054,8 +2045,8 @@ 0 ], "y": [ - -1.208730936050415, - 0.7501086592674255 + -1.0713963508605957, + 2.154068946838379 ] }, { @@ -2072,8 +2063,8 @@ 1 ], "y": [ - -0.4306159019470215, - 1.1173346042633057 + -1.5948072671890259, + 2.5628609657287598 ] }, { @@ -2090,8 +2081,8 @@ 2 ], "y": [ - -1.4494445323944092, - 0.8448596596717834 + -2.5634894371032715, + 1.118831753730774 ] }, { @@ -2108,8 +2099,8 @@ 3 ], "y": [ - -0.03949018567800522, - 1.1951476335525513 + -3.448399305343628, + 1.0427237749099731 ] }, { @@ -2126,8 +2117,8 @@ 4 ], "y": [ - -0.6710134148597717, - 1.577427625656128 + -3.044891357421875, + 0.6833502650260925 ] }, { @@ -2144,8 +2135,8 @@ 5 ], "y": [ - -1.388054609298706, - 0.7144660353660583 + -1.1571630239486694, + 1.9580988883972168 ] }, { @@ -2162,8 +2153,8 @@ 6 ], "y": [ - 0.9335368275642395, - 1.994225263595581 + -0.448160856962204, + 5.394647121429443 ] }, { @@ -2180,8 +2171,8 @@ 7 ], "y": [ - -1.2012914419174194, - 0.6220993399620056 + -3.3663272857666016, + 1.5296334028244019 ] }, { @@ -2198,8 +2189,8 @@ 8 ], "y": [ - 0.4870220720767975, - 1.1506199836730957 + -0.9083917140960693, + 4.910729885101318 ] }, { @@ -2216,8 +2207,8 @@ 9 ], "y": [ - -1.1105694770812988, - 0.017988024279475212 + 0.17172245681285858, + 3.2345242500305176 ] }, { @@ -2234,8 +2225,8 @@ 0 ], "y": [ - -0.6731794476509094, - 0.15996141731739044 + -0.11874838173389435, + 1.195131540298462 ] }, { @@ -2252,8 +2243,8 @@ 1 ], "y": [ - 0.023270580917596817, - 0.658467710018158 + -0.40085601806640625, + 1.4143391847610474 ] }, { @@ -2270,8 +2261,8 @@ 2 ], "y": [ - -0.769422709941864, - 0.17357832193374634 + -1.4394127130508423, + 0.04452350735664368 ] }, { @@ -2288,8 +2279,8 @@ 3 ], "y": [ - 0.3068358898162842, - 0.8011546730995178 + -2.0836896896362305, + -0.2426186203956604 ] }, { @@ -2306,8 +2297,8 @@ 4 ], "y": [ - 0.034883130341768265, - 0.9577965140342712 + -1.900681734085083, + -0.4198998510837555 ] }, { @@ -2324,8 +2315,8 @@ 5 ], "y": [ - -0.7857669591903687, - 0.1328243613243103 + -0.14701643586158752, + 1.0303322076797485 ] }, { @@ -2342,8 +2333,8 @@ 6 ], "y": [ - 1.253185749053955, - 1.6908645629882812 + 1.3525761365890503, + 3.828951597213745 ] }, { @@ -2360,8 +2351,8 @@ 7 ], "y": [ - -0.629646360874176, - 0.1008240208029747 + -1.7197819948196411, + 0.2868454158306122 ] }, { @@ -2378,8 +2369,8 @@ 8 ], "y": [ - 0.6771109104156494, - 0.9602589011192322 + 0.6424643397331238, + 3.1236233711242676 ] }, { @@ -2396,8 +2387,8 @@ 9 ], "y": [ - -0.7771245241165161, - -0.29959893226623535 + 1.113405704498291, + 2.3700172901153564 ] }, { @@ -2454,16 +2445,16 @@ 9 ], "y": [ - -0.23928046226501465, - 0.339310884475708, - -0.276636004447937, - 0.540900707244873, - 0.4606007933616638, - -0.34471726417541504, - 1.4770541191101074, - -0.30302998423576355, - 0.8187563419342041, - -0.5293239951133728 + 0.5771697759628296, + 0.4811735153198242, + -0.7214951515197754, + -1.211107611656189, + -1.1629741191864014, + 0.41260623931884766, + 2.5850915908813477, + -0.6760748624801636, + 1.8518712520599365, + 1.7761045694351196 ] }, { @@ -3472,9 +3463,9 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 7.6s remaining: 11.5s\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 7.6s remaining: 5.1s\n", - "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 7.7s finished\n" + "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 9.5s remaining: 14.3s\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 9.5s remaining: 6.4s\n", + "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 9.5s finished\n" ] }, { @@ -3483,8 +3474,8 @@ "text": [ "Quantile loss results:\n", " 0.05 0.10 0.15 0.20 0.25 ... 0.75 0.80 0.85 0.90 0.95\n", - "train 0.052342 0.080491 0.100151 0.114462 0.124877 ... 0.116854 0.105562 0.090379 0.070140 0.043119\n", - "test 0.060333 0.088338 0.109102 0.123207 0.134303 ... 0.127896 0.114887 0.099006 0.077486 0.047284\n", + "train 0.046876 0.074835 0.094329 0.108058 0.117289 ... 0.125522 0.116335 0.102234 0.081915 0.052980\n", + "test 0.048934 0.078406 0.099249 0.117245 0.130767 ... 0.133579 0.123578 0.109620 0.088337 0.058313\n", "\n", "[2 rows x 19 columns]\n" ] @@ -3516,8 +3507,32 @@ }, "data": [ { + "error_y": { + "array": [ + 0.03837709044008565, + 0.05971215726699471, + 0.07364494245759541, + 0.08337091425123092, + 0.08970334204623862, + 0.09257452417327834, + 0.09425214250614794, + 0.09727578509414461, + 0.1015340536222146, + 0.10484112107099822, + 0.10665196062518409, + 0.10767398977417002, + 0.10657423431327813, + 0.10291839541792067, + 0.09739380134017506, + 0.08998945316371061, + 0.079000703908655, + 0.0623493126962366, + 0.03897342645724234 + ], + "type": "data" + }, "marker": { - "color": "#00CC96" + "color": "#DDCC77" }, "name": "Train", "type": "bar", @@ -3543,30 +3558,54 @@ "0.95" ], "y": [ - 0.05234219330354317, - 0.08049099191203442, - 0.10015070706030556, - 0.11446162331927275, - 0.12487722321825685, - 0.1324132843109902, - 0.13768463404454853, - 0.14056006658781792, - 0.14185292803744393, - 0.1415992977069213, - 0.14005457428331253, - 0.13682163676619796, - 0.13220755848574667, - 0.12584912155335878, - 0.11685420327991669, - 0.10556158423365573, - 0.09037880944328178, - 0.07014017488138428, - 0.04311920349617806 + 0.04687617706207272, + 0.07483489712462413, + 0.0943290184639551, + 0.10805756595645155, + 0.11728923842580923, + 0.12286641601968225, + 0.12616856929600748, + 0.1292466478910964, + 0.13271784812663345, + 0.1354354106464196, + 0.13713837435319742, + 0.13782672818961056, + 0.13631602177577212, + 0.13210069705430905, + 0.12552150971602297, + 0.11633498833082316, + 0.10223421559735249, + 0.08191549168418728, + 0.0529798258038449 ] }, { + "error_y": { + "array": [ + 0.03618523330977212, + 0.05800221625311354, + 0.07291441840284316, + 0.08871777392907335, + 0.10179376511491461, + 0.11098038464484912, + 0.11340640962306367, + 0.11611839981806513, + 0.11748612497231077, + 0.1174333555879695, + 0.11549620009870612, + 0.1127200411691131, + 0.10807914038546579, + 0.10210520685018526, + 0.0948305718757984, + 0.08725433626495321, + 0.0765462308033389, + 0.060192339856620924, + 0.03773031440262491 + ], + "type": "data" + }, "marker": { - "color": "#AB63FA" + "color": "#117733" }, "name": "Test", "type": "bar", @@ -3592,25 +3631,25 @@ "0.95" ], "y": [ - 0.060333009487974085, - 0.08833805422856528, - 0.10910236343419596, - 0.12320670350109023, - 0.13430269442274195, - 0.14292452496806163, - 0.14938276914149362, - 0.15449110540720765, - 0.1571806630200133, - 0.1578150543631282, - 0.15689612315245363, - 0.15328965501428699, - 0.14702280821997657, - 0.13845547628705684, - 0.12789622672114154, - 0.11488700831441748, - 0.099006403363506, - 0.07748605801852056, - 0.047284249954029064 + 0.04893408004321757, + 0.07840558199482109, + 0.09924928077798252, + 0.11724509694752339, + 0.13076715505019737, + 0.14024932427368472, + 0.14459225917030372, + 0.1486228918064114, + 0.15166062433499797, + 0.1528524649168001, + 0.15231682792337978, + 0.15052964914121986, + 0.14697318308426555, + 0.14127787715452564, + 0.1335791298846442, + 0.1235777656562175, + 0.10962020934072403, + 0.08833690511355873, + 0.05831330288784895 ] } ], @@ -3629,8 +3668,8 @@ "r": 50, "t": 80 }, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "template": { "data": { "bar": [ @@ -4452,14 +4491,22 @@ }, "width": 750, "xaxis": { + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", "showgrid": false, + "showline": true, "title": { "text": "Quantile" }, "zeroline": false }, "yaxis": { - "showgrid": false, + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true, "title": { "text": "Average Quantile Loss" }, @@ -4476,7 +4523,7 @@ "# Plot the results\n", "if \"quantile_loss\" in mdn_results:\n", " perf_results_viz = model_performance_results(\n", - " results=mdn_results[\"quantile_loss\"][\"results\"],\n", + " results=mdn_results,\n", " model_name=\"MDN\",\n", " method_name=\"Cross-validation quantile loss average\",\n", " )\n", @@ -4531,9 +4578,9 @@ "text": [ "Categorical variable distribution:\n", "risk_level\n", + "medium 150\n", "low 148\n", - "high 148\n", - "medium 146\n", + "high 144\n", "Name: count, dtype: int64\n", "\n", "Training set size: 353 records\n", @@ -4607,13 +4654,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Categorical imputation accuracy: 31.46%\n", + "Categorical imputation accuracy: 30.34%\n", "\n", "Confusion matrix:\n", " Predicted: low Predicted: medium Predicted: high\n", - "Actual: low 10 8 12\n", - "Actual: medium 14 11 4\n", - "Actual: high 7 16 7\n" + "Actual: low 8 5 17\n", + "Actual: medium 11 8 10\n", + "Actual: high 7 12 11\n" ] } ], @@ -4647,9 +4694,9 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 0.8s remaining: 1.2s\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 0.9s remaining: 0.6s\n", - "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 0.9s finished\n" + "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 1.2s remaining: 1.9s\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 1.3s remaining: 0.9s\n", + "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 1.4s finished\n" ] }, { @@ -4657,8 +4704,8 @@ "output_type": "stream", "text": [ "Categorical imputation cross-validation results (log loss):\n", - "Mean train log loss: 0.9622\n", - "Mean test log loss: 1.0936\n" + "Mean train log loss: 0.9076\n", + "Mean test log loss: 1.0931\n" ] } ], @@ -4688,8 +4735,14 @@ }, "data": [ { + "error_y": { + "array": [ + 0.027550481167352696 + ], + "type": "data" + }, "marker": { - "color": "#00CC96" + "color": "#DDCC77" }, "name": "Train", "showlegend": true, @@ -4699,13 +4752,19 @@ ], "xaxis": "x", "y": [ - 0.9622434565816542 + 0.9075664863898126 ], "yaxis": "y" }, { + "error_y": { + "array": [ + 0.0913599760690238 + ], + "type": "data" + }, "marker": { - "color": "#AB63FA" + "color": "#117733" }, "name": "Test", "showlegend": true, @@ -4715,7 +4774,7 @@ ], "xaxis": "x", "y": [ - 1.093641015261423 + 1.093074681095708 ], "yaxis": "y" } @@ -4737,8 +4796,8 @@ } ], "height": 420, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "showlegend": true, "template": { "data": { @@ -5565,14 +5624,24 @@ "domain": [ 0, 1 - ] + ], + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": false, + "showline": true }, "yaxis": { "anchor": "x", "domain": [ 0, 1 - ] + ], + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true } } } @@ -5608,7 +5677,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pe3.13", + "display_name": "pe3.13 (3.13.0)", "language": "python", "name": "python3" }, diff --git a/docs/models/ols/index.md b/docs/models/ols/index.md index 5312939..13ef8bd 100644 --- a/docs/models/ols/index.md +++ b/docs/models/ols/index.md @@ -1,17 +1,17 @@ # Ordinary Least Squares -The `OLS` model employs linear regression techniques to predict missing values by leveraging the relationships between predictor and target variables. This classic statistical approach provides a computationally efficient method for imputation while offering theoretical guarantees under certain assumptions. +The `OLS` model uses linear regression to predict missing values from the relationships between predictor and target variables. It is computationally fast and provides a useful baseline for comparison with more complex methods. ## Variable type support -OLS automatically adapts to your target variable types. For numerical variables, it uses standard linear regression. For categorical variables (including strings, booleans, or numerically-encoded categorical variables), it automatically switches to logistic regression classification. This automatic detection means you don't need to specify variable types—simply pass your predictors and targets, and the model handles the rest internally. +OLS adapts to target variable types automatically. For numerical variables, it uses standard linear regression. For categorical variables (including strings, booleans, or numerically-encoded categorical variables), it switches to logistic regression. You don't need to specify variable types manually. ## How it works -The OLS imputer works by fitting a linear regression model using the statsmodels implementation of Ordinary Least Squares. During the training phase, it identifies the coefficients that minimize the sum of squared residuals between the predicted and actual values in the training data. This creates a model that captures the linear relationship between the predictors and target variables. +The OLS imputer fits a linear regression model using the statsmodels implementation. During training, it finds the coefficients that minimize the sum of squared residuals between predicted and actual values. -For prediction at different quantiles, the model makes an important assumption that the residuals (the differences between predicted and actual values) follow a normal distribution. This assumption allows the model to generate predictions at various quantiles by starting with the mean prediction and adding a quantile-specific offset derived from the normal distribution. Specifically, it computes the standard error of the predictions and applies the inverse normal cumulative distribution function to generate predictions at the requested quantiles. +For prediction at different quantiles, the model assumes normally distributed residuals. It starts with the mean prediction and adds a quantile-specific offset computed from the normal distribution's inverse CDF and the standard error of the predictions. ## Key features -The OLS imputer offers a simple yet powerful parametric approach with fast training and prediction times compared to more complex models. It relies on the assumption of linear relationships between variables, making it particularly suitable for datasets where such relationships hold or as a baseline comparison for more complex approaches. +OLS is fast to train and predict. It works well when the relationship between predictors and targets is approximately linear. Because it assumes constant variance and normally distributed errors, it tends to compress imputed values toward the mean, producing a narrower distribution than the true one. This makes it a good baseline but a poor choice when the data has heavy tails or heteroscedastic errors. diff --git a/docs/models/ols/ols-imputation.ipynb b/docs/models/ols/ols-imputation.ipynb index 0502f68..f09b1d9 100644 --- a/docs/models/ols/ols-imputation.ipynb +++ b/docs/models/ols/ols-imputation.ipynb @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -217,7 +217,7 @@ "4 0.005383 -0.044642 -0.036385 0.021872 0.003935 0.015596 0.008142 -0.002592 -0.031988 -0.046641" ] }, - "execution_count": 56, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -233,7 +233,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -354,7 +354,7 @@ "max 1.107267e-01 5.068012e-02 1.705552e-01 1.320436e-01 1.539137e-01 1.852344e-01" ] }, - "execution_count": 57, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -376,7 +376,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -408,7 +408,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -499,7 +499,7 @@ "73 0.012648 0.050680 -0.020218 -0.002228 NaN NaN" ] }, - "execution_count": 59, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -528,7 +528,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -547,7 +547,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -561,7 +561,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -628,7 +628,7 @@ "73 -0.004692 0.007055" ] }, - "execution_count": 62, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -653,7 +653,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -2004,7 +2004,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -2128,7 +2128,7 @@ "[5 rows x 20 columns]" ] }, - "execution_count": 64, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -2159,7 +2159,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -3617,14 +3617,17 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n" + "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 13.5s remaining: 20.2s\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 13.5s remaining: 9.0s\n", + "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 13.6s finished\n" ] }, { @@ -3657,7 +3660,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -3668,8 +3671,32 @@ }, "data": [ { + "error_y": { + "array": [ + 0.00007176011467255402, + 0.0000913355164601245, + 0.00011083575969884443, + 0.00011210242100357653, + 0.00011563162775412741, + 0.00013000531898907708, + 0.000143271836739722, + 0.0001506544338309766, + 0.00015261535043567369, + 0.00014081751874840215, + 0.00013834799510673368, + 0.0001357194581903946, + 0.00012375064613863142, + 0.00011764060922499956, + 0.00010758816302847906, + 0.00009853561930363294, + 0.00010941603098900106, + 0.00011884481777470878, + 0.00009323804980730522 + ], + "type": "data" + }, "marker": { - "color": "#00CC96" + "color": "#DDCC77" }, "name": "Train", "type": "bar", @@ -3717,8 +3744,32 @@ ] }, { + "error_y": { + "array": [ + 0.0002511960288130229, + 0.0003728781059406671, + 0.0005215834564121717, + 0.000522694690812438, + 0.0005201751472757763, + 0.0005498446940786355, + 0.000608514199932397, + 0.0006423333112739606, + 0.0006362169672265104, + 0.0005805283237282669, + 0.0005503271496842148, + 0.0005306193310896558, + 0.0005077865393829969, + 0.000485382018208908, + 0.00043665631509002455, + 0.00036712860842352926, + 0.0004221172731648885, + 0.0005054827478209771, + 0.00043716321985420154 + ], + "type": "data" + }, "marker": { - "color": "#AB63FA" + "color": "#117733" }, "name": "Test", "type": "bar", @@ -3781,8 +3832,8 @@ "r": 50, "t": 80 }, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "template": { "data": { "bar": [ @@ -4604,14 +4655,22 @@ }, "width": 750, "xaxis": { + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", "showgrid": false, + "showline": true, "title": { "text": "Quantile" }, "zeroline": false }, "yaxis": { - "showgrid": false, + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true, "title": { "text": "Average Quantile Loss" }, @@ -4628,7 +4687,7 @@ "# Plot the results for numerical variables\n", "if \"quantile_loss\" in ols_results:\n", " perf_results_viz = model_performance_results(\n", - " results=ols_results[\"quantile_loss\"][\"results\"],\n", + " results=ols_results,\n", " model_name=\"OLS\",\n", " method_name=\"Cross-validation quantile loss average\",\n", " )\n", @@ -4649,7 +4708,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -4658,19 +4717,19 @@ "text": [ "Categorical variable distribution:\n", "risk_level\n", + "medium 150\n", "low 148\n", - "high 148\n", - "medium 146\n", + "high 144\n", "Name: count, dtype: int64\n", "\n", "Percentage distribution:\n", "risk_level\n", + "medium 0.339367\n", "low 0.334842\n", - "high 0.334842\n", - "medium 0.330317\n", + "high 0.325792\n", "Name: proportion, dtype: float64\n", "\n", - "Data types: {'age': dtype('float64'), 'sex': dtype('float64'), 'bmi': dtype('float64'), 'bp': dtype('float64'), 'risk_level': dtype('O')}\n", + "Data types: {'age': dtype('float64'), 'sex': dtype('float64'), 'bmi': dtype('float64'), 'bp': dtype('float64'), 'risk_level': }\n", "\n", "Training set size: 353 records\n", "Testing set size: 89 records\n" @@ -4707,7 +4766,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -4750,20 +4809,20 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Categorical imputation accuracy: 40.45%\n", + "Categorical imputation accuracy: 37.08%\n", "\n", "Confusion matrix:\n", " Predicted: low Predicted: medium Predicted: high\n", - "Actual: low 15 11 4\n", - "Actual: medium 9 20 0\n", - "Actual: high 10 19 1\n" + "Actual: low 9 10 11\n", + "Actual: medium 4 20 5\n", + "Actual: high 6 20 4\n" ] } ], @@ -4791,14 +4850,17 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n" + "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", + "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 1.1s remaining: 1.6s\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 1.2s remaining: 0.8s\n", + "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 1.6s finished\n" ] }, { @@ -4806,8 +4868,8 @@ "output_type": "stream", "text": [ "Categorical imputation cross-validation results (log loss):\n", - "Mean train log loss: 1.0673\n", - "Mean test log loss: 1.0776\n" + "Mean train log loss: 1.0681\n", + "Mean test log loss: 1.0794\n" ] } ], @@ -4828,7 +4890,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -4839,8 +4901,14 @@ }, "data": [ { + "error_y": { + "array": [ + 0.0019313286420746783 + ], + "type": "data" + }, "marker": { - "color": "#00CC96" + "color": "#DDCC77" }, "name": "Train", "showlegend": true, @@ -4850,13 +4918,19 @@ ], "xaxis": "x", "y": [ - 1.0673353634599498 + 1.0681169294698512 ], "yaxis": "y" }, { + "error_y": { + "array": [ + 0.011486633730960285 + ], + "type": "data" + }, "marker": { - "color": "#AB63FA" + "color": "#117733" }, "name": "Test", "showlegend": true, @@ -4866,7 +4940,7 @@ ], "xaxis": "x", "y": [ - 1.077618288538511 + 1.0794183318669401 ], "yaxis": "y" } @@ -4878,7 +4952,7 @@ "size": 16 }, "showarrow": false, - "text": "Log Loss Performance", + "text": "Log loss performance", "x": 0.5, "xanchor": "center", "xref": "paper", @@ -4888,8 +4962,8 @@ } ], "height": 420, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "showlegend": true, "template": { "data": { @@ -5716,14 +5790,24 @@ "domain": [ 0, 1 - ] + ], + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": false, + "showline": true }, "yaxis": { "anchor": "x", "domain": [ 0, 1 - ] + ], + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true } } } @@ -5749,7 +5833,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pe3.13", + "display_name": "pe3.13 (3.13.0)", "language": "python", "name": "python3" }, diff --git a/docs/models/qrf/index.md b/docs/models/qrf/index.md index 7e15838..170ec9a 100644 --- a/docs/models/qrf/index.md +++ b/docs/models/qrf/index.md @@ -1,21 +1,19 @@ # Quantile Random Forests -The `QRF` model harnesses the power of ensemble learning by utilizing multiple decision trees to predict different quantiles of the target variable distribution. This sophisticated approach allows for flexible modeling of complex relationships while providing robust estimates of uncertainty. +The `QRF` model uses an ensemble of decision trees to predict different quantiles of the target variable distribution. This allows it to model non-linear relationships while estimating uncertainty across the conditional distribution. ## Variable type support -QRF handles both numerical and categorical variables. For numerical targets, it uses quantile regression forests. For categorical targets (strings, booleans, numerically-encoded categorical variables), it automatically employs a Random Forest Classifier. The model detects variable types automatically and applies the appropriate method internally, requiring no manual specification from users. +QRF handles both numerical and categorical variables. For numerical targets, it uses quantile regression forests. For categorical targets (strings, booleans, numerically-encoded categorical variables), it automatically uses a Random Forest Classifier. The model detects variable types internally and requires no manual specification. ## How it works -Quantile Random Forests build upon the foundation of random forests by implementing a specialized algorithm from the quantile_forest package. The method begins by constructing an ensemble of decision trees, each trained on different bootstrapped samples of the original data. This process, known as bagging, introduces diversity among the individual trees and helps reduce overfitting. +Quantile Random Forests build on standard random forests using the `quantile_forest` package. The method constructs an ensemble of decision trees, each trained on a bootstrapped sample of the data (bagging). At each split, only a random subset of features is considered, which introduces diversity among trees and reduces overfitting. -During training, each tree in the forest predicts the target variable using only a random subset of the available features at each split point. This feature randomization further enhances diversity within the ensemble and improves its ability to capture various aspects of the underlying data relationships. +Unlike standard random forests that aggregate predictions into averages, QRF retains the full predictive distribution from each tree and estimates quantiles directly from this empirical distribution. ## Key features -The Quantile Random Forest (QRF) imputer provides a robust non-parametric method particularly effective for datasets exhibiting complex, non-linear relationships and heteroscedasticity. Unlike linear models, which rely on strong distributional assumptions, QRF makes minimal assumptions about the underlying data structure, adapting its uncertainty measures to reflect varying levels of variability within different regions of the input data. +QRF is non-parametric and makes minimal assumptions about the data structure. It adapts its uncertainty estimates to different regions of the input space, producing wider prediction intervals where the data is more variable and tighter intervals where it is less so. -QRF's primary strength lies in its predictive approach. While traditional random forests aggregate predictions into averages, QRF maintains the entire predictive distribution from each tree, directly estimating quantiles based on this empirical distribution. It also quantifies uncertainty through robust prediction intervals derived directly from its quantile estimates. These intervals dynamically adjust across the feature space, effectively signaling areas with varying levels of predictive certainty. - -Although QRF typically involves higher computational demands compared to simpler linear models, its enhanced accuracy on datasets with complex, non-linear relationships frequently justifies this trade-off. For applications where accurate predictive performance and meaningful uncertainty quantification are critical, QRF emerges as an especially valuable approach. +The method is computationally heavier than linear models, but often more accurate on datasets with non-linear relationships or heteroscedasticity (where variance depends on predictor values). Hyperparameter tuning via Optuna is available for optimizing the number of trees, minimum samples per leaf, split thresholds, and feature sampling. diff --git a/docs/models/qrf/qrf-imputation.ipynb b/docs/models/qrf/qrf-imputation.ipynb index c59a999..ffba33f 100644 --- a/docs/models/qrf/qrf-imputation.ipynb +++ b/docs/models/qrf/qrf-imputation.ipynb @@ -97,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -3658,9 +3658,9 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 6.8s remaining: 10.2s\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 6.8s remaining: 4.5s\n", - "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 6.8s finished\n" + "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 19.2s remaining: 28.7s\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 19.2s remaining: 12.8s\n", + "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 19.3s finished\n" ] }, { @@ -3704,8 +3704,32 @@ }, "data": [ { + "error_y": { + "array": [ + 0.00008551920671991658, + 0.0004107245170541804, + 0.00018326432680445731, + 0.00044296698746322717, + 0.00024075509310405528, + 0.00029083292370858524, + 0.00043090603620434025, + 0.0004606508356991907, + 0.00034842456444518586, + 0.0004626832277981355, + 0.000816089425956785, + 0.00038467021226627177, + 0.0003398728669511024, + 0.000309964218752641, + 0.0002952487775964039, + 0.0002930002413745392, + 0.00022755896739762955, + 0.00017022925074651195, + 0.00005512096011768218 + ], + "type": "data" + }, "marker": { - "color": "#00CC96" + "color": "#DDCC77" }, "name": "Train", "type": "bar", @@ -3753,8 +3777,32 @@ ] }, { + "error_y": { + "array": [ + 0.000623604695409784, + 0.0009828712930208695, + 0.0012293155627994283, + 0.001167454357338009, + 0.000969638311140881, + 0.0007127881853378359, + 0.002636442666073569, + 0.001259531904761815, + 0.0010923805491178616, + 0.001634028502692155, + 0.0006800631282736149, + 0.0009467176016052857, + 0.0010995284899625922, + 0.0008688007594503854, + 0.0010158843160004165, + 0.0011140556744687275, + 0.0006082527278363497, + 0.00045185788868329796, + 0.0004799569375129921 + ], + "type": "data" + }, "marker": { - "color": "#AB63FA" + "color": "#117733" }, "name": "Test", "type": "bar", @@ -3817,8 +3865,8 @@ "r": 50, "t": 80 }, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "template": { "data": { "bar": [ @@ -4640,14 +4688,22 @@ }, "width": 750, "xaxis": { + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", "showgrid": false, + "showline": true, "title": { "text": "Quantile" }, "zeroline": false }, "yaxis": { - "showgrid": false, + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true, "title": { "text": "Average Quantile Loss" }, @@ -4664,7 +4720,7 @@ "# Plot the results for numerical variables\n", "if \"quantile_loss\" in qrf_results:\n", " perf_results_viz = model_performance_results(\n", - " results=qrf_results[\"quantile_loss\"][\"results\"],\n", + " results=qrf_results,\n", " model_name=\"QRF\",\n", " method_name=\"Cross-validation quantile loss average\",\n", " )\n", @@ -4739,7 +4795,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -4748,19 +4804,19 @@ "text": [ "Categorical variable distribution:\n", "risk_level\n", + "medium 150\n", "low 148\n", - "high 148\n", - "medium 146\n", + "high 144\n", "Name: count, dtype: int64\n", "\n", "Percentage distribution:\n", "risk_level\n", + "medium 0.339367\n", "low 0.334842\n", - "high 0.334842\n", - "medium 0.330317\n", + "high 0.325792\n", "Name: proportion, dtype: float64\n", "\n", - "Data types: {'age': dtype('float64'), 'sex': dtype('float64'), 'bmi': dtype('float64'), 'bp': dtype('float64'), 'risk_level': dtype('O')}\n", + "Data types: {'age': dtype('float64'), 'sex': dtype('float64'), 'bmi': dtype('float64'), 'bp': dtype('float64'), 'risk_level': }\n", "\n", "Training set size: 353 records\n", "Testing set size: 89 records\n" @@ -4797,7 +4853,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -4840,20 +4896,20 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Categorical imputation accuracy: 42.70%\n", + "Categorical imputation accuracy: 44.94%\n", "\n", "Confusion matrix:\n", " Predicted: low Predicted: medium Predicted: high\n", - "Actual: low 13 5 12\n", - "Actual: medium 7 11 11\n", - "Actual: high 3 13 14\n" + "Actual: low 14 4 12\n", + "Actual: medium 8 13 8\n", + "Actual: high 4 13 13\n" ] } ], @@ -4881,7 +4937,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -4889,9 +4945,9 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 0.7s remaining: 1.1s\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 0.7s remaining: 0.5s\n", - "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 0.7s finished\n" + "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 3.9s remaining: 5.8s\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 4.1s remaining: 2.7s\n", + "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 4.4s finished\n" ] }, { @@ -4899,8 +4955,8 @@ "output_type": "stream", "text": [ "Categorical imputation cross-validation results (log loss):\n", - "Mean train log loss: 0.2656\n", - "Mean test log loss: 1.2133\n" + "Mean train log loss: 0.2666\n", + "Mean test log loss: 1.2308\n" ] } ], @@ -4921,7 +4977,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -4932,8 +4988,14 @@ }, "data": [ { + "error_y": { + "array": [ + 0.004587677764544018 + ], + "type": "data" + }, "marker": { - "color": "#00CC96" + "color": "#DDCC77" }, "name": "Train", "showlegend": true, @@ -4943,13 +5005,19 @@ ], "xaxis": "x", "y": [ - 0.2656462787483515 + 0.26664213025700334 ], "yaxis": "y" }, { + "error_y": { + "array": [ + 0.21683973320460714 + ], + "type": "data" + }, "marker": { - "color": "#AB63FA" + "color": "#117733" }, "name": "Test", "showlegend": true, @@ -4959,7 +5027,7 @@ ], "xaxis": "x", "y": [ - 1.2133400644133157 + 1.2308133756846842 ], "yaxis": "y" } @@ -4981,8 +5049,8 @@ } ], "height": 420, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "showlegend": true, "template": { "data": { @@ -5809,14 +5877,24 @@ "domain": [ 0, 1 - ] + ], + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": false, + "showline": true }, "yaxis": { "anchor": "x", "domain": [ 0, 1 - ] + ], + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true } } } @@ -5842,7 +5920,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pe3.13", + "display_name": "pe3.13 (3.13.0)", "language": "python", "name": "python3" }, diff --git a/docs/models/quantreg/index.md b/docs/models/quantreg/index.md index 2e31d4d..38637a2 100644 --- a/docs/models/quantreg/index.md +++ b/docs/models/quantreg/index.md @@ -1,23 +1,21 @@ # Quantile Regression -The `QuantReg` model takes a direct approach to modeling specific quantiles of the target variable distribution. Unlike methods that model the mean and then derive quantiles from distributional assumptions, quantile regression addresses each conditional quantile explicitly, providing greater flexibility and robustness in heterogeneous data settings. +The `QuantReg` model directly models specific quantiles of the target variable distribution. Unlike methods that model the conditional mean and derive quantiles from distributional assumptions, quantile regression addresses each conditional quantile separately. ## Variable type support -QuantReg is designed specifically for numerical variables and does not support categorical variable imputation. If your imputation targets include categorical variables (string, or numerically-encoded categorical variables), consider using OLS or QRF models instead, which automatically handle both numerical and categorical targets through internal classification methods. +QuantReg is designed for numerical variables and does not support categorical imputation. If your targets include categorical variables, use OLS or QRF instead, which handle both numerical and categorical targets through internal classification methods. ## How it works -Quantile regression in Microimpute leverages the statsmodels' QuantReg implementation to create precise models of conditional quantiles. During the training phase, the approach fits separate regression models for each requested quantile level, creating a focused model for each part of the conditional distribution you wish to estimate. +The implementation uses statsmodels' `QuantReg`. During training, a separate regression model is fitted for each requested quantile level. -The mathematical foundation of the method lies in its objective function, which minimizes asymmetrically weighted absolute residuals rather than squared residuals as in ordinary least squares. This asymmetric weighting system penalizes under-predictions more heavily when estimating higher quantiles and over-predictions more heavily when estimating lower quantiles. This clever formulation allows the model to converge toward solutions that represent true conditional quantiles. +The objective function minimizes asymmetrically weighted absolute residuals rather than squared residuals (as in OLS). For higher quantiles, under-predictions are penalized more heavily; for lower quantiles, over-predictions are penalized more heavily. This asymmetry causes each model to converge toward the true conditional quantile. -When making predictions, the system applies the appropriate quantile-specific model for each requested quantile level. This direct approach means predictions at different quantiles come from distinct models optimized for those specific portions of the distribution, rather than from a single model with assumptions about the error distribution. +When predicting, the system applies the quantile-specific model for each requested level. Predictions at different quantiles come from distinct models, each optimized for that part of the distribution. ## Key features -Quantile regression offers several compelling advantages for imputation tasks. It allows direct modeling of conditional quantiles without making restrictive assumptions about the underlying distribution of the data. This distribution-free approach makes the method robust to outliers and applicable in a wide range of scenarios where normal distribution assumptions might be violated. +Quantile regression models conditional quantiles without assuming a particular error distribution, making it robust to outliers. It naturally captures heteroscedasticity, adapting to changing variance patterns across the feature space (unlike OLS, which assumes constant variance). -The method excels at capturing heteroscedasticity—situations where the variability of the target depends on the predictor values. While methods like OLS assume constant variance throughout the feature space, quantile regression naturally adapts to changing variance patterns, providing more accurate predictions in regions with different error characteristics. - -By fitting multiple quantile levels, the approach provides a comprehensive picture of the conditional distribution of the target variable. This detailed view enables more nuanced imputation where understanding the full range of possible values is important. For instance, it can reveal asymmetries in the conditional distribution that other methods might miss, offering valuable insights into the uncertainty structure of the imputed values. +By fitting multiple quantile levels, the method gives a picture of the full conditional distribution. This can reveal asymmetries that other methods would miss. However, QuantReg is limited to linear relationships between predictors and the target. diff --git a/docs/models/quantreg/quantreg-imputation.ipynb b/docs/models/quantreg/quantreg-imputation.ipynb index 79308a9..2a93247 100644 --- a/docs/models/quantreg/quantreg-imputation.ipynb +++ b/docs/models/quantreg/quantreg-imputation.ipynb @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -3633,9 +3633,9 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", - "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 5.7s remaining: 8.5s\n", - "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 5.7s remaining: 3.8s\n", - "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 5.8s finished\n" + "[Parallel(n_jobs=-1)]: Done 2 out of 5 | elapsed: 9.7s remaining: 14.5s\n", + "[Parallel(n_jobs=-1)]: Done 3 out of 5 | elapsed: 9.7s remaining: 6.5s\n", + "[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed: 10.0s finished\n" ] }, { @@ -3679,8 +3679,32 @@ }, "data": [ { + "error_y": { + "array": [ + 0.00007068607962860608, + 0.00008566173778664199, + 0.00010354861421659208, + 0.00010446681832917282, + 0.00009958982144334462, + 0.00010899487172081684, + 0.00012524333466419057, + 0.0001401559350155429, + 0.0001419856727246125, + 0.00014148804255191002, + 0.00013964325829394816, + 0.0001352059984264296, + 0.0001266211747154425, + 0.00010806937892337481, + 0.0000955126673952158, + 0.00009572945677819657, + 0.00010825711304818197, + 0.00009924455051088341, + 0.00008058254845100971 + ], + "type": "data" + }, "marker": { - "color": "#00CC96" + "color": "#DDCC77" }, "name": "Train", "type": "bar", @@ -3728,8 +3752,32 @@ ] }, { + "error_y": { + "array": [ + 0.0002923819760799368, + 0.00036360537094371777, + 0.0004548700007812027, + 0.0004754851566973023, + 0.0004310373487827412, + 0.0004963764336370348, + 0.0004984041180187477, + 0.0005592567272666623, + 0.0006264645396437899, + 0.0006000923054654875, + 0.0006170472013545412, + 0.0005280976014239853, + 0.0004788123955147881, + 0.00041499916196403646, + 0.0003537849752500568, + 0.0003903169140390032, + 0.0004234486790403488, + 0.0003675153867336991, + 0.00032353724735584744 + ], + "type": "data" + }, "marker": { - "color": "#AB63FA" + "color": "#117733" }, "name": "Test", "type": "bar", @@ -3792,8 +3840,8 @@ "r": 50, "t": 80 }, - "paper_bgcolor": "#F0F0F0", - "plot_bgcolor": "#F0F0F0", + "paper_bgcolor": "#FAFAFA", + "plot_bgcolor": "#FAFAFA", "template": { "data": { "bar": [ @@ -4615,14 +4663,22 @@ }, "width": 750, "xaxis": { + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", "showgrid": false, + "showline": true, "title": { "text": "Quantile" }, "zeroline": false }, "yaxis": { - "showgrid": false, + "gridcolor": "#E5E5E5", + "gridwidth": 1, + "linecolor": "#CCCCCC", + "showgrid": true, + "showline": true, "title": { "text": "Average Quantile Loss" }, @@ -4639,7 +4695,7 @@ "# Plot the results\n", "if \"quantile_loss\" in quantreg_results:\n", " perf_results_viz = model_performance_results(\n", - " results=quantreg_results[\"quantile_loss\"][\"results\"],\n", + " results=quantreg_results,\n", " model_name=\"QuantReg\",\n", " method_name=\"Cross-validation quantile loss average\",\n", " )\n", @@ -4652,7 +4708,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pe3.13", + "display_name": "pe3.13 (3.13.0)", "language": "python", "name": "python3" }, diff --git a/docs/myst.yml b/docs/myst.yml index e7f484c..a078097 100644 --- a/docs/myst.yml +++ b/docs/myst.yml @@ -48,8 +48,6 @@ project: - title: Use cases children: - file: use_cases/index - children: - - file: use_cases/scf_to_cps/imputing-from-scf-to-cps site: options: logo: logo.png diff --git a/docs/use_cases/index.md b/docs/use_cases/index.md index b7ea6d4..32c7af9 100644 --- a/docs/use_cases/index.md +++ b/docs/use_cases/index.md @@ -1,14 +1,14 @@ -# Imputing full variables across surveys +# Imputing variables across surveys -This document explains what the workflow for imputing variables from one survey to another using Microimpute may look like. We'll use the example of imputing wealth data from the Survey of Consumer Finances (SCF) into the Current Population Survey (CPS). +This document walks through a typical workflow for imputing variables from one survey to another using microimpute, using wealth imputation from the Survey of Consumer Finances (SCF) into the Current Population Survey (CPS) as the running example. ## Identifying receiver and donor datasets -The first step is to identify your donor and receiver datasets. The Donor dataset is that containing the variable you want to impute (e.g., SCF contains wealth data). The Receiver dataset will receive the imputed variable (e.g., CPS which originally did not contain wealth data but will after our imputation is completed). It is important for these two datasets to have predictor variables in common for the imputation to be succcesful. For example, both the SCF and CPS surveys contain demographic and financial data that may help us understand how wealth values may be distributed. +Start by identifying your donor and receiver datasets. The donor dataset contains the variable you want to impute (here, the SCF contains wealth data). The receiver dataset will receive the imputed variable (here, the CPS, which lacks wealth data). Both datasets need predictor variables in common for the imputation to work. For example, both the SCF and CPS contain demographic and financial data that can help predict wealth. ```python import pandas as pd -from microimpute.models import OLSImputer, MatchingImputer, QRFImputer +from microimpute.models import OLS, Matching, QRF # Load donor dataset (SCF with wealth data) scf_data = pd.read_csv("scf_data.csv") @@ -19,97 +19,107 @@ cps_data = pd.read_csv("cps_data.csv") ## Cleaning and aligning variables -Before imputation, you need to ensure both datasets have compatible variables. Identify common variables present in both datasets -and standardize their variable formats, units, and categories so that Python can recognize they indeed represent the same the same data. Remember to also handle missing values in common variables. Lastly, identify the target variables in the donor dataset which will directly inform the values of the imputed variables in the receiver dataset. For details on data preprocessing options refer to the [Data preprocessing page](./preprocessing.md). +Before imputation, make sure both datasets have compatible variables. Identify common variables present in both datasets and standardize their formats, units, and categories so that Python can match them correctly. Handle missing values in common variables, and identify the target variables in the donor dataset that will inform the imputed values in the receiver. For details on preprocessing options, see the [Data preprocessing page](../imputation-benchmarking/preprocessing.md). ```python -# Identify common variables +# Identify common variables common_variables = ['age', 'income', 'education', 'marital_status', 'region'] # Ensure variable formats match (example: education coding) education_mapping = { - 1: "less_than_hs", - 2: "high_school", - 3: "some_college", - 4: "bachelor", + 1: "less_than_hs", + 2: "high_school", + 3: "some_college", + 4: "bachelor", 5: "graduate" } # Apply standardization to both datasets for dataset in [scf_data, cps_data]: dataset['education'] = dataset['education'].map(education_mapping) - + # Convert income to same units (thousands) if 'income' in dataset.columns: dataset['income'] = dataset['income'] / 1000 - + # Identify target variable in donor dataset target_variable = ['networth'] ``` ## Performing imputation -Microimpute offers several methods for imputation across surveys, which are described in the [Models Chapter](../models). The approach under the hood will differ based on the method chosen, although the workflow will remain constant. Let us see this for two different example methods. +Microimpute offers several methods for imputation across surveys, described in the [Models chapter](../models). The underlying approach differs by method, but the workflow stays the same. Here are two examples. ### Matching imputation -Matching finds similar observations in the donor dataset for each observation in the receiver dataset and imputes the values for those receiver observations based on the values of the target value in the donor dataset. To do so it should be fitted on the donor dataset and predict using the receiver dataset. This will ensure the correct mapping of variables from one survey to the other. +Matching finds similar observations in the donor dataset for each receiver observation and transfers the donor's target values. Fit on the donor dataset, then predict using the receiver dataset. ```python # Set up matching imputer -matching_imputer = MatchingImputer( - predictors=common_variables, - imputed_variables=target_variable -) +matching = Matching() # Train on donor dataset -matching_imputer.fit(scf_data) +fitted_matching = matching.fit( + X_train=scf_data, + predictors=common_variables, + imputed_variables=target_variable, +) # Impute target variable into receiver dataset -cps_data_with_wealth_matching = matching_imputer.predict(cps_data) +cps_with_wealth_matching = fitted_matching.predict(X_test=cps_data) ``` ### Regression imputation (OLS) -OLS imputation builds a linear regression model using the donor dataset and applies it to the receiver dataset, predicting what wealth values may be for a specific combination of predictor variable values. To do so, again we need to make sure that we first fit the model on the donor dataset, while calling predict on the receiver dataset. +OLS builds a linear regression model on the donor dataset and predicts wealth values for each combination of predictor values in the receiver. Again, fit on the donor and predict on the receiver. - ```python +```python # Set up OLS imputer -ols_imputer = OLSImputer( - explanatory_variables=common_variables, - target_variable=target_variable -) +ols = OLS() # Train on donor dataset -ols_imputer.fit(scf_data) +fitted_ols = ols.fit( + X_train=scf_data, + predictors=common_variables, + imputed_variables=target_variable, +) # Impute target variable into receiver dataset -cps_data_with_wealth_ols = ols_imputer.impute(cps_data) +cps_with_wealth_ols = fitted_ols.predict(X_test=cps_data) ``` ## Evaluating imputation quality -Evaluating imputation quality across surveys can be challenging since the true values aren't known in the receiver dataset. Comparing the distribution of the target variables in the donor dataset to the distribution of the variables we imputed in the receiver dataset may give us an understanding of the imputation quality for different sections of the distribution. We may want to pay particular attention to obtaining accurate prediction not only for mean or median values but also look at the performance at the distribution tails. This can be achieved computing the quantile loss supported by Microimpute. Additionally, if we have performed imputation accross multiple methods we may want to compare across them. Microimpute supports this through multiple easy-to-use metrics described in the [Metrics page](./metrics.md) file. +Evaluating imputation quality across surveys is challenging since the true values are unknown in the receiver dataset. Comparing the target variable's distribution in the donor to the imputed distribution in the receiver can reveal how well the imputation captures different parts of the distribution. Beyond mean or median accuracy, check performance at the tails using quantile loss. When comparing multiple methods, microimpute provides several metrics described in the [Metrics page](../imputation-benchmarking/metrics.md). ```python -# Ensure all imputations are in a dictionary mapping quantiles to dataframes containing imputed values -method_imputations = { - [0.1]: pd.DataFrame - [0.5]: pd.DataFrame - ... -} +from microimpute.comparisons import get_imputations +from microimpute.comparisons.metrics import compare_metrics + +# Generate imputations from multiple models using cross-validation +method_imputations = get_imputations( + model_classes=[QRF, OLS, Matching], + X_train=train_data, + X_test=test_data, + predictors=common_variables, + imputed_variables=target_variable, +) -# Compare original wealth distribution with imputed wealth across methods -loss_comparison_df = compare_quantile_loss(Y_test, method_imputations) +# Compare quantile loss across methods +loss_comparison_df = compare_metrics( + test_y=test_data[target_variable], + method_imputations=method_imputations, + imputed_variables=target_variable, +) ``` ## Incorporating the imputed variable -Once you've chosen the best imputation method, you may want to incorporate the imputed variable into your receiver dataset for future analysis. +Once you've chosen the best imputation method, incorporate the imputed variable into your receiver dataset for downstream analysis. ```python # Choose the best imputation method (e.g., QRF) -final_imputed_dataset = cps_data_with_wealth_qrf +final_imputed_dataset = cps_with_wealth_qrf # Save the augmented dataset final_imputed_dataset.to_csv("cps_with_imputed_wealth.csv", index=False) @@ -117,4 +127,8 @@ final_imputed_dataset.to_csv("cps_with_imputed_wealth.csv", index=False) ## Key considerations -Model selection plays a critical role in this workflow because different imputation methods have unique strengths. For example, a Quantile Regression Forest (QRF) often performs better when capturing complex relationships between variables, while a Matching approach may be more effective at preserving the original distributional properties of the data. Additionally, not all models can impute categorical data. For example, atching is able to match any value regardless of its data type, but QuantReg does not support categorical imputation. OLS and QRF will use logistic regression and random forest classification methods under the hood, respectively. Variable selection is equally important, since the common predictors across datasets should have strong power for explaining the target variable to ensure a reliable imputation. Because the ground truth is typically unknown in the receiver dataset, validation can involve simulation studies or comparing imputed values against known aggregate statistics. Finally, it is crucial to maintain documentation of the imputation process, from the choice of model to any adjustments made, so that the analysis remains transparent and reproducible for others. For the full pipeline details on SCF-to-CPS net worth imputation refer to the following [notebook](./scf_to_cps/imputing-from-scf-to-cps.md). +Model selection matters because different imputation methods have different strengths. QRF often performs better at capturing non-linear relationships, while Matching tends to preserve the original distributional properties of the data. Not all models handle categorical data: Matching can match any value regardless of type, but QuantReg does not support categorical imputation. OLS and QRF use logistic regression and random forest classification internally for categorical targets. + +Variable selection is equally important. The common predictors should have strong explanatory power for the target variable. Because the ground truth is unknown in the receiver dataset, validation can involve simulation studies or comparison against known aggregate statistics. + +For a complete worked example of the SCF-to-CPS net worth imputation pipeline, see the [autoimpute notebook](../autoimpute/autoimpute.ipynb). The [microimpute paper](https://github.com/PolicyEngine/microimpute/blob/main/paper/main.pdf) presents the full methodology and reports results from this imputation. diff --git a/docs/use_cases/scf_to_cps/autoimpute_best_model_imputations.png b/docs/use_cases/scf_to_cps/autoimpute_best_model_imputations.png deleted file mode 100644 index 477f98d..0000000 Binary files a/docs/use_cases/scf_to_cps/autoimpute_best_model_imputations.png and /dev/null differ diff --git a/docs/use_cases/scf_to_cps/autoimpute_model_comparison.png b/docs/use_cases/scf_to_cps/autoimpute_model_comparison.png deleted file mode 100644 index 93c6bd6..0000000 Binary files a/docs/use_cases/scf_to_cps/autoimpute_model_comparison.png and /dev/null differ diff --git a/docs/use_cases/scf_to_cps/by_income_decile_comparisons.png b/docs/use_cases/scf_to_cps/by_income_decile_comparisons.png deleted file mode 100644 index 9adb8d1..0000000 Binary files a/docs/use_cases/scf_to_cps/by_income_decile_comparisons.png and /dev/null differ diff --git a/docs/use_cases/scf_to_cps/imputations_model_comparison.png b/docs/use_cases/scf_to_cps/imputations_model_comparison.png deleted file mode 100644 index e4675de..0000000 Binary files a/docs/use_cases/scf_to_cps/imputations_model_comparison.png and /dev/null differ diff --git a/docs/use_cases/scf_to_cps/imputing-from-scf-to-cps.md b/docs/use_cases/scf_to_cps/imputing-from-scf-to-cps.md deleted file mode 100644 index 18e09c7..0000000 --- a/docs/use_cases/scf_to_cps/imputing-from-scf-to-cps.md +++ /dev/null @@ -1,924 +0,0 @@ -# Example: imputing wealth from the SCF to the CPS - -This notebook demonstrates a full pipeline powered by the `microimpute` package and specifically the `autoimpute` function used to impute wealth variables from the Survey of Consumer Finances to the Current Population Survey. - -The Survey of Consumer Finances (SCF) is a triennial survey conducted by the Federal Reserve that collects detailed information on U.S. families' balance sheets, income, and demographic characteristics, with a special focus on wealth measures. The Current Population Survey (CPS) is a monthly survey conducted by the Census Bureau that provides comprehensive data on the labor force, employment, unemployment, and demographic characteristics, but lacks detailed wealth information. - -By using `microimpute`, wealth information can be transfered from the SCF to the CPS, enabling economic analyses that require both detailed labor market and wealth data. - -```python -import io -import logging -import zipfile -from typing import List, Optional, Union - -import numpy as np -import pandas as pd -import requests -import plotly.graph_objects as go -from plotly.subplots import make_subplots -from pydantic import validate_call -from tqdm import tqdm -import warnings - -from microimpute.config import ( - VALIDATE_CONFIG, VALID_YEARS, PLOT_CONFIG -) -from microimpute.comparisons import * -from microimpute.visualizations import * -from microimpute.utils.data import preprocess_data - -logger = logging.getLogger(__name__) -``` - -## Loading and preparing the SCF and CPS datasets - -The first step in the imputation process involves acquiring and harmonizing the two datasets. Extracting data from the SCF and the CPS, and then processing it to ensure the variables are compatible for imputation are crucial pre-processing steps for successful imputation. This involves identifying predictor variables that exist in both data sets and can meaningfully predict wealth, as well as ensuring they are named and encoded identically. - -```python -@validate_call(config=VALIDATE_CONFIG) -def scf_url(year: int, VALID_YEARS: List[int] = VALID_YEARS) -> str: - """Return the URL of the SCF summary microdata zip file for a year. - - Args: - year: Year of SCF summary microdata to retrieve. - - Returns: - URL of summary microdata zip file for the given year. - - Raises: - ValueError: If the year is not in VALID_YEARS. - """ - logger.debug(f"Generating SCF URL for year {year}") - - if year not in VALID_YEARS: - logger.error( - f"Invalid SCF year: {year}. Valid years are {VALID_YEARS}" - ) - raise ValueError( - f"The SCF is not available for {year}. Valid years are {VALID_YEARS}" - ) - - url = f"https://www.federalreserve.gov/econres/files/scfp{year}s.zip" - logger.debug(f"Generated URL: {url}") - return url - - -@validate_call(config=VALIDATE_CONFIG) -def load_scf( - years: Optional[Union[int, List[int]]] = VALID_YEARS, - columns: Optional[List[str]] = None, -) -> pd.DataFrame: - """Load Survey of Consumer Finances data for specified years and columns. - - Args: - years: Year or list of years to load data for. - columns: List of column names to load. - - Returns: - DataFrame containing the requested data. - - Raises: - ValueError: If no Stata files are found in the downloaded zip - or invalid parameters - RuntimeError: If there's a network error or a problem processing - the downloaded data - """ - - logger.info(f"Loading SCF data with years={years}") - - try: - # Identify years for download - if years is None: - years = VALID_YEARS - logger.warning(f"Using default years: {years}") - - if isinstance(years, int): - years = [years] - - # Validate all years are valid - invalid_years = [year for year in years if year not in VALID_YEARS] - if invalid_years: - logger.error(f"Invalid years specified: {invalid_years}") - raise ValueError( - f"Invalid years: {invalid_years}. Valid years are {VALID_YEARS}" - ) - - all_data: List[pd.DataFrame] = [] - - for year in tqdm(years): - logger.info(f"Processing data for year {year}") - try: - # Download zip file - logger.debug(f"Downloading SCF data for year {year}") - url = scf_url(year) - try: - response = requests.get(url, timeout=60) - response.raise_for_status() # Raise an error for bad responses - except requests.exceptions.RequestException as e: - logger.error( - f"Network error downloading SCF data for year {year}: {str(e)}" - ) - raise RuntimeError( - f"Failed to download SCF data for year {year}" - ) from e - - # Process zip file - try: - logger.debug("Creating zipfile from downloaded content") - z = zipfile.ZipFile(io.BytesIO(response.content)) - - # Find the .dta file in the zip - dta_files: List[str] = [ - f for f in z.namelist() if f.endswith(".dta") - ] - if not dta_files: - logger.error( - f"No Stata files found in zip for year {year}" - ) - raise ValueError( - f"No Stata files found in zip for year {year}" - ) - - logger.debug(f"Found Stata files: {dta_files}") - - # Read the Stata file - try: - logger.debug(f"Reading Stata file: {dta_files[0]}") - with z.open(dta_files[0]) as f: - df = pd.read_stata( - io.BytesIO(f.read()), columns=columns - ) - logger.debug( - f"Read DataFrame with shape {df.shape}" - ) - - # Ensure 'wgt' is included - if ( - columns is not None - and "wgt" not in df.columns - and "wgt" not in columns - ): - logger.debug("Re-reading with 'wgt' column added") - # Re-read to include weights - with z.open(dta_files[0]) as f: - cols_with_weight: List[str] = list( - set(columns) | {"wgt"} - ) - df = pd.read_stata( - io.BytesIO(f.read()), - columns=cols_with_weight, - ) - logger.debug( - f"Re-read DataFrame with shape {df.shape}" - ) - except Exception as e: - logger.error( - f"Error reading Stata file for year {year}: {str(e)}" - ) - raise RuntimeError( - f"Failed to process Stata file for year {year}" - ) from e - - except zipfile.BadZipFile as e: - logger.error(f"Bad zip file for year {year}: {str(e)}") - raise RuntimeError( - f"Downloaded zip file is corrupt for year {year}" - ) from e - - # Add year column - df["year"] = year - logger.info( - f"Successfully processed data for year {year}, shape: {df.shape}" - ) - all_data.append(df) - - except Exception as e: - logger.error(f"Error processing year {year}: {str(e)}") - raise - - # Combine all years - logger.debug(f"Combining data from {len(all_data)} years") - if len(all_data) > 1: - result = pd.concat(all_data) - logger.info( - f"Combined data from {len(years)} years, final shape: {result.shape}" - ) - return result - else: - logger.info( - f"Returning data for single year, shape: {all_data[0].shape}" - ) - return all_data[0] - - except Exception as e: - logger.error(f"Error in _load: {str(e)}") - raise - -scf = load_scf(2022) - -# Create mapping from desired variable names to SCF column names -scf_variable_mapping = { - "hhsex": "is_female", # sex (is female yes/no) (hhsex) - "age": "age", # age of respondent (age) - "race": "race", # race of respondent (race) - "kids": "own_children_in_household", # number of children in household (kids) - "wageinc": "employment_income", # income from wages and salaries (wageinc) - "bussefarminc": "farm_self_employment_income", # income from business, self-employment or farm (bussefarminc) - "intdivinc": "interest_dividend_income", # income from interest and dividends (intdivinc) - "ssretinc": "pension_income", # income from social security and retirement accounts (ssretinc) -} - -original_columns = list(scf_variable_mapping.keys()) + ["networth", "wgt"] -scf_df = pd.DataFrame({col: scf[col] for col in original_columns}) -scf_data = scf_df.rename(columns=scf_variable_mapping) - -# Convert hhsex to is_female (hhsex: 1=male, 2=female -> is_female: 0=male, 1=female) -scf_data["is_female"] = (scf_data["is_female"] == 2).astype(int) - -predictors = [ - "is_female", # sex of head of household - "age", # age of head of household - "own_children_in_household", # number of children in household - "race", # race of the head of household - "employment_income", # income from wages and salaries - "interest_dividend_income", # income from interest and dividends - "pension_income", # income from social security and retirement accounts -] - -imputed_variables = ["networth"] - -weights = ["wgt"] - -scf_data = scf_data[predictors + imputed_variables + weights] - -weights_col = scf_data["wgt"].values -weights_normalized = weights_col / weights_col.sum() -scf_data_weighted = scf_data.sample( - n=len(scf_data), - replace=True, - weights=weights_normalized, -).reset_index(drop=True) -``` - -```python -import ssl -import requests - -# Disable SSL verification warnings (only use in development environments) -import urllib3 -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - -# Create unverified context for SSL connections -ssl._create_default_https_context = ssl._create_unverified_context - -# Monkey patch the requests library to use the unverified context -old_get = requests.get -requests.get = lambda *args, **kwargs: old_get(*args, **{**kwargs, 'verify': False}) - -from policyengine_us_data import CPS_2024 -cps_data = CPS_2024() -#cps_data.generate() -cps = cps_data.load_dataset() - -cps_race_mapping = { - 1: 1, # White only -> WHITE - 2: 2, # Black only -> BLACK/AFRICAN-AMERICAN - 3: 5, # American Indian, Alaskan Native only -> AMERICAN INDIAN/ALASKA NATIVE - 4: 4, # Asian only -> ASIAN - 5: 6, # Hawaiian/Pacific Islander only -> NATIVE HAWAIIAN/PACIFIC ISLANDER - 6: 7, # White-Black -> OTHER - 7: 7, # White-AI -> OTHER - 8: 7, # White-Asian -> OTHER - 9: 7, # White-HP -> OTHER - 10: 7, # Black-AI -> OTHER - 11: 7, # Black-Asian -> OTHER - 12: 7, # Black-HP -> OTHER - 13: 7, # AI-Asian -> OTHER - 14: 7, # AI-HP -> OTHER - 15: 7, # Asian-HP -> OTHER - 16: 7, # White-Black-AI -> OTHER - 17: 7, # White-Black-Asian -> OTHER - 18: 7, # White-Black-HP -> OTHER - 19: 7, # White-AI-Asian -> OTHER - 20: 7, # White-AI-HP -> OTHER - 21: 7, # White-Asian-HP -> OTHER - 22: 7, # Black-AI-Asian -> OTHER - 23: 7, # White-Black-AI-Asian -> OTHER - 24: 7, # White-AI-Asian-HP -> OTHER - 25: 7, # Other 3 race comb. -> OTHER - 26: 7, # Other 4 or 5 race comb. -> OTHER -} - -# Apply the mapping to recode the race values -cps["race"] = np.vectorize(cps_race_mapping.get)(cps["cps_race"]) -cps["farm_self_employment_income"] = cps["self_employment_income"] + cps["farm_income"] -cps["interest_dividend_income"] = cps["taxable_interest_income"] + cps["tax_exempt_interest_income"] + cps["qualified_dividend_income"] + cps["non_qualified_dividend_income"] -cps["pension_income"] = cps["tax_exempt_private_pension_income"] + cps["taxable_private_pension_income"] + cps["social_security_retirement"] - -mask_head = cps["is_household_head"] -income_df = pd.DataFrame({ - "household_id": cps["person_household_id"], - "employment_income": cps["employment_income"], - "farm_self_employment_income": cps["farm_self_employment_income"], - "interest_dividend_income": cps["interest_dividend_income"], - "pension_income": cps["pension_income"], -}) -household_sums = ( - income_df - .groupby("household_id") - .sum() - .reset_index() -) -heads = pd.DataFrame({ - "household_id": cps["person_household_id"][mask_head], - "is_female": cps["is_female"][mask_head], - "age": cps["age"][mask_head], - "race": cps["race"][mask_head], - "own_children_in_household": cps["own_children_in_household"][mask_head], -}) -hh_level = heads.merge(household_sums, on="household_id", how="left") - -for name, series in cps.items(): - if isinstance(series, pd.Series) and len(series) == len(hh_level): - if name not in hh_level.columns: - hh_level[name] = series.values - - -cols = ( - ["household_id"] - + [ - "farm_self_employment_income", - "interest_dividend_income", - "pension_income", - "employment_income", - ] - + ["own_children_in_household", "is_female", "age", "race"] -) -cps_data = hh_level[cols] -cps_data["household_weight"] = cps["household_weight"] - -household_weights = ["household_weight"] - -from policyengine_us import Microsimulation -sim = Microsimulation(dataset=CPS_2022) -net_disposable_income = sim.calculate("household_net_income") - -cps_data["household_net_income"] = net_disposable_income -``` - -## Running wealth imputation with autoimpute - -After harmonizing the two datasets, the `autoimpute` function from `microimpute` can be used to transfer wealth information from the SCF to the CPS. This powerful function streamlines the imputation process by automating hyperparameter tuning, method selection, validation, and application. - -Behind the scenes, `autoimpute` evaluates multiple statistical approaches, including Quantile Random Forest, Ordinary Least Squares, Quantile Regression, and Statistical Matching. It performs cross-validation to determine which method most accurately captures the relationship between the predictor variables and wealth measures in the SCF data. The function then applies the best-performing method to generate synthetic wealth values for CPS households. - -By enabling hyperparameter tuning, the function can optimize each method's parameters, further improving imputation accuracy. This automated approach saves considerable time and effort compared to manually testing different imputation strategies, while ensuring the selection of the most appropriate method for this specific imputation task. - -```python -warnings.filterwarnings("ignore") - -# Run the autoimpute process -imputations, imputed_data, fitted_model, method_results_df = autoimpute( - donor_data=scf_data, - receiver_data=cps_data, - predictors=predictors, - imputed_variables=imputed_variables, - weight_col=weights[0], - tune_hyperparameters=True, # enable automated hyperparameter tuning - normalize_data=True, # normalization -) -``` - -## Comparing method performance - -The method comparison plot below shows how different imputation methods performed across various quantiles. Lower quantile loss values indicate better performance. - -```python -# Extract the quantiles used in the evaluation -quantiles = [q for q in method_results_df.columns if isinstance(q, float)] - -comparison_viz = method_comparison_results( - data=method_results_df, - metric="quantile_loss", - data_format="wide", -) -fig = comparison_viz.plot( - title="Autoimpute method comparison", - show_mean=True, -) -``` - -![png](./autoimpute_model_comparison.png) - -## Evaluating wealth imputations - -To assess the imputation results, a comparison the distribution of wealth in the original SCF data with the imputed values in the CPS allow examining how well the imputation preserves important characteristics of the wealth distribution, such as its shape, central tendency, and dispersion. - -Wealth distributions are typically highly skewed, with a long right tail representing a small number of households with very high net worth. A successful imputation should preserve this characteristic skewness while maintaining realistic values across the entire distribution. Examining both the raw distributions and log-transformed versions of wealth values can better capture important information for evaluation. - -```python -def plot_log_transformed_net_worth_distributions( - scf_data: pd.DataFrame, - imputed_data: pd.DataFrame, - title: Optional[str] = None, -) -> go.Figure: - """Plot the log-transformed distribution of net worth in SCF and imputed CPS data.""" - - palette = px.colors.qualitative.Plotly - scf_color = '#1f77b4' # palette[0] - cps_color = '#9467bd' # palette[1] - scf_median_color = scf_color - cps_median_color = cps_color - scf_mean_color = scf_color - cps_mean_color = cps_color - - def safe_log(x): - sign = np.sign(x) - log_x = np.log10(np.maximum(np.abs(x), 1e-10)) - return sign * log_x - - scf_log = safe_log(scf_data["networth"]) - cps_log = safe_log(imputed_data["networth"]) - - scf_log_median, cps_log_median = np.median(scf_log), np.median(cps_log) - scf_log_mean, cps_log_mean = np.mean(scf_log), np.mean(cps_log) - - fig = go.Figure() - - # histograms - fig.add_trace(go.Histogram( - x=scf_log, - nbinsx=150, - opacity=0.7, - name="SCF (weighted) log net worth", - marker_color=scf_color, - )) - fig.add_trace(go.Histogram( - x=cps_log, - nbinsx=150, - opacity=0.7, - name="CPS imputed log net worth", - marker_color=cps_color, - )) - - # medians (dashed) - fig.add_trace(go.Scatter( - x=[scf_log_median, scf_log_median], - y=[0, 20], - mode="lines", - line=dict(color=scf_median_color, width=2, dash="dash"), - name=f"SCF median: ${10**scf_log_median:,.0f}", - )) - fig.add_trace(go.Scatter( - x=[cps_log_median, cps_log_median], - y=[0, 20], - mode="lines", - line=dict(color=cps_median_color, width=2, dash="dash"), - name=f"CPS median: ${10**cps_log_median:,.0f}", - )) - - # means (dotted) - fig.add_trace(go.Scatter( - x=[scf_log_mean, scf_log_mean], - y=[0, 20], - mode="lines", - line=dict(color=scf_mean_color, width=2, dash="dot"), - name=f"SCF mean: ${10**scf_log_mean:,.0f}", - )) - fig.add_trace(go.Scatter( - x=[cps_log_mean, cps_log_mean], - y=[0, 20], - mode="lines", - line=dict(color=cps_mean_color, width=2, dash="dot"), - name=f"CPS mean: ${10**cps_log_mean:,.0f}", - )) - - # layout - fig.update_layout( - title=title if title else "Log-transformed net worth distribution comparison", - xaxis_title="Net worth", - yaxis_title="Frequency", - height=PLOT_CONFIG["height"], - width=PLOT_CONFIG["width"], - barmode="overlay", - bargap=0.1, - legend=dict( - x=0.01, y=0.99, - bgcolor="rgba(255,255,255,0.8)", - bordercolor="rgba(0,0,0,0.3)", - borderwidth=1, - orientation="v", - xanchor="left", - yanchor="top", - ), - ) - - # custom ticks - tick_values = [-6, -4, -2, 0, 2, 4, 6, 8] - tick_labels = [ - "$" + format(10**x if x >= 0 else -(10**abs(x)), ",.0f") - for x in tick_values - ] - fig.update_xaxes(tickvals=tick_values, ticktext=tick_labels) - - return fig - -weights_col = cps_data["household_weight"].values -weights_normalized = weights_col / weights_col.sum() -imputed_data_weighted = imputed_data.sample( - n=len(imputed_data), - replace=True, - weights=weights_normalized, -).reset_index(drop=True) - - -plot_log_transformed_net_worth_distributions(scf_data_weighted, imputed_data_weighted).show() -``` - -![png](./autoimpute_best_model_imputations.png) - -The logarithmic transformation provides a clearer view of the wealth distribution across its entire range. By logarithmically scaling the data, the extreme values are compressed while expanding the visibility of differences in the lower and middle portions of the distribution. - -This transformation is particularly valuable for wealth data, where values can span many orders of magnitude. The plot above, shows how closely the imputed CPS wealth distribution matches the original SCF distribution in terms of shape and central tendency after the imputation performed by the QRF model. The vertical lines marking the mean and median values help gauge how these statistical properties have been preserved through the imputation process. - -## Comparing with other methods - -```python -donor_data = scf_data[predictors + imputed_variables + weights] -receiver_data = cps_data[predictors + household_weights] - -donor_data, normalizing_params = preprocess_data(donor_data[predictors + imputed_variables], normalize=True, full_data=True) -donor_data[weights[0]] = scf_data[weights[0]] -receiver_data, _ = preprocess_data(receiver_data[predictors], normalize=True, full_data=True) -receiver_data["household_weight"] = cps_data["household_weight"] -receiver_data["household_net_income"] = cps_data["household_net_income"] - -mean = pd.Series( - {col: p["mean"] for col, p in normalizing_params.items()} -) -std = pd.Series( - {col: p["std"] for col, p in normalizing_params.items()} -) - -from microimpute.models import * - -def impute_scf_to_cps(model: Type[Imputer], - donor_data: pd.DataFrame, - receiver_data: pd.DataFrame, - cps_data: pd.DataFrame, - predictors: List[str], - imputed_variables: List[str], - weights: List[str]) -> pd.DataFrame: - """Impute SCF data into CPS data using the specified model.""" - model = model() - fitted_model = model.fit( - X_train=donor_data, - predictors=predictors, - imputed_variables=imputed_variables, - weight_col=weights[0], - ) - imputations = fitted_model.predict(X_test=receiver_data) - - cps_imputed = cps_data.copy() - cps_imputed["networth"] = imputations[0.5]["networth"] - - return cps_imputed - -quantreg_cps_imputed = impute_scf_to_cps( - model=QuantReg, - donor_data=donor_data, - receiver_data=receiver_data, - cps_data=cps_data, - predictors=predictors, - imputed_variables=imputed_variables, - weights=weights, -) - -quantreg_cps_imputed["networth"] = quantreg_cps_imputed["networth"].mul(std["networth"]) -quantreg_cps_imputed["networth"] = quantreg_cps_imputed["networth"].add(mean["networth"]) - -weights_col = receiver_data["household_weight"].values -weights_normalized = weights_col / weights_col.sum() -quantreg_cps_imputed_weighted = quantreg_cps_imputed.sample( - n=len(quantreg_cps_imputed), - replace=True, - weights=weights_normalized, -).reset_index(drop=True) - - -ols_cps_imputed = impute_scf_to_cps( - model=OLS, - donor_data=donor_data, - receiver_data=receiver_data, - cps_data=cps_data, - predictors=predictors, - imputed_variables=imputed_variables, - weights=weights, -) - -ols_cps_imputed["networth"] = ols_cps_imputed["networth"].mul(std["networth"]) -ols_cps_imputed["networth"] = ols_cps_imputed["networth"].add(mean["networth"]) - -ols_cps_imputed_weighted = ols_cps_imputed.sample( - n=len(ols_cps_imputed), - replace=True, - weights=weights_normalized, -).reset_index(drop=True) - -matching_cps_imputed = impute_scf_to_cps( - model=Matching, - donor_data=donor_data, - receiver_data=receiver_data, - cps_data=cps_data, - predictors=predictors, - imputed_variables=imputed_variables, - weights=weights, -) - -matching_cps_imputed["networth"] = matching_cps_imputed["networth"].mul(std["networth"]) -matching_cps_imputed["networth"] = matching_cps_imputed["networth"].add(mean["networth"]) - -matching_cps_imputed_weighted = matching_cps_imputed.sample( - n=len(matching_cps_imputed), - replace=True, - weights=weights_normalized, -).reset_index(drop=True) - - -qrf_model = QRF() -fitted_model, best_params = qrf_model.fit( - X_train=donor_data, - predictors=predictors, - imputed_variables=imputed_variables, - weight_col=weights[0], - tune_hyperparameters=True, -) -imputations = fitted_model.predict(X_test=receiver_data) - -qrf_cps_imputed = cps_data.copy() -qrf_cps_imputed["networth"] = imputations[0.5]["networth"] - -qrf_cps_imputed["networth"] = qrf_cps_imputed["networth"].mul(std["networth"]) -qrf_cps_imputed["networth"] = qrf_cps_imputed["networth"].add(mean["networth"]) - -qrf_cps_imputed_weighted = qrf_cps_imputed.sample( - n=len(qrf_cps_imputed), - replace=True, - weights=weights_normalized, -).reset_index(drop=True) - - -def plot_all_models_net_worth_log_distributions( - scf_data: pd.DataFrame, - model_results: dict, - title: Optional[str] = None, -) -> go.Figure: - """Plot log-transformed net worth distributions for all models in a 2x2 grid. - - Args: - scf_data: Original SCF data with networth column - model_results: Dictionary mapping model names to their imputed dataframes - title: Optional title for the entire figure - - Returns: - Plotly figure with 4 subplots - """ - from plotly.subplots import make_subplots - import plotly.graph_objects as go - - # Create subplots - fig = make_subplots( - rows=2, cols=2, - subplot_titles=list(model_results.keys()), - vertical_spacing=0.15, - horizontal_spacing=0.12, - ) - - # Define safe log transformation - def safe_log(x): - sign = np.sign(x) - log_x = np.log10(np.maximum(np.abs(x), 1e-10)) - return sign * log_x - - # Calculate SCF log values once - scf_log = safe_log(scf_data["networth"]) - scf_log_median = np.median(scf_log) - scf_log_mean = np.mean(scf_log) - - # Define colors - scf_color = '#1f77b4' - palette = px.colors.qualitative.Plotly - model_colors = palette[:4] - - # Plot positions - positions = [(1, 1), (1, 2), (2, 1), (2, 2)] - - for idx, (model_name, imputed_data) in enumerate(model_results.items()): - row, col = positions[idx] - model_color = model_colors[idx] - - # Calculate model log values - model_log = safe_log(imputed_data["networth"]) - model_log_median = np.median(model_log) - model_log_mean = np.mean(model_log) - - # Add SCF histogram (grey/transparent) - fig.add_trace( - go.Histogram( - x=scf_log, - nbinsx=150, - opacity=0.3, - name=f"SCF (weighted by sampling)", - marker_color='grey', - showlegend=(idx == 0), # Only show in legend once - ), - row=row, col=col - ) - - # Add model histogram - fig.add_trace( - go.Histogram( - x=model_log, - nbinsx=150, - opacity=0.7, - name=f"{model_name.replace(' imputations', '')}", - marker_color=model_color, - showlegend=True, - ), - row=row, col=col - ) - - # Get y-axis range for vertical lines - fig.update_yaxes(range=[0, 2000], row=row, col=col) - - # Add median lines - fig.add_trace( - go.Scatter( - x=[scf_log_median, scf_log_median], - y=[0, 2000], - mode="lines", - line=dict(color='grey', width=2, dash="dash"), - name=f"SCF Median", - showlegend=False, - ), - row=row, col=col - ) - - fig.add_trace( - go.Scatter( - x=[model_log_median, model_log_median], - y=[0, 2000], - mode="lines", - line=dict(color=model_color, width=2, dash="dash"), - name=f"{model_name} Median", - showlegend=False, - ), - row=row, col=col - ) - - # Determine correct axis references for annotations - if idx == 0: - xref, yref = "x", "y" - elif idx == 1: - xref, yref = "x2", "y2" - elif idx == 2: - xref, yref = "x3", "y3" - else: - xref, yref = "x4", "y4" - - # Add text annotations for statistics - fig.add_annotation( - x=0, # Position on the x-axis (log scale) - y=5000, # Position on the y-axis - xref=xref, - yref=yref, - text=f"Median: ${10**model_log_median:,.0f}
Mean: ${10**model_log_mean:,.0f}", - showarrow=False, - bgcolor="rgba(255,255,255,0.8)", - bordercolor="rgba(0,0,0,0.3)", - borderwidth=1, - font=dict(size=10), - xanchor="right", - yanchor="top", - ) - - # Update layout - fig.update_layout( - title=title if title else "Log-transformed net worth distributions by model", - height=800, - width=1000, - showlegend=True, - legend=dict( - x=0.5, - y=-0.2, - xanchor="center", - yanchor="top", - orientation="h", - ), - barmode="overlay", - ) - - # Update x-axes with custom tick labels - tick_values = [-6, -4, -2, 0, 2, 4, 6, 8] - tick_labels = [ - "$" + format(10**x if x >= 0 else -(10**abs(x)), ",.0f") - for x in tick_values - ] - - for i in range(1, 5): - fig.update_xaxes( - tickvals=tick_values, - ticktext=tick_labels, - title_text="Net Worth (log scale)" if i > 2 else "", - row=(i-1)//2 + 1, - col=(i-1)%2 + 1 - ) - fig.update_yaxes( - title_text="Frequency" if i % 2 == 1 else "", - row=(i-1)//2 + 1, - col=(i-1)%2 + 1 - ) - - return fig - - -# Create dictionary of model results -model_results = { - "QRF imputations": qrf_cps_imputed_weighted, - "OLS imputations": ols_cps_imputed_weighted, - "QuantReg imputations": quantreg_cps_imputed_weighted, - "Matching imputations": matching_cps_imputed_weighted, -} - -# Create and show the combined plot -combined_fig = plot_all_models_net_worth_log_distributions(scf_data_weighted, model_results) -combined_fig.show() -``` -![png](./imputations_model_comparison.png) - -Comparing the wealth distributions that result from imputing from the SCF on to the CPS with four different models, we can visually recognize the different strengths and limitations of each of them. The implications of using one model instead of another for imputation will be further explored by evaluating the impact wealth imputed data has on microsimulation results. - -## Wealth distributions by disposable income deciles - -Lastly, to confidently say that our wealth imputations are coherent with other household characteristics, we can compare the average net worth values per disposable income decile for each of the four methods used. - -```python -income_col = "household_net_income" -wealth_col = "networth" # e.g. "net_wealth" - -decile_means = [] - -for model, df in model_results.items(): - tmp = df.copy() - - # Create 1–10 decile indicator (unweighted) - tmp["income_decile"] = ( - pd.qcut(tmp[income_col], 10, labels=False) + 1 - ) - - # Mean wealth in each decile - out = ( - tmp.groupby("income_decile")[wealth_col] - .mean() - .reset_index(name="mean_wealth") - ) - out["Model"] = model - decile_means.append(out) - -avg_df = pd.concat(decile_means, ignore_index=True) -avg_df["income_decile"] = avg_df["income_decile"].astype(int) - -fig = px.bar( - avg_df, - x="income_decile", - y="mean_wealth", - color="Model", - barmode="group", - labels={ - "income_decile": "Net-income decile (1 = lowest, 10 = highest)", - "mean_wealth": "Average household net worth ($)", - }, - title=( - "Average household net worth by net-income decile
" - "Comparison of imputation models" - ), -) - -fig.update_layout( - xaxis=dict(dtick=1, tick0=1), - paper_bgcolor="#F0F0F0", - plot_bgcolor="#F0F0F0", - yaxis_tickformat="$,.0f", - hovermode="x unified", -) - -fig.update_xaxes(showgrid=False) -fig.update_yaxes(showgrid=False) - -fig.show() -``` - -![png](./by_income_decile_comparisons.png) - -QRF clearly presents the most consistent and plausable realtionship to disposable income, with a gradually increasing average as the deciles increase. This plot also supports the behavior observed above showing the extreme negative and positive predictions that OLS and QuantReg produce at the left and right tails, respectively, and Matching's underprediction at every decile. diff --git a/microimpute/comparisons/metrics.py b/microimpute/comparisons/metrics.py index c963962..29336ec 100644 --- a/microimpute/comparisons/metrics.py +++ b/microimpute/comparisons/metrics.py @@ -175,19 +175,25 @@ def order_probabilities_alphabetically( return reordered_probabilities, alphabetical_classes +# Accept numpy arrays, pandas Series, and pandas ExtensionArrays +# (e.g. ArrowExtensionArray) which newer pandas versions produce for +# string columns. +ArrayLike = Union[np.ndarray, pd.Series, pd.api.extensions.ExtensionArray] + + @validate_call(config=VALIDATE_CONFIG) def compute_loss( - test_y: np.ndarray, - imputations: np.ndarray, + test_y: ArrayLike, + imputations: ArrayLike, metric: MetricType, q: float = 0.5, - labels: Optional[np.ndarray] = None, + labels: Optional[ArrayLike] = None, ) -> Tuple[np.ndarray, float]: """Compute loss for given true values and imputations using specified metric. Args: - test_y: Array of true values. - imputations: Array of predicted/imputed values. + test_y: Array-like of true values. + imputations: Array-like of predicted/imputed values. metric: Type of metric to use ('quantile_loss' or 'log_loss'). q: Quantile value (only used for quantile_loss). labels: Possible label values (only used for log_loss). @@ -198,6 +204,13 @@ def compute_loss( Raises: ValueError: If inputs have different shapes or invalid metric type. """ + # Convert array-like inputs (e.g. ArrowStringArray, pd.Series) + # to numpy before use. + test_y = np.asarray(test_y) + imputations = np.asarray(imputations) + if labels is not None: + labels = np.asarray(labels) + try: # Validate input dimensions if len(test_y) != len(imputations): @@ -291,9 +304,9 @@ def _compute_method_losses( log.error(error_msg) raise ValueError(error_msg) - # Get values - test_values = test_y[variable].values - pred_values = imputation[quantile][variable].values + # Get values as numpy arrays (handles Arrow-backed dtypes) + test_values = np.asarray(test_y[variable]) + pred_values = np.asarray(imputation[quantile][variable]) # Compute loss _, mean_loss = compute_loss( @@ -327,9 +340,9 @@ def _compute_method_losses( log.error(error_msg) raise ValueError(error_msg) - # Get values - test_values = test_y[variable].values - pred_values = imputation[quantile][variable].values + # Get values as numpy arrays (handles Arrow-backed dtypes) + test_values = np.asarray(test_y[variable]) + pred_values = np.asarray(imputation[quantile][variable]) # Get unique labels from test data labels = np.unique(test_values) diff --git a/microimpute/config.py b/microimpute/config.py index c2b0eb7..ba5896a 100644 --- a/microimpute/config.py +++ b/microimpute/config.py @@ -55,7 +55,7 @@ "ols": { # statsmodels OLS uses default parameters # LogisticRegression params for categorical targets: - "penalty": "l2", + "l1_ratio": 0, "C": 1.0, "max_iter": 1000, }, diff --git a/microimpute/evaluations/cross_validation.py b/microimpute/evaluations/cross_validation.py index 664adb8..05f7a19 100644 --- a/microimpute/evaluations/cross_validation.py +++ b/microimpute/evaluations/cross_validation.py @@ -265,11 +265,13 @@ def _compute_fold_loss_by_metric( for var in imputed_variables: metric_type = variable_metrics[var] - # Get data for this variable - test_y_var = test_y_values[var][fold_idx] - train_y_var = train_y_values[var][fold_idx] - test_pred_var = test_results[quantile][fold_idx][var].values - train_pred_var = train_results[quantile][fold_idx][var].values + # Get data for this variable, converting to numpy to handle + # Arrow-backed dtypes (e.g. ArrowStringArray) that pydantic + # won't accept as np.ndarray. + test_y_var = np.asarray(test_y_values[var][fold_idx]) + train_y_var = np.asarray(train_y_values[var][fold_idx]) + test_pred_var = np.asarray(test_results[quantile][fold_idx][var]) + train_pred_var = np.asarray(train_results[quantile][fold_idx][var]) # Compute loss based on metric type if metric_type == "quantile_loss": diff --git a/microimpute/models/mdn.py b/microimpute/models/mdn.py index fe78e68..1649ecb 100644 --- a/microimpute/models/mdn.py +++ b/microimpute/models/mdn.py @@ -64,6 +64,19 @@ message=".*have no logger configured.*", module="pytorch_lightning.core.module", ) + warnings.filterwarnings( + "ignore", + message=".*isinstance.*LeafSpec.*is deprecated.*", + ) + warnings.filterwarnings( + "ignore", + message=".*given NumPy array is not writable.*", + ) + warnings.filterwarnings( + "ignore", + message=".*Trying to unpickle estimator.*", + module="sklearn.*", + ) # After import, also update the rank_zero_module logger from lightning_fabric.utilities.rank_zero import rank_zero_module @@ -148,6 +161,25 @@ def _generate_data_hash(X: pd.DataFrame, y: pd.Series) -> str: return hashlib.md5(combined.encode()).hexdigest()[:12] +def _get_package_versions_hash() -> str: + """Get a short hash of sklearn and pytorch_tabular versions. + + Including package versions in the cache key ensures that caches + saved with older versions are automatically invalidated when + packages are upgraded, avoiding deserialization warnings. + """ + import sklearn + + versions = [sklearn.__version__] + try: + import pytorch_tabular + + versions.append(pytorch_tabular.__version__) + except (ImportError, AttributeError): + versions.append("unknown") + return hashlib.md5("_".join(versions).encode()).hexdigest()[:6] + + def _generate_cache_key( predictors: List[str], target: str, data_hash: str ) -> str: @@ -168,7 +200,9 @@ def _generate_cache_key( predictors_hash = hashlib.md5(predictors_str.encode()).hexdigest()[:8] # Sanitize target name for filesystem safe_target = target.replace("/", "_").replace("\\", "_") - return f"{predictors_hash}_{safe_target}_{data_hash}" + # Include package versions to invalidate cache on upgrades + ver_hash = _get_package_versions_hash() + return f"{predictors_hash}_{safe_target}_{data_hash}_{ver_hash}" class _MDNModel: diff --git a/microimpute/models/ols.py b/microimpute/models/ols.py index e0ed4ee..f2fb6e8 100644 --- a/microimpute/models/ols.py +++ b/microimpute/models/ols.py @@ -59,8 +59,9 @@ def fit( y_encoded = y_encoded.fillna(0) # Default to first category # Extract relevant LR parameters from kwargs + # Use l1_ratio instead of penalty (deprecated in sklearn 1.8) classifier_params = { - "penalty": lr_kwargs.get("penalty", "l2"), + "l1_ratio": lr_kwargs.get("l1_ratio", 0), "C": lr_kwargs.get("C", 1.0), "max_iter": lr_kwargs.get("max_iter", 1000), "solver": lr_kwargs.get( diff --git a/microimpute/utils/dashboard_formatter.py b/microimpute/utils/dashboard_formatter.py index fe8dbb6..5d56309 100644 --- a/microimpute/utils/dashboard_formatter.py +++ b/microimpute/utils/dashboard_formatter.py @@ -700,10 +700,8 @@ def format_csv( # Generate histogram data for each imputed variable for var in imputed_variables: # Check if variable is categorical or numerical - if donor_data[ - var - ].dtype == "object" or pd.api.types.is_categorical_dtype( - donor_data[var] + if pd.api.types.is_string_dtype(donor_data[var]) or isinstance( + donor_data[var].dtype, pd.CategoricalDtype ): # Categorical variable hist_data = _compute_categorical_distribution( diff --git a/microimpute/visualizations/comparison_plots.py b/microimpute/visualizations/comparison_plots.py index 4e1b85a..53248cf 100644 --- a/microimpute/visualizations/comparison_plots.py +++ b/microimpute/visualizations/comparison_plots.py @@ -200,11 +200,19 @@ def _process_dual_metrics_input( ): std_results = ql_data["results_std"].loc["test"] + # Loss values in results are already averaged + # across variables, so create ONE row per + # (method, quantile) to avoid duplicate bars. + variables = ql_data.get("variables", ["y"]) + var_label = ( + variables[0] if len(variables) == 1 else "average" + ) + for quantile in test_results.index: - for var in ql_data.get("variables", ["y"]): - row = { + long_format_data.append( + { "Method": method_name, - "Imputed Variable": var, + "Imputed Variable": var_label, "Percentile": quantile, "Loss": test_results[quantile], "Metric": "quantile_loss", @@ -214,21 +222,20 @@ def _process_dual_metrics_input( else np.nan ), } - long_format_data.append(row) + ) # Add mean loss if "mean_test" in ql_data: - for var in ql_data.get("variables", ["y"]): - long_format_data.append( - { - "Method": method_name, - "Imputed Variable": var, - "Percentile": "mean_quantile_loss", - "Loss": ql_data["mean_test"], - "Metric": "quantile_loss", - "Std": ql_data.get("std_test", np.nan), - } - ) + long_format_data.append( + { + "Method": method_name, + "Imputed Variable": var_label, + "Percentile": "mean_quantile_loss", + "Loss": ql_data["mean_test"], + "Metric": "quantile_loss", + "Std": ql_data.get("std_test", np.nan), + } + ) # Process log loss if available if ( @@ -240,35 +247,41 @@ def _process_dual_metrics_input( ll_data.get("results") is not None and not ll_data["results"].empty ): - # Log loss is constant across quantiles + # Log loss is constant across quantiles. + # Same as quantile_loss: one row per method, + # not per variable, since values are aggregated. if "test" in ll_data["results"].index: test_loss = ll_data["results"].loc["test"].mean() test_std = ll_data.get("std_test", np.nan) - for var in ll_data.get("variables", []): - long_format_data.append( - { - "Method": method_name, - "Imputed Variable": var, - "Percentile": "log_loss", - "Loss": test_loss, - "Metric": "log_loss", - "Std": test_std, - } - ) + ll_variables = ll_data.get("variables", []) + ll_var_label = ( + ll_variables[0] + if len(ll_variables) == 1 + else "average" + ) + long_format_data.append( + { + "Method": method_name, + "Imputed Variable": ll_var_label, + "Percentile": "log_loss", + "Loss": test_loss, + "Metric": "log_loss", + "Std": test_std, + } + ) # Add mean loss if "mean_test" in ll_data: - for var in ll_data.get("variables", []): - long_format_data.append( - { - "Method": method_name, - "Imputed Variable": var, - "Percentile": "mean_log_loss", - "Loss": ll_data["mean_test"], - "Metric": "log_loss", - "Std": ll_data.get("std_test", np.nan), - } - ) + long_format_data.append( + { + "Method": method_name, + "Imputed Variable": ll_var_label, + "Percentile": "mean_log_loss", + "Loss": ll_data["mean_test"], + "Metric": "log_loss", + "Std": ll_data.get("std_test", np.nan), + } + ) self.comparison_data = pd.DataFrame(long_format_data) @@ -805,7 +818,7 @@ def _plot_stacked_contribution( ) if title is None: - title = "Rank-based mmodel performance by variable (lower is better)" + title = "Rank-based model performance by variable (lower is better)" fig.update_layout( title=title, diff --git a/microimpute/visualizations/performance_plots.py b/microimpute/visualizations/performance_plots.py index 8e9de4a..6c88f70 100644 --- a/microimpute/visualizations/performance_plots.py +++ b/microimpute/visualizations/performance_plots.py @@ -67,6 +67,14 @@ def __init__( # Handle different input formats if isinstance(results, pd.DataFrame): # Backward compatibility: single metric DataFrame + # Note: passing a bare DataFrame loses results_std, so error + # bars will not be shown. Pass the full cross_validate_model + # output dict to preserve error bar data. + logger.warning( + "Received a bare DataFrame instead of the full results dict. " + "Error bars will not be available. Pass the full output of " + "cross_validate_model() to enable error bars." + ) self.results = {"quantile_loss": {"results": results.copy()}} self.has_quantile_loss = True self.has_log_loss = False diff --git a/paper/main.pdf b/paper/main.pdf index e4042bc..6eb6047 100644 Binary files a/paper/main.pdf and b/paper/main.pdf differ diff --git a/paper/sections/abstract.tex b/paper/sections/abstract.tex index dd447ee..4dcb08e 100644 --- a/paper/sections/abstract.tex +++ b/paper/sections/abstract.tex @@ -1,3 +1,3 @@ \section*{Abstract} -Microdata surveys often lack variables critical for policy analysis, requiring imputation from richer donor surveys, yet researchers have lacked standardized tools for systematically comparing imputation methods and selecting the best approach for a given dataset. We introduce $\texttt{microimpute}$, an open-source Python package that provides a unified framework for benchmarking, tuning, and automatically selecting among multiple imputation methods. We apply it to impute U.S. household wealth from the Survey of Consumer Finances (SCF) onto the Current Population Survey (CPS). Evaluating five methods---Quantile Random Forests (QRF), Ordinary Least Square (OLS), Quantile Regression, Hot Deck Matching, and Mixture Density Networks, we find that QRF achieves the lowest quantile loss and Wasserstein distance, reflecting its capacity to model the non-linear relationships between demographic predictors and wealth. A simulation of the Supplemental Security Income Savings Penalty Elimination Act demonstrates the practical stakes of method choice for downstream analysis. Without imputed wealth, the microsimulation overestimates SSI recipients by 167\%, and the SSI policy reform would be entirely un-simulable. Cross-dataset benchmarking across six additional domains reveals, however, that no single method dominates universally. Although QRF proves best at capturing complex, non-linear relationships, Hot Deck Matching better preserves marginal distributions by drawing directly from the donor pool, while OLS remains competitive for approximately linear relationships. These findings underscore that imputation method selection should be empirically driven rather than prescribed, motivating standardized benchmarking tools, like $\texttt{microimpute}$, that enable researchers to identify the best approach for their specific data characteristics and research objectives. \ No newline at end of file +Microdata surveys often lack variables needed for policy analysis, requiring imputation from richer donor surveys, yet no standardized tools exist for comparing imputation methods and selecting the best approach for a given dataset. We introduce $\texttt{microimpute}$, an open-source Python package for benchmarking, tuning, and automatically selecting among multiple imputation methods. We apply it to impute U.S. household wealth from the Survey of Consumer Finances (SCF) onto the Current Population Survey (CPS). Evaluating five methods (Quantile Random Forests, Ordinary Least Squares, Quantile Regression, Hot Deck Matching, and Mixture Density Networks) through a 5-fold cross-validation, we find that QRF reduces average quantile loss by 76\% relative to the worst-performing method and by 7\% relative to the next-best method (Hot Deck Matching), owing to its ability to model non-linear relationships between demographic predictors and wealth. Its Wasserstein distance to the donor distribution is 24\% lower than Matching's. A simulation of the Supplemental Security Income Savings Penalty Elimination Act shows what is at stake: without imputed wealth, the microsimulation overestimates SSI recipients by 167\%, and the reform is entirely un-simulable. Cross-dataset benchmarking across six additional domains, however, revealsthat no single method dominates universally. QRF performs best when relationships are non-linear, Hot Deck Matching better preserves marginal distributions by drawing directly from the donor pool, and OLS remains competitive when relationships are approximately linear. Imputation method selection should therefore be empirically driven rather than prescribed, which motivates tools like $\texttt{microimpute}$ for systematic comparison. \ No newline at end of file diff --git a/paper/sections/appendix_benchmarking.tex b/paper/sections/appendix_benchmarking.tex index 3f23738..d546fae 100644 --- a/paper/sections/appendix_benchmarking.tex +++ b/paper/sections/appendix_benchmarking.tex @@ -73,4 +73,4 @@ \subsection{Results} Matching achieves the best Wasserstein distance on four of six datasets and the lowest mean rank overall (1.67), followed closely by QRF (1.83). QRF performs best on space\_ga and brazilian\_houses, the latter of which exhibits complex nonlinear relationships between predictors and house prices. OLS and Quantile Regression occupy middle ranks, while MDN consistently ranks last, likely reflecting a combination of the relatively small sample sizes in these datasets and the hyperparameter sensitivity of neural network-based approaches. -With only six datasets, the rank differences between methods---particularly the narrow gap between Matching (1.67) and QRF (1.83)---should be interpreted cautiously, as they are not statistically robust to the inclusion or exclusion of individual datasets. Notably, the relative rankings here differ from the main SCF-CPS analysis, where QRF achieves the lowest quantile loss. This is consistent with the expectation that method performance depends on dataset characteristics, and reinforces the value of \texttt{microimpute}'s method comparison framework, which allows researchers to empirically determine the best approach for their specific data rather than relying on a single method. +With only six datasets, the rank differences between methods, particularly the narrow gap between Matching (1.67) and QRF (1.83), should be interpreted cautiously; they are not statistically robust to the inclusion or exclusion of individual datasets. The relative rankings here also differ from the main SCF-CPS analysis, where QRF achieves the lowest quantile loss. This is consistent with the expectation that method performance depends on dataset characteristics, and it supports the case for empirically comparing methods on each new dataset rather than relying on a single approach. diff --git a/paper/sections/appendix_robustness.tex b/paper/sections/appendix_robustness.tex index f143f0e..fb840e0 100644 --- a/paper/sections/appendix_robustness.tex +++ b/paper/sections/appendix_robustness.tex @@ -54,4 +54,4 @@ \subsection{Progressive inclusion analysis} \subsection{Implications for the Conditional Independence Assumption} -The strong predictive power of the shared financial variables, particularly interest and dividend income, employment income, and pension income, is consistent with, though not sufficient evidence for, the plausibility of the CIA in the SCF-CPS matching context. These variables capture key dimensions of household financial position that are available in both surveys. The fact that they collectively explain a large share of the variation in net worth suggests that the common variables $X$ may be sufficient to render the target variable $Y$ (wealth) approximately conditionally independent of variables $Z$ unique to the CPS, though this assumption remains fundamentally untestable as discussed in Section~2.2. +The strong predictive power of the shared financial variables, particularly interest and dividend income, employment income, and pension income, is consistent with, though not sufficient evidence for, the plausibility of the CIA in the SCF-CPS matching context. These variables capture the main dimensions of household financial position available in both surveys. The fact that they collectively explain a large share of the variation in net worth suggests that the common variables $X$ may be sufficient to render the target variable $Y$ (wealth) approximately conditionally independent of variables $Z$ unique to the CPS, though this assumption remains fundamentally untestable as discussed in Section~2.2. diff --git a/paper/sections/background.tex b/paper/sections/background.tex index 73b77fd..ef7cbe0 100644 --- a/paper/sections/background.tex +++ b/paper/sections/background.tex @@ -1,6 +1,6 @@ \section{Background} -This section establishes the theoretical foundations for statistical matching and reviews the imputation methods implemented in \texttt{microimpute}. We begin with the formal problem definition and the key assumption underlying all statistical matching procedures, then discuss the relationship to missing data mechanisms, and finally describe each of the five imputation methods. +This section defines the statistical matching problem and describes the imputation methods in \texttt{microimpute}. We begin with the formal problem definition and the assumption underlying all statistical matching procedures, then discuss the relationship to missing data mechanisms, and finally describe each of the five imputation methods. \subsection{The Statistical Matching Problem} @@ -56,7 +56,7 @@ \subsection{Imputation Methods} \subsubsection{Hot Deck Matching} -Hot deck imputation replaces missing values in a receiver record with observed values from ``similar'' donor records \citep{andridge2010review}. The \texttt{microimpute} implementation uses the unconstrained distance hot deck approach from the StatMatch R package \citep{dorazio2021statistical}. To identify the best match, the method computes distances between each receiver observation and all donor observations based on the common variables $X$, selects donor records within a specified distance threshold, and randomly samples from the eligible donors, with optional weighting by survey weights. For continuous and mixed covariates, Mahalanobis or Gower distance metrics can capture similarity across different variable types. The donated value will thus be an actual observed value from the donor file, ensuring plausibility. +Hot deck imputation replaces missing values in a receiver record with observed values from ``similar'' donor records \citep{andridge2010review}. The \texttt{microimpute} implementation uses the unconstrained distance hot deck approach from the StatMatch R package \citep{dorazio2021statistical}. To identify the best match, the method computes distances between each receiver observation and all donor observations based on the common variables $X$, selects donor records within a specified distance threshold, and randomly samples from the eligible donors, with optional weighting by survey weights. For continuous and mixed covariates, Mahalanobis or Gower distance metrics can capture similarity across different variable types. The donated value is therefore always an actual observed value from the donor file. Hot deck methods are nonparametric and avoid distributional assumptions, making them robust when the true conditional distribution is unknown or complex \citep{dorazio2006statistical}. However, they face limitations: \begin{itemize} @@ -186,18 +186,18 @@ \subsection{Current Practice in Microsimulation} Statistical matching and data fusion are fundamental operations in microsimulation modeling, yet the methods employed in practice have remained relatively unchanged for decades. A review of major tax-benefit microsimulation models reveals a strong reliance on traditional imputation approaches, primarily hot deck matching and OLS-based regression. -Hot deck matching remains the dominant approach in European microsimulation. EUROMOD, the EU-wide tax-benefit model, employs a multi-stage imputation procedure combining predictive mean matching with distance-based hot deck methods to integrate consumption data from Household Budget Surveys into its EU-SILC input data \citep{sutherland2013euromod}. The appeal of hot deck methods lies in their simplicity and the guarantee that imputed values are observed values from the donor file, ensuring plausibility. However, as discussed above, these methods struggle with tail behavior and may not adequately capture the full conditional distribution of the target variable. +Hot deck matching remains the dominant approach in European microsimulation. EUROMOD, the EU-wide tax-benefit model, employs a multi-stage imputation procedure combining predictive mean matching with distance-based hot deck methods to integrate consumption data from Household Budget Surveys into its EU-SILC input data \citep{sutherland2013euromod}. Hot deck methods are simple to implement and guarantee that imputed values are actual observed values from the donor file. However, as discussed above, these methods struggle with tail behavior and may not adequately capture the full conditional distribution of the target variable. -In U.S. tax policy microsimulation, regression-based approaches are more common. The Tax Policy Center employs a two-stage probit and OLS procedure for wealth imputation, first predicting the probability of holding each asset type, then predicting amounts conditional on positive holdings \citep{nunns2012tax}. Similarly, the Institute on Taxation and Economic Policy (ITEP) model relies on statistical matching between tax return data and the American Community Survey, supplemented with regression-based imputations from the Survey of Consumer Finances \citep{itep2023model}. These regression approaches assume linear relationships and normally distributed errors; assumptions that are frequently violated by economic variables with heavy tails and heteroscedastic relationships. +In U.S. tax policy microsimulation, regression-based approaches are more common. The Tax Policy Center employs a two-stage probit and OLS procedure for wealth imputation, first predicting the probability of holding each asset type, then predicting amounts conditional on positive holdings \citep{nunns2012tax}. Similarly, the Institute on Taxation and Economic Policy (ITEP) model relies on statistical matching between tax return data and the American Community Survey, supplemented with regression-based imputations from the Survey of Consumer Finances \citep{itep2023model}. These regression approaches assume linear relationships and normally distributed errors, assumptions frequently violated by economic variables with heavy tails and heteroscedasticity. -More recent microsimulation efforts have begun incorporating administrative data through record linkage rather than statistical matching \citep{abowd2019census}, but this approach requires access to restricted data and raises privacy concerns. For researchers working with publicly available survey data, statistical matching remains essential, creating a need for methods that can better capture complex distributional features. +More recent microsimulation efforts have begun incorporating administrative data through record linkage rather than statistical matching \citep{abowd2019census}, but this approach requires access to restricted data and raises privacy concerns. For researchers working with publicly available survey data, statistical matching remains the primary option, and there is room for methods that capture distributional features more effectively. -The limitations of traditional approaches can be particularly acute for heavily-tailed variables where relationships with predictors vary across the distribution. Machine learning methods such as Quantile Random Forests offer a promising alternative, as they can capture nonlinear relationships, model the entire conditional distribution, and handle heavy-tailed data without restrictive parametric assumptions. Despite these advantages, QRF and similar methods have seen limited adoption in mainstream microsimulation practice. The \texttt{microimpute} package aims to lower barriers to adopting these more flexible methods by providing a unified framework for comparing traditional and machine learning approaches, enabling researchers to empirically evaluate which method best suits their specific data characteristics. +These limitations are especially pronounced for heavy-tailed variables where the relationship with predictors varies across the distribution. Machine learning methods such as Quantile Random Forests can capture nonlinear relationships, model the entire conditional distribution, and handle heavy-tailed data without restrictive parametric assumptions, yet they have seen limited adoption in microsimulation practice. The \texttt{microimpute} package provides a common framework for comparing traditional and machine learning approaches, so that researchers can empirically evaluate which method works best for their data. \subsection{Supplemental Security Income as a Validation Case} -The limitations of traditional imputation approaches in microsimulation have concrete consequences on downstream policy analysis. The impact of wealth imputation quality for U.S. households can be illustrated through the Supplemental Security Income (SSI) program. SSI is a federal means-tested program for aged, blind, and disabled individuals with strict resource limits (\$2,000 for individuals, \$3,000 for couples) that have remained unchanged since 1989 \citep{ssa2023ssi}. The resource test requires household wealth and asset data that the CPS does not collect, making SSI eligibility determination dependent on imputed or heuristic wealth values. +The limitations of traditional imputation approaches have concrete consequences for downstream policy analysis. The Supplemental Security Income (SSI) program illustrates this well. SSI is a federal means-tested program for aged, blind, and disabled individuals with strict resource limits (\$2,000 for individuals, \$3,000 for couples) that have remained unchanged since 1989 \citep{ssa2023ssi}. The resource test requires household wealth and asset data that the CPS does not collect, making SSI eligibility determination dependent on imputed or heuristic wealth values. In 2024, SSI served approximately 7.4 million recipients with \$59.6 billion in federal payments \citep{ssa2024ssi}. Because different wealth imputation methods produce different wealth distributions, they directly affect which simulated households pass the resource test and therefore the resulting estimates of SSI recipient counts and expenditure. This makes SSI a particularly informative case for comparing imputation approaches, as described in Section~\ref{sec:policy_validation}. -The frozen resource limits also make SSI a compelling case for policy reform analysis. In 2024 dollars, the original 1989 thresholds of \$2,000/\$3,000 are worth approximately \$4,800/\$7,200, meaning the resource test has become substantially more restrictive over time without any legislative change. The SSI Savings Penalty Elimination Act (S.~1234 / H.R.~2540), a bipartisan bill introduced in April 2025, proposes raising these limits to \$10,000 for individuals and \$20,000 for couples, with future CPI indexing \citep{ssi_spea_2025}. Simulating the impact of this reform, and determining which additional households would gain eligibility under the higher thresholds requires household-level wealth data, making it a natural application for cross-survey imputation. As we demonstrate in Section~\ref{sec:policy_validation}, the choice of imputation method produces meaningfully different estimates of this reform's impact, underscoring that imputation is not merely a data preprocessing step but a consequential modeling choice for policy analysis. +The frozen resource limits also make SSI a compelling case for policy reform analysis. In 2024 dollars, the original 1989 thresholds of \$2,000/\$3,000 are worth approximately \$4,800/\$7,200, meaning the resource test has become substantially more restrictive over time without any legislative change. The SSI Savings Penalty Elimination Act (S.~1234 / H.R.~2540), a bipartisan bill introduced in April 2025, proposes raising these limits to \$10,000 for individuals and \$20,000 for couples, with future CPI indexing \citep{ssi_spea_2025}. Simulating the impact of this reform, and determining which additional households would gain eligibility under the higher thresholds requires household-level wealth data, making it a natural application for cross-survey imputation. As we show in Section~\ref{sec:policy_validation}, the choice of imputation method produces meaningfully different estimates of this reform's impact, meaning that imputation is not merely a data preprocessing step but a modeling choice with real consequences for policy analysis. diff --git a/paper/sections/conclusion.tex b/paper/sections/conclusion.tex index 3d9a412..362be7b 100644 --- a/paper/sections/conclusion.tex +++ b/paper/sections/conclusion.tex @@ -1,7 +1,7 @@ \section{Conclusion} -This paper has introduced $\texttt{microimpute}$, an open-source Python package for cross-survey imputation, and applied it to a substantive policy problem: imputing household wealth from the SCF onto the CPS. Through systematic comparison of five imputation methods using an inverse hyperbolic sine transformation to accommodate negative and zero net worth values, we find that Quantile Random Forests achieve the strongest conditional accuracy among the five methods evaluated for wealth imputation, driven by their capacity to model the complex non-linear relationships between demographic predictors and wealth. A downstream SSI simulation validates the practical importance of imputation method choice: without wealth data, the microsimulation overestimates SSI recipients by 167\%, while imputation-based approaches substantially narrow this gap. Simulating the SSI Savings Penalty Elimination Act, which would raise resource limits from \$2,000/\$3,000 to \$10,000/\$20,000, further demonstrates that method choice directly affects policy estimates. Reform impact estimates range from 0.1 to 4.2 million additional recipients and \$1.0 to \$24.1 billion in additional expenditure depending on the imputation method. Without imputed wealth, this reform is entirely un-simulable. +This paper has introduced $\texttt{microimpute}$, an open-source Python package for cross-survey imputation, and applied it to a substantive policy problem: imputing household wealth from the SCF onto the CPS. Comparing five imputation methods with an inverse hyperbolic sine transformation to accommodate negative and zero net worth values, we find that Quantile Random Forests achieve the strongest conditional accuracy for wealth imputation, owing to their ability to model the non-linear relationships between demographic predictors and wealth. A downstream SSI simulation confirms the practical importance of method choice. Without wealth data, the microsimulation overestimates SSI recipients by 167\%, while imputation-based approaches substantially narrow this gap. Simulating the SSI Savings Penalty Elimination Act, which would raise resource limits from \$2,000/\$3,000 to \$10,000/\$20,000, further shows that method choice directly affects policy estimates. Reform impact estimates range from 0.1 to 4.2 million additional recipients and \$1.0 to \$24.1 billion in additional expenditure depending on the imputation method. Without imputed wealth, this reform cannot be simulated at all. -Our findings also reveal important nuances about method selection. Cross-dataset benchmarking shows that QRF's advantage reflects the specific characteristics of the wealth imputation problem rather than a universal superiority, and predictor importance analysis confirms that financial variables drive the majority of imputation accuracy. +The findings also reveal nuances about method selection. Cross-dataset benchmarking shows that QRF's advantage reflects the specific characteristics of the wealth imputation problem rather than a universal superiority, and predictor importance analysis confirms that financial variables drive the majority of imputation accuracy. -These results carry two broader implications. First, imputation method selection should be empirically driven rather than prescribed; the optimal approach depends on dataset characteristics, and standardized comparison tools are essential for enabling this data-driven selection. Second, the gap between statistical and policy-relevant evaluation metrics---QRF achieves the best quantile loss yet not the closest SSI estimates to administrative totals---highlights the importance of validating imputation quality through downstream applications, not only through statistical benchmarks. Future work should expand $\texttt{microimpute}$'s method library to include gradient boosting and deep learning approaches \citep{alaa2024deep}, explore ensemble strategies that combine methods across different parts of the distribution, and address the terminal node sparsity challenge that limits tree-based methods at extreme quantiles. \ No newline at end of file +Two broader implications follow. First, imputation method selection should be empirically driven rather than prescribed; the optimal approach depends on dataset characteristics, and standardized comparison tools are needed to support this. Second, the gap between statistical and policy-relevant evaluation metrics (QRF achieves the best quantile loss yet not the closest SSI estimates to administrative totals) points to the importance of validating imputation quality through downstream applications, not only through statistical benchmarks. Future work should expand $\texttt{microimpute}$'s method library to include gradient boosting and deep learning approaches \citep{alaa2024deep}, explore ensemble strategies that combine methods across different parts of the distribution, and address the terminal node sparsity problem that limits tree-based methods at extreme quantiles. \ No newline at end of file diff --git a/paper/sections/data.tex b/paper/sections/data.tex index de0f564..d51902a 100644 --- a/paper/sections/data.tex +++ b/paper/sections/data.tex @@ -4,16 +4,16 @@ \section{Data}\label{sec:data} \subsection{Survey of Consumer Finances} -The Survey of Consumer Finances, sponsored by the Federal Reserve Board, is a triennial survey providing detailed information on U.S. households' assets, liabilities, income, and demographic characteristics. Its dual-frame sample design includes a standard national area-probability sample and a list sample deliberately oversampling wealthy households to better capture the skewed wealth distribution \citep{barcelo2006imputation}. The SCF is a benchmark for wealth imputation research due to its detailed financial data and the known complexities arising from its design and the nature of wealth. Item nonresponse in public-use SCF datasets is addressed by the Federal Reserve through a multiple imputation approach that generates five complete datasets with different imputed values, using sequential regression-based procedures that incorporate range constraints, logical data structures, and empirical residuals to preserve the complex multivariate relationships inherent in wealth data \citep{kennickell1998multiple}. +The Survey of Consumer Finances, sponsored by the Federal Reserve Board, is a triennial survey providing detailed information on U.S. households' assets, liabilities, income, and demographic characteristics. Its dual-frame sample design includes a standard national area-probability sample and a list sample that deliberately oversamples wealthy households to better capture the skewed wealth distribution \citep{barcelo2006imputation}. The SCF is a standard reference for wealth imputation research because of its detailed financial data and the known complexities arising from its design and the nature of wealth. Item nonresponse in public-use SCF datasets is addressed by the Federal Reserve through a multiple imputation approach that generates five complete datasets with different imputed values, using sequential regression-based procedures that incorporate range constraints, logical data structures, and empirical residuals to preserve the complex multivariate relationships inherent in wealth data \citep{kennickell1998multiple}. Specifically, we use the 2022 summarized SCF as our donor dataset. \subsection{Current Population Survey} -The Current Population Survey, conducted jointly by the U.S. Census Bureau and the Bureau of Labor Statistics, is a monthly survey of approximately 60,000 U.S. households that serves as the primary source of labor force statistics for the United States. The CPS uses a multistage probability-based sample designed to represent the civilian non-institutional population; on average, each sampled household represents approximately 2,500 households in the population. The Annual Social and Economic Supplement (ASEC) extends the core survey with detailed annual income data, including earnings, unemployment compensation, Social Security, pension income, interest, dividends, and other income sources. However, despite its comprehensive coverage of income and employment, the CPS does not collect information on household wealth, assets, or liabilities, which are key variables needed for wealth-based policy analysis. This omission motivates the need for statistical matching to transfer wealth information from the SCF onto the CPS. +The Current Population Survey, conducted jointly by the U.S. Census Bureau and the Bureau of Labor Statistics, is a monthly survey of approximately 60,000 U.S. households and is the primary source of labor force statistics for the United States. The CPS uses a multistage probability-based sample designed to represent the civilian non-institutional population; on average, each sampled household represents approximately 2,500 households in the population. The Annual Social and Economic Supplement (ASEC) extends the core survey with detailed annual income data, including earnings, unemployment compensation, Social Security, pension income, interest, dividends, and other income sources. However, despite its broad coverage of income and employment, the CPS does not collect information on household wealth, assets, or liabilities, which are needed for wealth-based policy analysis. This omission motivates the need for statistical matching to transfer wealth information from the SCF onto the CPS. Specifically, we use the Enhanced CPS 2024 produced by PolicyEngine \citep{policyengine_us} as our receiver dataset. The Enhanced CPS extends the base CPS ASEC microdata through survey reweighting and calibration to administrative totals, improving its representativeness for tax-benefit microsimulation. The SCF 2022 financial variables are uprated to 2024 dollars using CPI adjustment factors to ensure consistency with the receiver dataset's reference year. \subsection{Comparative analysis and characteristics for imputation} -The SCF and CPS differ fundamentally in sampling design: the SCF employs a dual-frame approach combining a standard area-probability sample with a list sample that deliberately oversamples wealthy households, while the CPS uses a multistage area-probability household design representative of the civilian non-institutional population. This difference means that the donor and receiver surveys weight different parts of the wealth distribution differently, making proper survey weight integration during model training essential for unbiased imputation. The two surveys share a core set of demographic and income variables---age, sex, race, number of children, employment income, interest and dividend income, and pension income---but differ in their detailed variable coverage: the SCF collects granular asset and liability data absent from the CPS, while the CPS captures labor force dynamics and program participation variables not available in the SCF. This partial overlap constrains the conditioning set available for imputation and underscores the importance of the conditional independence assumption discussed in Section~\ref{sec:cia}. Together, the contrasting designs, the heavy-tailed nature of wealth, and the variable overlap constraints make the SCF-to-CPS imputation a particularly informative test case for comparing methods, while also being directly relevant to U.S. policy microsimulation, where researchers require comprehensive microdata combining income, demographics, and wealth to analyze the distributional effects of tax and benefit reforms. \ No newline at end of file +The SCF and CPS differ fundamentally in sampling design: the SCF employs a dual-frame approach combining a standard area-probability sample with a list sample that deliberately oversamples wealthy households, while the CPS uses a multistage area-probability household design representative of the civilian non-institutional population. This difference means that the donor and receiver surveys weight different parts of the wealth distribution differently, making proper survey weight integration during model training essential for unbiased imputation. The two surveys share a core set of demographic and income variables (age, sex, race, number of children, employment income, interest and dividend income, and pension income) but differ in detailed coverage: the SCF collects granular asset and liability data absent from the CPS, while the CPS captures labor force dynamics and program participation variables not available in the SCF. This partial overlap constrains the conditioning set available for imputation and reinforces the importance of the conditional independence assumption discussed in Section~\ref{sec:cia}. The contrasting survey designs, the heavy-tailed nature of wealth, and the limited variable overlap make the SCF-to-CPS imputation an informative test case for comparing methods. It is also directly relevant to U.S. policy microsimulation, where analyzing the distributional effects of tax and benefit reforms requires microdata that combine income, demographics, and wealth. \ No newline at end of file diff --git a/paper/sections/discussion.tex b/paper/sections/discussion.tex index 6e9627a..3600d45 100644 --- a/paper/sections/discussion.tex +++ b/paper/sections/discussion.tex @@ -1,56 +1,42 @@ \section{Discussion} -This paper has evaluated five imputation methods within the \texttt{microimpute} framework, finding that Quantile Random Forests offer meaningful advantages for the SCF-to-CPS wealth imputation application. By preserving the full conditional distribution of wealth, QRF maintains the key statistical properties of wealth data that traditional methods struggle to capture. +This paper has evaluated five imputation methods within the \texttt{microimpute} framework, finding that Quantile Random Forests offer meaningful advantages for the SCF-to-CPS wealth imputation application. By preserving the full conditional distribution of wealth, QRF captures statistical properties of wealth data that traditional methods miss. -\subsection{Strengths powered by $\texttt{microimpute}$} +\subsection{The role of \texttt{microimpute} in the analysis} -The $\texttt{microimpute}$ package's design philosophy and implementation choices provide several key advantages that contributed to the success of our wealth imputation analysis, and will extend to future model comparisons across applications: +Several design choices in \texttt{microimpute} shaped the analysis and its results. Wrapping all five methods behind a consistent API meant that performance differences in the benchmarking reflect the methods themselves, not differences in how they were implemented or tuned \citep{policyengine2025microimpute}. The \texttt{autoimpute} function automated hyperparameter tuning and method selection based on quantile loss, removing a subjective step that could otherwise bias the comparison. Using quantile loss rather than RMSE as the evaluation metric was itself consequential: its asymmetric penalty structure prioritizes accuracy at distribution tails, where symmetric metrics are insensitive to the kinds of errors that matter most for skewed variables like wealth. -\begin{enumerate} - \item \textbf{Unified interface for method comparison}: $\texttt{microimpute}$'s consistent API across all imputation methods enabled systematic benchmarking without implementation-specific biases. This standardization ensures that performance differences reflect genuine methodological advantages rather than implementation artifacts \citep{policyengine2025microimpute}. - - \item \textbf{Automated method selection}: The package's \texttt{autoimpute} function streamlines the imputation workflow by automatically comparing methods and selecting the best performer based on quantile loss metrics. This feature proved particularly valuable given the wealth data's complexity, as it removed subjective method selection and ensured optimal performance. - - \item \textbf{Survey weight integration}: $\texttt{microimpute}$'s native support for survey weights through stratified sampling ensures that imputation models properly represent population distributions. This capability is crucial when transferring information between surveys with different sampling designs, such as the SCF's oversampling of wealthy households. - - \item \textbf{Quantile-aware evaluation}: By implementing quantile loss as the primary evaluation metric, $\texttt{microimpute}$ directly addresses the challenges of skewed distributions. This metric's asymmetric penalty structure naturally prioritizes accurate imputation at distribution tails, where traditional metrics like Root Mean Squared Error (RMSE) often fail. - - \item \textbf{Computational efficiency}: The package's optimized implementation enables processing of large microdata files while maintaining reasonable computation times. Cross-validation on the full SCF dataset, including QRF hyperparameter tuning, was completed in under 30 minutes on standard hardware, making iterative experimentation feasible. - - \item \textbf{Open-source transparency}: As an open-source tool, $\texttt{microimpute}$ allows full inspection and modification of imputation algorithms, promoting reproducibility and enabling custom extensions for specific research needs \citep{policyengine2025microimpute}. -\end{enumerate} +Two practical aspects also mattered. First, native survey weight support through stratified sampling allowed the models to train on data representative of the U.S. population despite the SCF's deliberate oversampling of wealthy households. Second, the full cross-validation pipeline, including hyperparameter tuning on the SCF for the QRF, MDN, and Matching models ran in under 45 minutes on standard hardware, making iterative experimentation practical. The package is open-source, so all algorithms and evaluation procedures can be inspected and modified \citep{policyengine2025microimpute}. \subsection{Generalizability and predictor sensitivity} -The main SCF-CPS results are complemented by two supplementary analyses that inform the broader applicability of these findings. +Two supplementary analyses help contextualize the main SCF-CPS results. -The cross-dataset benchmarking exercise (Appendix~\ref{app:benchmarking}) reveals that method performance is context-dependent, with each method's strengths mapping to different data characteristics. QRF achieves the best Wasserstein distance on datasets exhibiting complex non-linear relationships, such as the Brazilian houses dataset and the SCF wealth imputation, where its ability to model flexible conditional distributions provides a clear advantage. Matching, by contrast, achieves the lowest Wasserstein distance on four of six benchmarking datasets and the best mean rank overall (1.67 vs.\ QRF's 1.83). This strong marginal distributional performance reflects Matching's core mechanism: by drawing imputed values directly from the donor pool, it inherently preserves the shape, skewness, and tail behavior of the original distribution, yielding low Wasserstein distances even when it may not optimally condition on predictor values. However, as the main SCF-CPS results demonstrate, Matching's distributional fidelity does not guarantee accurate conditional imputation. Households may receive plausible-looking wealth values that do not correctly reflect their individual characteristics. Thus, understanding the trade-offs between marginal and conditional accuracy, and evaluating how differences in survey design affect record matching becomes very important. +The cross-dataset benchmarking exercise (Appendix~\ref{app:benchmarking}) shows that method performance is context-dependent. QRF achieves the best Wasserstein distance on datasets with complex non-linear relationships, such as the Brazilian houses dataset and the SCF wealth imputation. Matching, by contrast, achieves the lowest Wasserstein distance on four of six benchmarking datasets and the best mean rank overall (1.67 vs.\ QRF's 1.83). This strong marginal performance follows from Matching's mechanism: drawing imputed values directly from the donor pool inherently preserves the shape and tail behavior of the original distribution, yielding low Wasserstein distances even when the conditioning on predictor values is suboptimal. But as the main SCF-CPS results show, good marginal distributions do not guarantee accurate conditional imputation. Households may receive plausible-looking wealth values that do not correctly reflect their individual characteristics. The trade-off between marginal and conditional accuracy, and the effect of differing survey designs on record matching, deserve careful attention in any application. -Meanwhile, OLS performs competitively on datasets where the predictor-target relationship is approximately linear (e.g., abalone, space\_ga), highlighting that simpler parametric methods remain effective when their structural assumptions are met. MDN consistently ranks last across the benchmarking datasets, likely reflecting a combination of the relatively small sample sizes in these datasets and the hyperparameter sensitivity of neural network-based approaches, which require larger training sets to realize their flexibility advantage. +OLS performs competitively on datasets where the predictor-target relationship is approximately linear (e.g., abalone, space\_ga), confirming that simpler parametric methods remain effective when their structural assumptions hold. MDN consistently ranks last across the benchmarking datasets, likely reflecting a combination of the relatively small sample sizes in these datasets and the hyperparameter sensitivity of neural network-based approaches, which require larger training sets to realize their flexibility advantage. -These findings reinforce the value of \texttt{microimpute}'s automated method comparison framework: no single method dominates across all settings, and the optimal choice depends on dataset size, the complexity of the predictor-target relationship, and whether the research objective prioritizes conditional accuracy or marginal distributional fidelity. Researchers should leverage the package's comparison tools to empirically determine the best approach for their specific data rather than defaulting to any single method. +No single method dominates across all settings. The best choice depends on dataset size, the complexity of the predictor-target relationship, and whether the research question calls for conditional accuracy or marginal distributional fidelity. This is precisely why automated comparison matters: researchers should test methods on their own data rather than defaulting to any one approach. -The predictor importance analysis (Appendix~\ref{app:robustness}) reveals a clear hierarchy among the variables shared between the SCF and CPS. Financial predictors, particularly interest and dividend income, employment income, and age, account for the vast majority of imputation accuracy, with interest and dividend income alone more than doubling the quantile loss when removed. Demographic variables such as race and gender contribute relatively little to imputation quality. This concentration of predictive power in financial variables has two practical implications. First, the quality of wealth imputation rests primarily on the availability and accuracy of income-related variables in both surveys; researchers lacking financial predictors would see substantially degraded results. Second, diminishing returns from additional demographic predictors suggest that a parsimonious set of well-chosen financial variables can achieve most of the attainable accuracy, as confirmed by the progressive inclusion analysis showing that the first three predictors capture the majority of achievable performance. Nonetheless, the inclusion of demographic variables still provides incremental improvements, and their contribution to capturing nuanced wealth variation should not be overlooked. +The predictor importance analysis (Appendix~\ref{app:robustness}) reveals a clear hierarchy among the variables shared between the SCF and CPS. Financial predictors, particularly interest and dividend income, employment income, and age, account for the vast majority of imputation accuracy, with interest and dividend income alone more than doubling the quantile loss when removed. Demographic variables such as race and gender contribute relatively little to imputation quality. This concentration of predictive power in financial variables has two practical implications. First, the quality of wealth imputation rests primarily on the availability and accuracy of income-related variables in both surveys; researchers lacking financial predictors would see substantially degraded results. Second, diminishing returns from additional demographic predictors suggest that a parsimonious set of well-chosen financial variables can achieve most of the attainable accuracy, as confirmed by the progressive inclusion analysis showing that the first three predictors capture the majority of achievable performance. That said, demographic variables still provide incremental improvements and should not be excluded without reason. -\texttt{microimpute}'s support for predictor importance analysis and progressive inclusion experiments allows researchers to systematically assess which variables are driving imputation performance in their specific context, and to make informed decisions about variable selection when faced with data limitations. +The predictor importance and progressive inclusion tools in \texttt{microimpute} let researchers assess which variables drive imputation performance in their specific context and make informed decisions about variable selection when data are limited. -Together, these supplementary analyses contextualize the main results: QRF's strong performance in the SCF-CPS setting reflects the particular characteristics of the wealth imputation problem---a large donor dataset, highly non-linear relationships, and informative financial predictors available in both surveys. Researchers applying \texttt{microimpute} to other settings should expect method rankings to vary with dataset characteristics and should leverage the package's comparison tools accordingly. +Together, these supplementary analyses put the main results in perspective. QRF's strong performance in the SCF-CPS setting reflects the particular characteristics of the wealth imputation problem, namely a large donor dataset, highly non-linear relationships, and informative financial predictors available in both surveys. Researchers applying \texttt{microimpute} to other settings should expect method rankings to vary with dataset characteristics. \subsection{Reform analysis and policy uncertainty} -The SSI Savings Penalty Elimination Act simulation demonstrates that imputation is not merely a data preprocessing step but a consequential modeling choice for policy analysis. The variation in estimated reform impact across the five methods represents practical policy uncertainty: policymakers relying on different imputation approaches would reach different conclusions about the budgetary and coverage effects of raising SSI resource limits. This finding extends beyond SSI to any means-tested program where eligibility depends on wealth or asset data not collected in the primary survey. +The SSI Savings Penalty Elimination Act simulation shows that imputation is not merely a data preprocessing step but a modeling choice with direct consequences for policy analysis. The variation in estimated reform impact across the five methods represents practical policy uncertainty: policymakers relying on different imputation approaches would reach different conclusions about the budgetary and coverage effects of raising SSI resource limits. This finding extends beyond SSI to any means-tested program where eligibility depends on wealth or asset data not collected in the primary survey. -The difference-in-differences approach, by comparing current-law and reform outcomes within each method, partially mitigates the limitation that SCF net worth serves as a proxy for SSI countable resources. Since the systematic bias affects both scenarios similarly, the reform impact estimates are less sensitive to this proxy than the level estimates. Nonetheless, future work could improve these estimates by imputing asset-specific components (e.g., excluding primary residence equity) rather than total net worth, and by extending the analysis to demographic subgroups (aged vs.\ disabled recipients) to better characterize the reform's distributional effects. +By comparing current-law and reform outcomes within each method, the microsimulation approach (under ceteris paribus) partially mitigates the limitation that SCF net worth is only a proxy for SSI countable resources. Since the systematic bias affects both scenarios similarly, the reform impact estimates are less sensitive to this proxy than the level estimates. Nonetheless, future work could improve these estimates by imputing asset-specific components (e.g., excluding primary residence equity) rather than total net worth, and by extending the analysis to demographic subgroups (aged vs.\ disabled recipients) to better characterize the reform's distributional effects. -A practical path to resolving the variation in aggregate estimates across imputation methods is survey weight calibration. Rather than choosing a method based on proximity to administrative totals, which conflates imputation accuracy with policy function artifacts, researchers can select the method with the best conditional accuracy (here, QRF) for record-level imputation and then calibrate survey weights to match known administrative aggregates such as SSI recipient counts and total expenditure to ensure national-level representativeness \citep{woodruff2024enhancing}. This two-stage approach achieves the best of both worlds: accurate household-level wealth distributions that preserve the conditional relationships needed for reform simulation, combined with aggregate estimates that align with administrative benchmarks. +A practical path to resolving the variation in aggregate estimates across imputation methods is survey weight calibration. Rather than choosing a method based on proximity to administrative totals, which conflates imputation accuracy with policy function artifacts, researchers can select the method with the best conditional accuracy (here, QRF) for record-level imputation and then calibrate survey weights to match known administrative aggregates such as SSI recipient counts and total expenditure to ensure national-level representativeness \citep{woodruff2024enhancing}. This two-stage approach combines accurate household-level wealth distributions that preserve the conditional relationships needed for reform simulation with aggregate estimates that align with administrative benchmarks. \subsection{Limitations and future improvements}\label{sec:limitations} -Despite the demonstrated advantages, there remain limitations worth considering for future development: - \subsubsection{Current package limitations} -While $\texttt{microimpute}$ currently supports five imputation methods, expanding to include additional modern machine learning approaches such as more complex neural networks, gradient boosting machines, or deep learning architectures could further improve performance, particularly for complex multivariate relationships \citep{alaa2024deep}. The package would benefit from implementing ensemble methods that combine multiple imputation approaches, potentially leveraging the strengths of different methods across different parts of the distribution. Moreover, model selection and assessment could be enhanced with evaluation metrics additional to quantile loss, ensuring a thorough understanding of each model's behavior at every step. +The current method library of five approaches could be expanded with gradient boosting machines, deeper neural architectures, or other modern machine learning methods, which may improve performance for complex multivariate relationships \citep{alaa2024deep}. Ensemble methods that combine multiple imputation approaches across different parts of the distribution are another natural extension. Model selection and assessment could also benefit from evaluation metrics beyond quantile loss, giving a more complete picture of each model's behavior. \subsubsection{QRF-specific challenges} @@ -58,6 +44,6 @@ \subsubsection{QRF-specific challenges} \subsubsection{Missing data assumptions} -All imputation methods in this study operate under the assumption that wealth is missing at random (MAR) conditional on the shared predictors $X$, as formalized in Section~\ref{sec:cia}. This assumption may be violated in practice: wealthy individuals may be less likely to participate in surveys or may underreport assets, introducing missing-not-at-random (MNAR) patterns that the conditioning variables cannot fully account for \citep{andridge2010review}. The predictor exclusion analysis in Appendix~\ref{app:robustness} provides a partial sensitivity check on the plausibility of the MAR assumption. The substantial degradation in imputation quality when financial predictors are removed, particularly the 116\% increase in quantile loss when interest and dividend income is excluded, demonstrates that the conditioning set is doing meaningful work in explaining wealth variation, which is consistent with MAR plausibility. However, this analysis cannot rule out residual dependence on unobserved factors. More formal MNAR sensitivity analyses, such as pattern-mixture models or selection models applied to the imputation framework, represent an important direction for future work. Researchers applying \texttt{microimpute} to settings where nonresponse or selection effects are suspected should consider supplementary diagnostics to assess the sensitivity of their results to departures from MAR. +All imputation methods in this study operate under the assumption that wealth is missing at random (MAR) conditional on the shared predictors $X$, as formalized in Section~\ref{sec:cia}. This assumption may be violated in practice: wealthy individuals may be less likely to participate in surveys or may underreport assets, introducing missing-not-at-random (MNAR) patterns that the conditioning variables cannot fully account for \citep{andridge2010review}. The predictor exclusion analysis in Appendix~\ref{app:robustness} provides a partial sensitivity check on the plausibility of the MAR assumption. The substantial degradation in imputation quality when financial predictors are removed, particularly the 116\% increase in quantile loss when interest and dividend income is excluded, indicates that the conditioning set is doing meaningful work in explaining wealth variation, consistent with MAR plausibility. However, this analysis cannot rule out residual dependence on unobserved factors. More formal MNAR sensitivity analyses, such as pattern-mixture models or selection models applied to the imputation framework, are a natural next step. Researchers applying \texttt{microimpute} to settings where nonresponse or selection effects are suspected should consider supplementary diagnostics to assess the sensitivity of their results to departures from MAR. -These enhancements would position $\texttt{microimpute}$ further as a comprehensive solution for survey data imputation challenges while maintaining its current strengths in ease of use and methodological rigor. \ No newline at end of file +These extensions would broaden the range of problems \texttt{microimpute} can address while preserving its current simplicity. \ No newline at end of file diff --git a/paper/sections/introduction.tex b/paper/sections/introduction.tex index 8228cbc..864fe41 100644 --- a/paper/sections/introduction.tex +++ b/paper/sections/introduction.tex @@ -1,19 +1,19 @@ \section{Introduction} -Statistical matching, also known as data fusion or full variable imputation, addresses the challenge of combining information from multiple data sources that share common variables but contain different samples of units \citep{dorazio2006statistical}, which is commonly found across fields in empirical research. In its simplest form, the problem involves a donor dataset containing both shared and unique variables and a receiver dataset containing only the shared variables, where the goal is to impute the missing variables for observations in the receiver file based on their shared characteristics \citep{dorazio2021statistical}. This framework generalizes traditional missing data imputation by recognizing that ``missingness'' can arise not only from nonresponse within a single survey but also from the structural absence of variables across distinct data sources. +Statistical matching, also known as data fusion or full variable imputation, combines information from multiple data sources that share common variables but contain different samples of units \citep{dorazio2006statistical}, a situation common across empirical research fields. In its simplest form, the problem involves a donor dataset containing both shared and unique variables and a receiver dataset containing only the shared variables, where the goal is to impute the missing variables for observations in the receiver file based on their shared characteristics \citep{dorazio2021statistical}. This framework generalizes traditional missing data imputation by recognizing that ``missingness'' can arise not only from nonresponse within a single survey but also from the structural absence of variables across distinct data sources. -The relevance of statistical matching extends across the social sciences, particularly in microsimulation modeling where policy analysis requires combining detailed demographic information from one survey with economic variables from another \citep{bourguignon2006microsimulation}. For instance, household surveys may capture income and consumption patterns but lack wealth data, while financial surveys provide wealth information for different samples. Constructing synthetic files that combine these variables enables richer policy analysis than either source alone permits. +Statistical matching is widely applicable in the social sciences, particularly in microsimulation modeling, where policy analysis often requires combining detailed demographic information from one survey with economic variables from another \citep{bourguignon2006microsimulation}. For instance, household surveys may capture income and consumption patterns but lack wealth data, while financial surveys provide wealth information for different samples. Constructing synthetic files that combine these variables allows more detailed policy analysis than either source alone. -The central methodological challenge is accurately modeling the conditional distribution of the target variable given shared predictors, as different approaches make different assumptions about this relationship. Parametric methods may be too restrictive for economic variables exhibiting heavy tails and heteroscedasticity, while nonparametric and machine learning approaches offer greater flexibility at the cost of additional complexity \citep{meinshausen2006quantile, andridge2010review}. Section~2 reviews these trade-offs in detail. +The main methodological difficulty is accurately modeling the conditional distribution of the target variable given shared predictors, as different approaches make different assumptions about this relationship. Parametric methods may be too restrictive for economic variables with heavy tails and heteroscedasticity, while nonparametric and machine learning approaches are more flexible but also more complex \citep{meinshausen2006quantile, andridge2010review}. Section~2 reviews these trade-offs in detail. -This paper presents the \texttt{microimpute} package, a Python library implementing five statistical matching methods: Hot Deck Matching, Ordinary Least Squares, Quantile Regression, Mixture Density Networks, and Quantile Random Forests. The package provides a comprehensive framework for comparing these methods through systematic benchmarking, with automated model selection based on quantile loss performance. We demonstrate the methodology through an application to wealth imputation, transferring net worth data from the Survey of Consumer Finances to the Current Population Survey. This task exemplifies the challenges of statistical matching for heavy-tailed distributions. +This paper presents the \texttt{microimpute} package, a Python library implementing five statistical matching methods: Hot Deck Matching, Ordinary Least Squares, Quantile Regression, Mixture Density Networks, and Quantile Random Forests. The package provides a framework for comparing these methods through systematic benchmarking, with automated model selection based on quantile loss performance. We demonstrate the methodology through an application to wealth imputation, transferring net worth data from the Survey of Consumer Finances to the Current Population Survey. This task exemplifies the challenges of statistical matching for heavy-tailed distributions. This work provides the following contributions: \begin{itemize} - \item An open-source implementation of a unified framework for five imputation methods, facilitating systematic comparison across methods and support statistical matching through accessible, user-friendly software - \item An automated benchmarking pipeline that evaluates distributional accuracy using quantile loss, enabling researchers to select the most appropriate method for their specific data characteristics - \item Empirical evidence on the relative performance of these methods for wealth imputation, suggesting that Quantile Random Forests are well-suited to heavy-tailed distributions while highlighting the context-dependence of method choice - \item A policy reform simulation of the SSI Savings Penalty Elimination Act demonstrating that imputation method choice produces meaningfully different estimates of reform impact, and that without imputed wealth data, such reforms are entirely un-simulable + \item An open-source implementation of five imputation methods under a common interface, allowing systematic comparison across methods + \item An automated benchmarking pipeline that evaluates distributional accuracy using quantile loss, so that researchers can select the most appropriate method for their data + \item Empirical evidence on the relative performance of these methods for wealth imputation, suggesting that Quantile Random Forests are well-suited to heavy-tailed distributions, though method choice remains context-dependent + \item A policy reform simulation of the SSI Savings Penalty Elimination Act, showing that imputation method choice produces meaningfully different estimates of reform impact, and that without imputed wealth data, such reforms cannot be simulated at all \end{itemize} The remainder of this paper is organized as follows. Section 2 reviews the statistical matching problem, including its formal definition, the conditional independence assumption that underlies all matching methods, and the five imputation approaches implemented in \texttt{microimpute}. Section 3 describes our data sources. Section 4 presents the package architecture and benchmarking methodology. Section 5 reports empirical results from the wealth imputation application. Section 6 discusses implications and limitations, and Section 7 concludes. \ No newline at end of file diff --git a/paper/sections/methodology.tex b/paper/sections/methodology.tex index d41a552..6623906 100644 --- a/paper/sections/methodology.tex +++ b/paper/sections/methodology.tex @@ -2,22 +2,22 @@ \section{Methodology}\label{sec:methodology} \subsection{$\texttt{microimpute}$ package implementation} -$\texttt{microimpute}$\footnote{Complete documentation, implementation details, and usage examples are available at https://policyengine.github.io/microimpute/.} is an open-source Python framework for statistical matching that provides a unified interface for imputing variables across survey datasets using multiple methods. The package addresses a gap in the microsimulation and microdata research community: while several imputation approaches exist, researchers have lacked a standardized tool for systematically comparing their distributional performance on a given dataset and automatically selecting the best approach. In this study, \texttt{microimpute} serves both as a methodological contribution to the broader research community and as the analytical engine powering the wealth imputation experiments described below. +$\texttt{microimpute}$\footnote{Documentation, implementation details, and usage examples are available at https://policyengine.github.io/microimpute/.} is an open-source Python framework for statistical matching that provides a common interface for imputing variables across survey datasets using multiple methods. While several imputation approaches exist individually, researchers have lacked a standardized tool for comparing their distributional performance on a given dataset and automatically selecting the best approach. In this study, \texttt{microimpute} is both a methodological contribution and the analytical engine behind the wealth imputation experiments described below. \subsubsection{Core capabilities} -The package currently supports five imputation methods: Hot Deck Matching, Ordinary Least Squares Linear Regression (OLS), Quantile Regression, Quantile Random Forests (QRF), and Mixture Density Networks (MDN). This approach allows researchers to systematically evaluate which technique provides the most accurate results for their specific dataset and research objectives. Additionally, the package is designed to be modular, allowing for easy extension with additional imputation methods in the future. +The package currently supports five imputation methods: Hot Deck Matching, Ordinary Least Squares Linear Regression (OLS), Quantile Regression, Quantile Random Forests (QRF), and Mixture Density Networks (MDN). The package is modular, so additional imputation methods can be added without modifying the existing codebase. \subsubsection{Key features} -The package provides several capabilities designed for the challenges of cross-survey imputation, each of which plays a specific role in the wealth imputation experiments presented in this paper: +Several features are relevant to the cross-survey imputation experiments presented in this paper: \begin{enumerate} - \item \textbf{Survey data weights integration}: Handles survey data weights through sampling to ensure that models are trained on a donor data set representative of the true distribution. - \item \textbf{Method comparison and benchmarking}: Allows researchers to easily compare different approaches and automatically determine the method providing the most accurate results. - \item \textbf{Flexible methodological set-up}: Enables advanced usage through specified hyperparameter setting and tuning. - \item \textbf{Quantile-based evaluation}: Uses quantile loss calculations to assess imputation quality across different parts of the distribution. - \item \textbf{Autoimputation}: Provides an integrated imputation pipeline that tunes method hyperparameters to the specific datasets, compares methods, and selects the best-performing to conduct the requested imputation in a single function call. + \item \textbf{Survey weight integration}: Incorporates survey weights through sampling so that models are trained on data representative of the true population distribution. + \item \textbf{Method comparison and benchmarking}: Compares different approaches and automatically identifies the method with the lowest quantile loss. + \item \textbf{Hyperparameter tuning}: Supports user-specified hyperparameters as well as automated tuning. + \item \textbf{Quantile-based evaluation}: Assesses imputation quality across different parts of the distribution using quantile loss. + \item \textbf{Autoimputation}: Tunes hyperparameters, compares methods, and selects the best-performing one to impute onto the receiver dataset, all in a single function call. \end{enumerate} \subsection{Evaluation framework} @@ -29,7 +29,7 @@ \subsection{Evaluation framework} \mathcal{L}_\tau = \rho_\tau(y - \hat{Q}_\tau(y|x)) = (y - \hat{Q}_\tau(y|x)) \cdot (\tau - \mathbf{1}_{y < \hat{Q}_\tau(y|x)}) \end{equation} -Under-prediction of an upper-tail quantile is penalized at rate $\tau$, whereas over-prediction is penalized at rate $1 - \tau$. This asymmetric structure captures directional bias---a crucial feature when imputing highly skewed variables like wealth, where large under-estimates in the right tail must be discouraged more strongly than equal-sized over-estimates \citep{koenker1978regression}. Compared with mean-squared or mean-absolute error, which penalize errors symmetrically, quantile loss directly targets the conditional distribution and remains robust to outliers \citep{ghenis2018quantile}. Averaging over multiple quantiles $\tau \in \{0.1, 0.2, \ldots, 0.9\}$ provides a comprehensive measure of how well the method estimates the conditional distribution across its entire range. Lower average quantile loss indicates better distributional calibration. Because quantile loss is not scale-invariant, the absolute values reported in Section~5 reflect the scale transformation applied; we supplement raw values with relative improvement percentages to aid interpretability. +Under-prediction of an upper-tail quantile is penalized at rate $\tau$, whereas over-prediction is penalized at rate $1 - \tau$. This asymmetric structure captures directional bias, which matters when imputing highly skewed variables like wealth, where large under-estimates in the right tail must be discouraged more strongly than equal-sized over-estimates \citep{koenker1978regression}. Compared with mean-squared or mean-absolute error, which penalize errors symmetrically, quantile loss directly targets the conditional distribution and remains robust to outliers \citep{ghenis2018quantile}. Averaging over multiple quantiles $\tau \in \{0.1, 0.2, \ldots, 0.9\}$ summarizes how well the method estimates the conditional distribution across its entire range. Lower average quantile loss indicates better distributional calibration. Because quantile loss is not scale-invariant, the absolute values reported in Section~5 reflect the scale transformation applied; we supplement raw values with relative improvement percentages to aid interpretability. As a complementary metric, we use the Wasserstein distance (also known as the earth mover's distance) to measure overall distributional similarity between imputed and reference distributions. Formally, for one-dimensional distributions $P$ and $Q$: \begin{equation} @@ -42,11 +42,11 @@ \subsection{Evaluation framework} \subsection{Experimental setup}\label{sec:experimental_setup} -The \texttt{autoimpute} function orchestrates the full imputation pipeline in a single call: it tunes hyperparameters for applicable methods using the Optuna framework \citep{akiba2019optuna}, evaluates all methods through $k$-fold cross-validation on the donor data, and automatically selects the best-performing one to impute onto the receiver dataset. Using this pipeline, we evaluate all five imputation methods---QRF, OLS, Quantile Regression, Hot Deck Matching, and MDN---for net worth imputation from the SCF to the CPS. +The \texttt{autoimpute} function orchestrates the full imputation pipeline in a single call: it tunes hyperparameters for applicable methods using the Optuna framework \citep{akiba2019optuna}, evaluates all methods through $k$-fold cross-validation on the donor data, and automatically selects the best-performing one to impute onto the receiver dataset. Using this pipeline, we evaluate all five imputation methods (QRF, OLS, Quantile Regression, Hot Deck Matching, and MDN) for net worth imputation from the SCF to the CPS. Prior to model training, we apply an inverse hyperbolic sine (asinh) transformation to the net worth target variable. The asinh function, $\operatorname{asinh}(x) = \ln(x + \sqrt{x^2 + 1})$, approximates the natural logarithm for large positive values while remaining well-defined for zero and negative values. This is particularly important for wealth data, where a meaningful fraction of households report negative net worth (i.e., debts exceeding assets), making a standard logarithmic transformation infeasible. The transformation compresses the extreme right tail of the wealth distribution, stabilizing variance and improving model performance across all methods. After imputation, predictions are back-transformed via the inverse function $\sinh(\cdot)$ to recover values on the original dollar scale. -This experimental design implicitly assumes that wealth is missing at random (MAR) conditional on the shared predictors $X$---that is, the probability of observing wealth in the donor survey does not depend on wealth itself once we condition on the demographic and financial variables available in both datasets, as formalized in Section~\ref{sec:cia}. This assumption is standard in statistical matching but may be violated in practice: wealthy households may be less likely to participate in surveys or may underreport assets, creating MNAR-like patterns that the shared predictors cannot fully capture. We return to this limitation and discuss partial sensitivity checks in Section~\ref{sec:limitations}. +This experimental design implicitly assumes that wealth is missing at random (MAR) conditional on the shared predictors $X$, i.e., that the probability of observing wealth in the donor survey does not depend on wealth itself once we condition on the demographic and financial variables available in both datasets, as formalized in Section~\ref{sec:cia}. This assumption is standard in statistical matching but may be violated in practice: wealthy households may be less likely to participate in surveys or may underreport assets, creating MNAR-like patterns that the shared predictors cannot fully capture. We return to this limitation and discuss partial sensitivity checks in Section~\ref{sec:limitations}. To create a ground truth for evaluation, we employ cross-validation, splitting the SCF into 5 folds. For each method, \texttt{autoimpute} performs the following steps: @@ -82,7 +82,7 @@ \subsection{Downstream policy analysis}\label{sec:policy_validation} While distributional metrics such as quantile loss and Wasserstein distance assess how well an imputation method recovers the statistical properties of the target variable, they do not directly capture whether differences between methods matter for downstream policy analysis. To address this, we introduce a policy analysis layer that evaluates imputation quality through its effect on a concrete policy simulation, under both current law and a proposed reform. -We use the Supplemental Security Income (SSI) program as the policy case. SSI is a federal means-tested program that provides cash assistance to aged, blind, and disabled individuals with limited income and resources \citep{ssa2023ssi}. Critically, SSI eligibility includes a resource test with limits of \$2,000 for individuals and \$3,000 for couples, which directly depends on household wealth data that the CPS does not collect. +We use the Supplemental Security Income (SSI) program as the policy case. SSI is a federal means-tested program that provides cash assistance to aged, blind, and disabled individuals with limited income and resources \citep{ssa2023ssi}. SSI eligibility includes a resource test with limits of \$2,000 for individuals and \$3,000 for couples, which depends on household wealth data that the CPS does not collect. In current practice, PolicyEngine's microsimulation model \citep{policyengine_us} handles the missing wealth data by defaulting \texttt{ssi\_countable\_resources} to zero for all CPS households, meaning every categorically eligible individual passes the resource test. This produces a substantial overestimate of SSI participation. For each imputation method, we feed the model's imputed net worth values into PolicyEngine as \texttt{ssi\_countable\_resources}, replacing the default zero-resource assumption with household-level wealth data. We then simulate SSI eligibility and benefits under each imputation scenario and compare the resulting recipient counts and total expenditure against SSA administrative totals (7.4 million recipients and \$59.6 billion in federal payments for 2024). @@ -90,4 +90,4 @@ \subsection{Downstream policy analysis}\label{sec:policy_validation} \subsection{Additional methodological analyses} -Two supplementary analyses complement the main experimental framework described above. Appendix~\ref{app:robustness} presents a predictor importance analysis using leave-one-out and progressive inclusion experiments on the QRF model, identifying the relative contribution of each predictor variable to wealth imputation quality and informing the plausibility of the conditional independence assumption. Meanwhile, to understand how the superiority of certain models over others generalize to other contexts and datasets, appendix~\ref{app:benchmarking} reports a cross-dataset benchmarking exercise that evaluates \texttt{microimpute}'s five methods across six publicly available datasets from diverse domains, assessing the generalizability of method performance beyond the SCF-CPS use case. +Two supplementary analyses complement the main experimental framework. Appendix~\ref{app:robustness} presents a predictor importance analysis using leave-one-out and progressive inclusion experiments on the QRF model, measuring the relative contribution of each predictor to wealth imputation quality and informing the plausibility of the conditional independence assumption. To test whether the relative performance of methods generalizes beyond this particular application, Appendix~\ref{app:benchmarking} reports a cross-dataset benchmarking exercise that evaluates \texttt{microimpute}'s five methods across six publicly available datasets from different domains. diff --git a/paper/sections/results.tex b/paper/sections/results.tex index 4396ce5..d6f81d5 100644 --- a/paper/sections/results.tex +++ b/paper/sections/results.tex @@ -2,9 +2,9 @@ \section{Results} \subsection{Imputation results} -The cross-validation results demonstrate QRF's strong performance across the wealth distribution. With the asinh transformation applied to the net worth target variable, QRF achieved the lowest average quantile loss across all quantiles, outperforming OLS, Hot Deck Matching, Quantile Regression, and MDN. QRF's overall consistency across the entire distribution made it the best-performing method for wealth imputation in this setting. +The cross-validation results show that QRF performs well across the wealth distribution. With the asinh transformation applied to the net worth target variable, QRF achieved the lowest average quantile loss across all quantiles, outperforming OLS, Hot Deck Matching, Quantile Regression, and MDN. QRF's overall consistency across the entire distribution made it the best-performing method for wealth imputation in this setting. -$\texttt{microimpute}$'s automated hyperparameter tuning via the Optuna framework contributed to these results. The optimal QRF configuration identified through cross-validation included 202 trees, a minimum of 1 sample per leaf, a split threshold of 5 samples, approximately 95\% of features considered at each split, and bootstrap sampling enabled. These parameters balanced model complexity with generalization capability, preventing overfitting while capturing the non-linear relationships between demographic predictors and wealth outcomes. +Automated hyperparameter tuning via the Optuna framework contributed to these results. The optimal QRF configuration identified through cross-validation included 202 trees, a minimum of 1 sample per leaf, a split threshold of 5 samples, approximately 95\% of features considered at each split, and bootstrap sampling enabled. These parameters balance model complexity with generalization, preventing overfitting while still capturing non-linear relationships between demographic predictors and wealth. \subsubsection{Quantile loss} @@ -42,12 +42,12 @@ \subsubsection{Quantile loss} \subsubsection{Imputed wealth distribution comparisons} -By visually comparing the wealth distributions resulting from imputing with each method, and comparing them to the weighted donor distribution, we gain a more comprehensive understanding of imputation performance, moving past the test quantile loss average measured on the SCF. The distribution of wealth values imputed by QRF closely resembles the original SCF distribution, suggesting that QRF is not only the best method in terms of quantile loss but also strong when achieving distributional estimates close to the true wealth distribution in the United States. Because Matching directly takes values from the donor distribution and imputes them onto the receiver, it is unsurprising that its distribution closely resembles the donor distribution. However, Matching may struggle to uncover the non-linear relationship between the different predictors and wealth values, resulting in a seemingly correct marginal distribution but inaccurate imputation given certain household characteristics. Meanwhile, OLS, QuantReg, and MDN fail to capture the variability across the distribution and impute most values at the lower or higher ends of the distribution. +Comparing the imputed wealth distributions against the weighted donor distribution gives a fuller picture of imputation performance beyond the quantile loss averages measured on the SCF. The distribution of wealth values imputed by QRF closely matches the original SCF distribution, suggesting that QRF performs well not only in quantile loss but also in reproducing the overall shape of the U.S. wealth distribution. Because Matching directly takes values from the donor distribution and imputes them onto the receiver, it is unsurprising that its distribution closely resembles the donor distribution. However, Matching may fail to capture the non-linear relationship between predictors and wealth, producing a marginal distribution that looks correct while assigning inaccurate values to individual households. OLS, QuantReg, and MDN fail to capture variability across the distribution and concentrate imputed values at the lower or higher ends. \begin{figure}[H] \centering \includegraphics[width=\textwidth]{figures/models_dist_comparison.png} - \caption{Log-transformed net worth distributions for all five imputation methods, compared to the weighted SCF donor distribution. Dashed lines represent median values.} + \caption{Log-transformed net worth distributions for all five imputation methods, compared to the weighted SCF donor distribution. The y-axis shows the percentage of observations falling in each bin, which normalizes for the different sample sizes between the SCF and CPS and allows direct shape comparison. Dashed lines represent median values.} \label{fig:imputation_distributions_method_comparison} \end{figure} @@ -76,7 +76,7 @@ \subsubsection{Wasserstein distance} \subsubsection{Distribution of wealth by disposable income decile} -Beyond the statistical metrics reported above, we perform a plausibility check by examining whether imputed wealth exhibits economically expected relationships with observable variables. Household wealth should increase monotonically with disposable income, reflecting the well-documented positive correlation between income and wealth accumulation. This check is particularly important for wealth imputation: a method that achieves competitive quantile loss but produces wealth-income patterns that violate basic economic relationships would be unreliable for downstream policy analysis, where the joint distribution of income and wealth drives eligibility and benefit calculations. +Beyond the statistical metrics reported above, we check whether imputed wealth exhibits economically expected relationships with observable variables. Household wealth should increase monotonically with disposable income, given the well-documented positive correlation between income and wealth accumulation. This check matters because a method that achieves competitive quantile loss but produces wealth-income patterns that violate basic economic relationships would be unreliable for downstream policy analysis, where eligibility and benefit calculations depend on the joint distribution of income and wealth. \begin{figure}[H] \centering @@ -85,11 +85,11 @@ \subsubsection{Distribution of wealth by disposable income decile} \label{fig:imputation_comparison_by_income_decile} \end{figure} -These results support the observations above, with QRF presenting the most consistent and plausible relationship to disposable income, with a monotonically increasing average wealth as deciles increase---a pattern consistent with economic expectations and serving as a plausibility check on the imputation. This plot also demonstrates the caveats of the other models, for example showing the lower variability in OLS's predictions, and Matching's consistent underprediction across deciles. +These results are consistent with the observations above. QRF produces a monotonically increasing average wealth across income deciles, as economic theory would predict. The plot also exposes weaknesses in the other models: OLS shows little variability in its predictions, and Matching consistently underpredicts across deciles. \subsection{Policy analysis: SSI simulation} -Following the downstream policy analysis framework described in Section~\ref{sec:policy_validation}, we feed each model's imputed net worth into PolicyEngine as \texttt{ssi\_countable\_resources} and simulate SSI eligibility and benefits. Table~\ref{tab:ssi_results} compares simulated recipient counts and total expenditure against SSA administrative totals for 2024. +As described in Section~\ref{sec:policy_validation}, we feed each model's imputed net worth into PolicyEngine as \texttt{ssi\_countable\_resources} and simulate SSI eligibility and benefits. Table~\ref{tab:ssi_results} compares simulated recipient counts and total expenditure against SSA administrative totals for 2024. \begin{table}[H] \centering @@ -115,7 +115,7 @@ \subsection{Policy analysis: SSI simulation} \subsubsection{Policy reform: SSI Savings Penalty Elimination Act} -To demonstrate why wealth imputation matters for policy analysis and not just model calibration, we simulate the SSI Savings Penalty Elimination Act \citep{ssi_spea_2025}, which would raise resource limits from \$2,000/\$3,000 to \$10,000/\$20,000. Table~\ref{tab:ssi_reform} reports SSI outcomes under both current law and the reformed thresholds for each imputation method. +To show why wealth imputation matters for policy analysis, not just model calibration, we simulate the SSI Savings Penalty Elimination Act \citep{ssi_spea_2025}, which would raise resource limits from \$2,000/\$3,000 to \$10,000/\$20,000. Table~\ref{tab:ssi_reform} reports SSI outcomes under both current law and the reformed thresholds for each imputation method. \begin{table}[H] \centering @@ -140,9 +140,9 @@ \subsubsection{Policy reform: SSI Savings Penalty Elimination Act} The reform raises the resource threshold from \$2,000 to \$10,000 (individual), so the relevant question is how dense each method's imputed wealth distribution is in the \$2,000--\$10,000 range. Households with imputed net worth in this interval fail the current-law resource test but would pass under the reform, gaining SSI eligibility. The estimated number of additional recipients ranges from 0.1 million (QuantReg) to 4.2 million (OLS), with corresponding additional expenditure from \$1.0 billion to \$24.1 billion. This range reflects how differently each imputation approach populates this critical region of the wealth distribution. -QRF estimates a modest reform impact (0.6 million additional recipients, \$3.4 billion additional expenditure), consistent with its faithful reproduction of the SCF wealth distribution. Most households with near-zero wealth are already below the \$2,000 threshold, so raising it to \$10,000 captures relatively few additional households. OLS, by contrast, produces the largest reform impact (4.2 million additional recipients), because its compressed, normally distributed imputations place a substantial mass of households in the \$2,000--\$10,000 range. QuantReg's near-zero reform impact (0.1 million) reflects its tendency to push imputed values away from this intermediate range. These differences underscore that the choice of imputation method is not merely a technical detail but a modeling decision that directly shapes policy impact estimates. +QRF estimates a modest reform impact (0.6 million additional recipients, \$3.4 billion additional expenditure), consistent with its faithful reproduction of the SCF wealth distribution. Most households with near-zero wealth are already below the \$2,000 threshold, so raising it to \$10,000 captures relatively few additional households. OLS, by contrast, produces the largest reform impact (4.2 million additional recipients), because its compressed, normally distributed imputations place a substantial mass of households in the \$2,000--\$10,000 range. QuantReg's near-zero reform impact (0.1 million) reflects its tendency to push imputed values away from this intermediate range. These differences show that the choice of imputation method is not merely a technical detail but a modeling decision that directly shapes policy impact estimates. -The baseline (no wealth data) scenario starkly illustrates why imputation is essential: because all resources default to zero, every categorically eligible individual already passes the current-law resource test, so raising the thresholds has zero effect. Without wealth data, this reform is literally un-simulable. The microsimulation would not be able to distinguish between the status quo and a fivefold increase in resource limits. +The baseline (no wealth data) scenario shows plainly why imputation is necessary. Because all resources default to zero, every categorically eligible individual already passes the current-law resource test, so raising the thresholds has zero effect. Without wealth data, the microsimulation cannot distinguish between the status quo and a fivefold increase in resource limits. \begin{figure}[H] \centering diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 302df50..ca416c8 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1289,8 +1289,8 @@ def test_compare_distributions_return_format() -> None: assert len(results) == 2 # Check data types - assert results["Variable"].dtype == "object" - assert results["Metric"].dtype == "object" + assert pd.api.types.is_string_dtype(results["Variable"]) + assert pd.api.types.is_string_dtype(results["Metric"]) assert results["Distance"].dtype in ["float64", "float32"] diff --git a/tests/test_models/test_imputers.py b/tests/test_models/test_imputers.py index 762d45d..7c50ef0 100644 --- a/tests/test_models/test_imputers.py +++ b/tests/test_models/test_imputers.py @@ -264,7 +264,7 @@ def test_imputation_categorical_targets( # Default behavior returns DataFrame directly assert isinstance(predictions, pd.DataFrame) - assert predictions["categorical"].dtype == "object" + assert pd.api.types.is_string_dtype(predictions["categorical"]) # Test probability predictions for models that support it if model_class.__name__ in ["OLS", "QRF", "Matching"]: @@ -279,7 +279,9 @@ def test_imputation_categorical_targets( # Check that we still get the categorical predictions assert isinstance(predictions_with_probs[0.5], pd.DataFrame) - assert predictions_with_probs[0.5]["categorical"].dtype == "object" + assert pd.api.types.is_string_dtype( + predictions_with_probs[0.5]["categorical"] + ) # Check probability format prob_info = predictions_with_probs["probabilities"]["categorical"] @@ -333,7 +335,7 @@ def test_categorical_return_probs_false( predictions = fitted_model.predict(X_test) assert isinstance(predictions, pd.DataFrame) assert "categorical" in predictions.columns - assert predictions["categorical"].dtype == "object" + assert pd.api.types.is_string_dtype(predictions["categorical"]) assert set(predictions["categorical"].unique()).issubset({"A", "B", "C"}) # Test 2: Explicit return_probs=False with quantiles should return dict of DataFrames diff --git a/tests/test_models/test_matching.py b/tests/test_models/test_matching.py index e6c8e40..7c01bff 100644 --- a/tests/test_models/test_matching.py +++ b/tests/test_models/test_matching.py @@ -208,7 +208,7 @@ def test_matching_mixed_types() -> None: predictions = fitted_model.predict(X_test, quantiles=[0.5]) assert predictions[0.5]["target_numeric"].dtype == np.float64 - assert predictions[0.5]["target_category"].dtype == object + assert pd.api.types.is_string_dtype(predictions[0.5]["target_category"]) # === Edge Cases ===