Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions python/pyspark/statcounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@
import math
from typing import Dict, Iterable, Optional

try:
from numpy import maximum, minimum, sqrt
except ImportError:
maximum = max # type: ignore[assignment]
minimum = min # type: ignore[assignment]
sqrt = math.sqrt # type: ignore[assignment]


class StatCounter:
def __init__(self, values: Optional[Iterable[float]] = None):
try:
from numpy import maximum, minimum, sqrt
except ImportError:
maximum = max # type: ignore[assignment]
minimum = min # type: ignore[assignment]
sqrt = math.sqrt # type: ignore[assignment]

self.maximum = maximum
self.minimum = minimum
self.sqrt = sqrt
Comment on lines +27 to +36
Copy link
Contributor

@Yicong-Huang Yicong-Huang Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those methods are pretty standard, do we necessarily need to use the numpy version of them?
we are falling back to math anyway when numpy is not detected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original PR (long long time ago) was to fix the issue that max, min and math.sqrt does not deal with numpy. It might not be an issue now. We might not have to keep it but that would be a different thing to fix.

if values is None:
values = list()
self.n = 0 # Running count of our values
Expand All @@ -48,8 +51,8 @@ def merge(self, value: float) -> "StatCounter":
self.n += 1
self.mu += delta / self.n
self.m2 += delta * (value - self.mu)
self.maxValue = maximum(self.maxValue, value)
self.minValue = minimum(self.minValue, value)
self.maxValue = self.maximum(self.maxValue, value)
self.minValue = self.minimum(self.minValue, value)

return self

Expand Down Expand Up @@ -77,8 +80,8 @@ def mergeStats(self, other: "StatCounter") -> "StatCounter":
else:
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)

self.maxValue = maximum(self.maxValue, other.maxValue)
self.minValue = minimum(self.minValue, other.minValue)
self.maxValue = self.maximum(self.maxValue, other.maxValue)
self.minValue = self.minimum(self.minValue, other.minValue)

self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
self.n += other.n
Expand Down Expand Up @@ -122,14 +125,14 @@ def sampleVariance(self) -> float:

# Return the standard deviation of the values.
def stdev(self) -> float:
return sqrt(self.variance())
return self.sqrt(self.variance())

#
# Return the sample standard deviation of the values, which corrects for bias in estimating the
# variance by dividing by N-1 instead of N.
#
def sampleStdev(self) -> float:
return sqrt(self.sampleVariance())
return self.sqrt(self.sampleVariance())

def asDict(self, sample: bool = False) -> Dict[str, float]:
"""Returns the :class:`StatCounter` members as a ``dict``.
Expand Down