Fix #848: Pass estimation_sample_size parameter to individual trees in UpliftRandomForestClassifier#1
Conversation
…es in UpliftRandomForestClassifier
WalkthroughAdds a new estimation_sample_size parameter to UpliftRandomForestClassifier, stores it on the instance, and forwards it to each UpliftTreeClassifier during fit. Public init signature updated; tree construction calls now include estimation_sample_size. No other logic changes noted. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant RFC as UpliftRandomForestClassifier
participant Pool as Job Pool / Workers
participant UTC as UpliftTreeClassifier
User->>RFC: __init__(..., estimation_sample_size=0.5, ...)
note right of RFC: Store self.estimation_sample_size
User->>RFC: fit(X, treatment, y)
RFC->>Pool: Spawn tasks to build trees
loop For each estimator
Pool->>UTC: __init__(..., estimation_sample_size=RFC.estimation_sample_size)
UTC-->>Pool: Initialized with config
Pool->>UTC: fit(subsample)
UTC-->>Pool: Trained tree
end
Pool-->>RFC: Aggregated trained trees
RFC-->>User: Fitted model
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
causalml/inference/tree/uplift.pyx (2)
2427-2437: Duplicate assignment of control_name — remove the second set for clarity.
self.control_nameis set twice; the second set is redundant.Apply this diff:
assert control_name is not None and isinstance(control_name, str), \ f"control_group should be string but {control_name} is passed" - self.control_name = control_name self.classes_ = [control_name]
2429-2432: Store estimation_sample_size on the RF — add light validation.
Preempt train_test_split errors by validating bounds when honesty=True.Apply this diff:
self.normalization = normalization self.honesty = honesty - self.estimation_sample_size = estimation_sample_size + # Basic validation aligned with sklearn's train_test_split expectations + if self.honesty: + if isinstance(estimation_sample_size, float) and not (0.0 < estimation_sample_size < 1.0): + raise ValueError("estimation_sample_size must be in (0, 1) when provided as a float and honesty=True.") + if isinstance(estimation_sample_size, int) and estimation_sample_size < 1: + raise ValueError("estimation_sample_size must be >= 1 when provided as an int and honesty=True.") + self.estimation_sample_size = estimation_sample_sizeIf you prefer, mirror the same validation in UpliftTreeClassifier.init for direct usage:
# In UpliftTreeClassifier.__init__ if honesty: if isinstance(estimation_sample_size, float) and not (0.0 < estimation_sample_size < 1.0): raise ValueError("estimation_sample_size must be in (0, 1) when provided as a float and honesty=True.") if isinstance(estimation_sample_size, int) and estimation_sample_size < 1: raise ValueError("estimation_sample_size must be >= 1 when provided as an int and honesty=True.")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
causalml/inference/tree/uplift.pyx(17 hunks)
🔇 Additional comments (5)
causalml/inference/tree/uplift.pyx (5)
2382-2384: Docstring updated to include estimation_sample_size — looks good.
Accurately documents the new parameter and its semantics.
2399-2414: Public init signature change — verify downstream compatibility.
Adding a new kwarg with a default is generally safe. Please confirm any serialization/docs (API reference, examples) reflect the new parameter.
2470-2483: Forwarding estimation_sample_size to each tree — correct and complete.
This fixes uber#848 by ensuring per-tree honest estimation uses the forest’s configured split.Please confirm the new regression test asserts that each constructed tree has tree.estimation_sample_size == forest.estimation_sample_size.
1591-1593: Typed memoryview aliasing — LGTM.
Casting toN_TYPE_t[::1]avoids repeated Python-level indexing; minor perf win.
246-246: Minor comments/whitespace-only adjustments.
No functional impact.Also applies to: 325-325, 365-366, 407-407, 468-468, 1338-1341, 1368-1369, 1425-1447, 1932-1932, 2110-2112, 2163-2164, 2518-2518
Summary
Fixes issue uber#848 where the
estimation_sample_sizeparameter was not being passed to individual trees inUpliftRandomForestClassifier.Changes
UpliftRandomForestClassifier.__init__UpliftTreeClassifierinstances during forest creationTesting
This fix enables users to properly control honest splitting behavior in random forests as intended.
Summary by CodeRabbit