From 0b6df3c82688f6bee29fb28ca54d32d8cb1bff97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Wed, 11 Mar 2026 17:29:23 +0100 Subject: [PATCH 1/3] Add ForkJoinParallelCpgPassWithAccumulator for thread-local aggregation --- .../scala/io/shiftleft/passes/CpgPass.scala | 58 +++++++++ .../io/shiftleft/passes/CpgPassNewTests.scala | 110 ++++++++++++++++++ 2 files changed, 168 insertions(+) diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index f43571e44..a08d019b2 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -5,8 +5,10 @@ import io.shiftleft.SerializedCpg import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder} import org.slf4j.{Logger, LoggerFactory, MDC} +import java.util.concurrent.ConcurrentLinkedQueue import java.util.function.{BiConsumer, Supplier} import scala.annotation.nowarn +import scala.jdk.CollectionConverters.* import scala.concurrent.duration.DurationLong import scala.util.{Failure, Success, Try} @@ -137,6 +139,62 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S } +/** A [[ForkJoinParallelCpgPass]] that additionally maintains a thread-local accumulator of type [[R]] which is merged + * across all threads after processing completes. This enables map-reduce style aggregation alongside the usual + * DiffGraph-based graph modifications. + * + * Each thread gets its own accumulator instance (via [[newAccumulator]]). After all parts are processed, the + * accumulators are merged using [[mergeAccumulators]] and the result is passed to [[onAccumulatorComplete]]. + * + * @tparam T + * the part type (same as in [[ForkJoinParallelCpgPass]]) + * @tparam R + * the accumulator type + */ +abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, R](cpg: Cpg, outName: String = "") + extends ForkJoinParallelCpgPass[T](cpg, outName) { + + /** Create a fresh, empty accumulator. Called once per thread. */ + protected def newAccumulator(): R + + /** Merge two accumulators. Must be associative. The result may reuse either argument. */ + protected def mergeAccumulators(left: R, right: R): R + + /** Process a single part, writing graph changes to `builder` and aggregated data to `acc`. */ + protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: R, part: T): Unit + + /** Called after all parts are processed with the fully merged accumulator. */ + protected def onAccumulatorComplete(acc: R): Unit = {} + + private val accumulators = new ConcurrentLinkedQueue[R]() + + private val threadLocalAcc: ThreadLocal[R] = new ThreadLocal[R]() + + final override def runOnPart(builder: DiffGraphBuilder, part: T): Unit = { + var acc = threadLocalAcc.get() + if (acc == null) { + acc = newAccumulator() + threadLocalAcc.set(acc) + accumulators.add(acc) + } + runOnPartWithAccumulator(builder, acc, part) + } + + override def init(): Unit = { + accumulators.clear() + threadLocalAcc.remove() + super.init() + } + + override def finish(): Unit = { + val merged = accumulators.asScala.reduceOption(mergeAccumulators).getOrElse(newAccumulator()) + onAccumulatorComplete(merged) + accumulators.clear() + threadLocalAcc.remove() + super.finish() + } +} + trait CpgPassBase { protected def baseLogger: Logger = LoggerFactory.getLogger(getClass) diff --git a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala index dc32d7975..467e786ba 100644 --- a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala +++ b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala @@ -90,4 +90,114 @@ class CpgPassNewTests extends AnyWordSpec with Matchers { } } + "ForkJoinParallelCpgPassWithAccumulator" should { + "merge accumulators and invoke completion callback once" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-pass") { + override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int] + override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] = + left ++= right + override protected def runOnPartWithAccumulator( + builder: DiffGraphBuilder, + acc: ArrayBuffer[Int], + part: String + ): Unit = acc += part.length + override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum + override def generateParts(): Array[String] = Array("a", "bb", "ccc") + override def isParallel: Boolean = false + } + + pass.createAndApply() + + completed.toSeq shouldBe Seq(6) + } + + "use a fresh accumulator when there are no parts" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-empty") { + override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer(42) + override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] = + left ++= right + override protected def runOnPartWithAccumulator( + builder: DiffGraphBuilder, + acc: ArrayBuffer[Int], + part: String + ): Unit = () + override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum + override def generateParts(): Array[String] = Array.empty + } + + pass.createAndApply() + + completed.toSeq shouldBe Seq(42) + } + + "clear accumulator state between runs" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-rerun") { + override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int] + override protected def mergeAccumulators( + left: ArrayBuffer[Int], + right: ArrayBuffer[Int] + ): ArrayBuffer[Int] = { + left ++= right + } + override protected def runOnPartWithAccumulator( + builder: DiffGraphBuilder, + acc: ArrayBuffer[Int], + part: String + ): Unit = acc += part.toInt + override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum + override def generateParts(): Array[String] = Array("1", "2", "3") + override def isParallel: Boolean = false + } + + pass.createAndApply() + pass.createAndApply() + + completed.toSeq shouldBe Seq(6, 6) + } + + "invoke completion callback once when a part fails" in { + val cpg = Cpg.empty + val events = ArrayBuffer.empty[String] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, Int] = + new ForkJoinParallelCpgPassWithAccumulator[String, Int](cpg, "acc-fail") { + override protected def newAccumulator(): Int = 0 + override protected def mergeAccumulators(left: Int, right: Int): Int = left + right + override protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: Int, part: String): Unit = { + events += "run" + throw new RuntimeException("boom") + } + override protected def onAccumulatorComplete(acc: Int): Unit = events += s"complete:$acc" + override def generateParts(): Array[String] = Array("p1") + override def isParallel: Boolean = false + override def init(): Unit = { + events += "init" + super.init() + } + override def finish(): Unit = { + events += "finish" + super.finish() + } + } + + intercept[RuntimeException] { + pass.createAndApply() + } + + events.toSeq shouldBe Seq("init", "run", "finish", "complete:0") + } + } + } From 8adfbd5558a357b82f9fa1e0632c25ff4666e85f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Thu, 12 Mar 2026 09:11:19 +0100 Subject: [PATCH 2/3] Alternative using BiConsumer etc --- .../scala/io/shiftleft/passes/CpgPass.scala | 134 ++++++++++++++---- 1 file changed, 109 insertions(+), 25 deletions(-) diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index a08d019b2..7330c37ec 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -5,10 +5,9 @@ import io.shiftleft.SerializedCpg import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder} import org.slf4j.{Logger, LoggerFactory, MDC} -import java.util.concurrent.ConcurrentLinkedQueue import java.util.function.{BiConsumer, Supplier} import scala.annotation.nowarn -import scala.jdk.CollectionConverters.* +import scala.compiletime.uninitialized import scala.concurrent.duration.DurationLong import scala.util.{Failure, Success, Try} @@ -146,15 +145,29 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S * Each thread gets its own accumulator instance (via [[newAccumulator]]). After all parts are processed, the * accumulators are merged using [[mergeAccumulators]] and the result is passed to [[onAccumulatorComplete]]. * + * This variant uses the `stream.collect` / `BiConsumer` API (just like [[ForkJoinParallelCpgPass]]) with a combined + * container that holds both a [[DiffGraphBuilder]] and an accumulator per fork, so no `ThreadLocal` or + * `ConcurrentLinkedQueue` is needed. + * * @tparam T * the part type (same as in [[ForkJoinParallelCpgPass]]) * @tparam R * the accumulator type */ -abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, R](cpg: Cpg, outName: String = "") - extends ForkJoinParallelCpgPass[T](cpg, outName) { +abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, R](cpg: Cpg, @nowarn outName: String = "") + extends CpgPassBase { + type DiffGraphBuilder = io.shiftleft.codepropertygraph.generated.DiffGraphBuilder - /** Create a fresh, empty accumulator. Called once per thread. */ + /** Generate Array of parts that can be processed in parallel. */ + def generateParts(): Array[? <: AnyRef] + + /** Setup large data structures, acquire external resources. */ + def init(): Unit = {} + + /** Override this to disable parallelism of passes. Useful for debugging. */ + def isParallel: Boolean = true + + /** Create a fresh, empty accumulator. Called once per fork (thread). */ protected def newAccumulator(): R /** Merge two accumulators. Must be associative. The result may reuse either argument. */ @@ -163,35 +176,106 @@ abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, R](cpg: Cpg, /** Process a single part, writing graph changes to `builder` and aggregated data to `acc`. */ protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: R, part: T): Unit - /** Called after all parts are processed with the fully merged accumulator. */ + /** Called after all parts are processed with the fully merged accumulator. Override `finish()` if you need to release + * resources; `onAccumulatorComplete` is invoked from within the default `finish()` implementation. + */ protected def onAccumulatorComplete(acc: R): Unit = {} - private val accumulators = new ConcurrentLinkedQueue[R]() + /** Container pairing a per-fork DiffGraphBuilder with a per-fork accumulator. */ + private class BuilderWithAccumulator(val builder: DiffGraphBuilder, var acc: R) - private val threadLocalAcc: ThreadLocal[R] = new ThreadLocal[R]() + @volatile private var _accResult: R = uninitialized + @volatile private var _hasResult: Boolean = false - final override def runOnPart(builder: DiffGraphBuilder, part: T): Unit = { - var acc = threadLocalAcc.get() - if (acc == null) { - acc = newAccumulator() - threadLocalAcc.set(acc) - accumulators.add(acc) + /** Release large data structures and external resources. The default implementation calls + * [[onAccumulatorComplete]] with the merged accumulator (or a fresh one if processing failed). Subclasses that + * override this method must call `super.finish()` to ensure the accumulator callback fires. + */ + def finish(): Unit = { + val acc = if (_hasResult) _accResult else newAccumulator() + onAccumulatorComplete(acc) + _hasResult = false + } + + override def createAndApply(): Unit = { + baseLogger.info(s"Start of pass: $name") + val nanosStart = System.nanoTime() + var nParts = 0 + var nanosBuilt = -1L + var nDiff = -1 + var nDiffT = -1 + try { + val diffGraph = Cpg.newDiffGraphBuilder + nParts = runWithBuilder(diffGraph) + nanosBuilt = System.nanoTime() + nDiff = diffGraph.size + + nDiffT = flatgraph.DiffGraphApplier.applyDiff(cpg.graph, diffGraph) + } catch { + case exc: Exception => + baseLogger.error(s"Pass ${name} failed", exc) + throw exc + } finally { + val nanosStop = System.nanoTime() + val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) + baseLogger.info( + f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms ($fracRun%.0f%% on mutations). $nDiff%d + ${nDiffT - nDiff}%d changes committed from $nParts%d parts." + ) } - runOnPartWithAccumulator(builder, acc, part) } - override def init(): Unit = { - accumulators.clear() - threadLocalAcc.remove() - super.init() + override def runWithBuilder(externalBuilder: DiffGraphBuilder): Int = { + _hasResult = false + try { + init() + val parts = generateParts() + val nParts = parts.size + _accResult = nParts match { + case 0 => + newAccumulator() + case 1 => + val acc = newAccumulator() + runOnPartWithAccumulator(externalBuilder, acc, parts(0).asInstanceOf[T]) + acc + case _ => + val stream = + if (!isParallel) + java.util.Arrays + .stream(parts) + .sequential() + else + java.util.Arrays + .stream(parts) + .parallel() + val result = stream.collect( + new Supplier[BuilderWithAccumulator] { + override def get(): BuilderWithAccumulator = + new BuilderWithAccumulator(Cpg.newDiffGraphBuilder, newAccumulator()) + }, + new BiConsumer[BuilderWithAccumulator, AnyRef] { + override def accept(bwa: BuilderWithAccumulator, part: AnyRef): Unit = + runOnPartWithAccumulator(bwa.builder, bwa.acc, part.asInstanceOf[T]) + }, + new BiConsumer[BuilderWithAccumulator, BuilderWithAccumulator] { + override def accept(left: BuilderWithAccumulator, right: BuilderWithAccumulator): Unit = { + left.builder.absorb(right.builder) + left.acc = mergeAccumulators(left.acc, right.acc) + } + } + ) + externalBuilder.absorb(result.builder) + result.acc + } + _hasResult = true + nParts + } finally { + finish() + } } - override def finish(): Unit = { - val merged = accumulators.asScala.reduceOption(mergeAccumulators).getOrElse(newAccumulator()) - onAccumulatorComplete(merged) - accumulators.clear() - threadLocalAcc.remove() - super.finish() + @deprecated("Please use createAndApply") + override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = { + createAndApply() } } From 10a1f953869b6c8cf751c8522d48c89822c1eb10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Thu, 12 Mar 2026 09:20:43 +0100 Subject: [PATCH 3/3] format --- .../src/main/scala/io/shiftleft/passes/CpgPass.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index 7330c37ec..464b6f742 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -184,12 +184,12 @@ abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, R](cpg: Cpg, /** Container pairing a per-fork DiffGraphBuilder with a per-fork accumulator. */ private class BuilderWithAccumulator(val builder: DiffGraphBuilder, var acc: R) - @volatile private var _accResult: R = uninitialized + @volatile private var _accResult: R = uninitialized @volatile private var _hasResult: Boolean = false - /** Release large data structures and external resources. The default implementation calls - * [[onAccumulatorComplete]] with the merged accumulator (or a fresh one if processing failed). Subclasses that - * override this method must call `super.finish()` to ensure the accumulator callback fires. + /** Release large data structures and external resources. The default implementation calls [[onAccumulatorComplete]] + * with the merged accumulator (or a fresh one if processing failed). Subclasses that override this method must call + * `super.finish()` to ensure the accumulator callback fires. */ def finish(): Unit = { val acc = if (_hasResult) _accResult else newAccumulator()