From 2f60636670dbb0b73a0db9031bf7dc474944dfa0 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Wed, 18 Mar 2026 16:36:38 -0700 Subject: [PATCH] Lazy import numpy for pyspark --- python/pyspark/statcounter.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 758d0a6fb8b01..92bf0b11a3e11 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -21,16 +21,19 @@ import math from typing import Dict, Iterable, Optional -try: - from numpy import maximum, minimum, sqrt -except ImportError: - maximum = max # type: ignore[assignment] - minimum = min # type: ignore[assignment] - sqrt = math.sqrt # type: ignore[assignment] - class StatCounter: def __init__(self, values: Optional[Iterable[float]] = None): + try: + from numpy import maximum, minimum, sqrt + except ImportError: + maximum = max # type: ignore[assignment] + minimum = min # type: ignore[assignment] + sqrt = math.sqrt # type: ignore[assignment] + + self.maximum = maximum + self.minimum = minimum + self.sqrt = sqrt if values is None: values = list() self.n = 0 # Running count of our values @@ -48,8 +51,8 @@ def merge(self, value: float) -> "StatCounter": self.n += 1 self.mu += delta / self.n self.m2 += delta * (value - self.mu) - self.maxValue = maximum(self.maxValue, value) - self.minValue = minimum(self.minValue, value) + self.maxValue = self.maximum(self.maxValue, value) + self.minValue = self.minimum(self.minValue, value) return self @@ -77,8 +80,8 @@ def mergeStats(self, other: "StatCounter") -> "StatCounter": else: self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) - self.maxValue = maximum(self.maxValue, other.maxValue) - self.minValue = minimum(self.minValue, other.minValue) + self.maxValue = self.maximum(self.maxValue, other.maxValue) + self.minValue = self.minimum(self.minValue, other.minValue) self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) self.n += other.n @@ -122,14 +125,14 @@ def sampleVariance(self) -> float: # Return the standard deviation of the values. def stdev(self) -> float: - return sqrt(self.variance()) + return self.sqrt(self.variance()) # # Return the sample standard deviation of the values, which corrects for bias in estimating the # variance by dividing by N-1 instead of N. # def sampleStdev(self) -> float: - return sqrt(self.sampleVariance()) + return self.sqrt(self.sampleVariance()) def asDict(self, sample: bool = False) -> Dict[str, float]: """Returns the :class:`StatCounter` members as a ``dict``.