From 2e2d229632c7c7ac1a22a53e012a7f6aadc23e8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Thu, 12 Mar 2026 14:01:40 +0100 Subject: [PATCH] Add ForkJoinParallelCpgPassWithAccumulator Introduce a general ForkJoinParallelCpgPassWithAccumulator for fork/join CPG passes that need per-worker accumulators. --- .../scala/io/shiftleft/passes/CpgPass.scala | 248 +++++++++++++----- .../io/shiftleft/passes/CpgPassNewTests.scala | 106 ++++++++ 2 files changed, 290 insertions(+), 64 deletions(-) diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index f43571e44..5f2a51dd6 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -10,12 +10,23 @@ import scala.annotation.nowarn import scala.concurrent.duration.DurationLong import scala.util.{Failure, Success, Try} -/* CpgPass - * - * Base class of a program which receives a CPG as input for the purpose of modifying it. - * */ +/** A single-threaded CPG pass. This is the simplest pass to implement: override [[run]] and add desired graph + * modifications to the provided [[DiffGraphBuilder]]. + * + * Internally implemented as a [[ForkJoinParallelCpgPass]] with a single part and parallelism disabled. + * + * @param cpg + * the code property graph to modify + * @param outName + * optional name for output + */ abstract class CpgPass(cpg: Cpg, outName: String = "") extends ForkJoinParallelCpgPass[AnyRef](cpg, outName) { + /** The main method to implement. Add all desired graph changes (nodes, edges, properties) to the provided builder. + * + * @param builder + * the [[DiffGraphBuilder]] that accumulates graph modifications + */ def run(builder: DiffGraphBuilder): Unit final override def generateParts(): Array[? <: AnyRef] = Array[AnyRef](null) @@ -26,42 +37,126 @@ abstract class CpgPass(cpg: Cpg, outName: String = "") extends ForkJoinParallelC override def isParallel: Boolean = false } +/** @deprecated Use [[CpgPass]] instead. */ @deprecated abstract class SimpleCpgPass(cpg: Cpg, outName: String = "") extends CpgPass(cpg, outName) -/* ForkJoinParallelCpgPass is a possible replacement for CpgPass and ParallelCpgPass. - * - * Instead of returning an Iterator, generateParts() returns an Array. This means that the entire collection - * of parts must live on the heap at the same time; on the other hand, there are no possible issues with iterator invalidation, - * e.g. when running over all METHOD nodes and deleting some of them. - * - * Instead of streaming writes as ParallelCpgPass do, all `runOnPart` invocations read the initial state - * of the graph. Then all changes (accumulated in the DiffGraphBuilders) are merged into a single change, and applied in one go. - * - * In other words, the parallelism follows the fork/join parallel map-reduce (java: collect, scala: aggregate) model. - * The effect is identical as if one were to sequentially run `runOnParts` on all output elements of `generateParts()` - * in sequential order, with the same builder. - * - * This simplifies semantics and makes it easy to reason about possible races. - * - * Note that ForkJoinParallelCpgPass never writes intermediate results, so one must consider peak memory consumption when porting from ParallelCpgPass. - * - * Initialization and cleanup of external resources or large datastructures can be done in the `init()` and `finish()` - * methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct - * passes eagerly, and releases them only when the entire chain has run. - * */ -abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: String = "") extends CpgPassBase { +/** A parallel CPG pass using the fork/join model. + * + * Instead of returning an Iterator, [[generateParts]] returns an Array. This means that the entire collection of parts + * must live on the heap at the same time; on the other hand, there are no possible issues with iterator invalidation, + * e.g. when running over all METHOD nodes and deleting some of them. + * + * Instead of streaming writes as ParallelCpgPass do, all [[runOnPart]] invocations read the initial state of the + * graph. Then all changes (accumulated in the DiffGraphBuilders) are merged into a single change, and applied in one + * go. + * + * In other words, the parallelism follows the fork/join parallel map-reduce (java: collect, scala: aggregate) model. + * The effect is identical as if one were to sequentially run [[runOnPart]] on all output elements of [[generateParts]] + * in sequential order, with the same builder. + * + * This simplifies semantics and makes it easy to reason about possible races. + * + * Note that ForkJoinParallelCpgPass never writes intermediate results, so one must consider peak memory consumption + * when porting from ParallelCpgPass. + * + * Initialization and cleanup of external resources or large datastructures can be done in the [[init]] and [[finish]] + * methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct passes + * eagerly, and releases them only when the entire chain has run. + * + * This is a simplified form of [[ForkJoinParallelCpgPassWithAccumulator]] that does not use an accumulator. + * + * @tparam T + * the type of each part produced by [[generateParts]] + * @param cpg + * the code property graph to modify + * @param outname + * optional output name + */ +abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outname: String = "") + extends ForkJoinParallelCpgPassWithAccumulator[T, Null](cpg, outname) { + + /** Process a single part and record graph modifications in the provided builder. + * + * @param builder + * the [[DiffGraphBuilder]] that accumulates graph modifications + * @param part + * the part to process, as produced by [[generateParts]] + */ + def runOnPart(builder: DiffGraphBuilder, part: T): Unit + + override def createAccumulator(): Null = null + override def runOnPart(builder: DiffGraphBuilder, part: T, acc: Null): Unit = runOnPart(builder, part) + override def onAccumulatorComplete(builder: DiffGraphBuilder, accumulator: Null): Unit = {} + override def mergeAccumulator(left: Null, accumulator: Null): Unit = {} +} + +/** A parallel CPG pass with an accumulator for aggregating side results. + * + * This is the most general form of the fork/join pass framework. It extends [[ForkJoinParallelCpgPass]] with an + * accumulator of type [[Accumulator]] that each parallel worker maintains locally. After all parts are processed, + * worker accumulators are merged via [[mergeAccumulator]], and the final merged accumulator is passed to + * [[onAccumulatorComplete]] where additional graph changes can be recorded. + * + * @tparam T + * the type of each part produced by [[generateParts]] + * @tparam Accumulator + * the type of the accumulator used during parallel execution + * @param cpg + * the code property graph to modify + * @param outName + * optional output name + */ +abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, Accumulator <: AnyRef]( + cpg: Cpg, + @nowarn outName: String = "" +) extends CpgPassBase { type DiffGraphBuilder = io.shiftleft.codepropertygraph.generated.DiffGraphBuilder - // generate Array of parts that can be processed in parallel + + /** Generate an array of parts to be processed in parallel by [[runOnPart]]. */ def generateParts(): Array[? <: AnyRef] - // setup large data structures, acquire external resources + + /** Called once before [[generateParts]]. Use to set up large data structures or acquire external resources. */ def init(): Unit = {} - // release large data structures and external resources + + /** Called once after all parts have been processed (in a `finally` block). Use to release resources acquired in + * [[init]]. + */ def finish(): Unit = {} - // main function: add desired changes to builder - def runOnPart(builder: DiffGraphBuilder, part: T): Unit - // Override this to disable parallelism of passes. Useful for debugging. + + /** Process a single part, recording graph changes in `builder` and side results in `accumulator`. + * + * @param builder + * the [[DiffGraphBuilder]] that accumulates graph modifications + * @param part + * the part to process + * @param accumulator + * the thread-local accumulator for this worker + */ + def runOnPart(builder: DiffGraphBuilder, part: T, accumulator: Accumulator): Unit + + /** Override and return `false` to disable parallel execution. Useful for debugging. */ def isParallel: Boolean = true + /** Create a fresh accumulator instance. Called once per parallel worker thread. */ + def createAccumulator(): Accumulator + + /** Merge the `accumulator` (right) into `left`. Called during the combine phase of fork/join. */ + def mergeAccumulator(left: Accumulator, accumulator: Accumulator): Unit + + /** Called once after all parts are processed and accumulators are merged. Use to record additional graph changes + * based on the fully merged accumulator. + * + * @param builder + * the [[DiffGraphBuilder]] for any additional modifications + * @param accumulator + * the final merged accumulator + */ + def onAccumulatorComplete(builder: DiffGraphBuilder, accumulator: Accumulator): Unit + + /** Creates a new [[DiffGraphBuilder]], runs the pass (init, generateParts, runOnPart, finish), applies all + * accumulated changes to the graph, and logs timing information. Exceptions during execution are logged and + * re-thrown. + */ override def createAndApply(): Unit = { baseLogger.info(s"Start of pass: $name") val nanosStart = System.nanoTime() @@ -89,41 +184,50 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S } } + /** Runs the full pass lifecycle (init, generateParts, parallel runOnPart, accumulator merge, finish) and absorbs all + * changes into `externalBuilder` without applying them to the graph. The caller is responsible for applying the + * builder. + * + * @param externalBuilder + * the builder to absorb all generated changes into + * @return + * the number of parts that were processed + */ override def runWithBuilder(externalBuilder: DiffGraphBuilder): Int = { try { init() + val parts = generateParts() val nParts = parts.size - nParts match { - case 0 => - case 1 => - runOnPart(externalBuilder, parts(0).asInstanceOf[T]) - case _ => - val stream = - if (!isParallel) - java.util.Arrays - .stream(parts) - .sequential() - else - java.util.Arrays - .stream(parts) - .parallel() - val diff = stream.collect( - new Supplier[DiffGraphBuilder] { - override def get(): DiffGraphBuilder = - Cpg.newDiffGraphBuilder - }, - new BiConsumer[DiffGraphBuilder, AnyRef] { - override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit = - runOnPart(builder, part.asInstanceOf[T]) - }, - new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] { - override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit = - leftBuilder.absorb(rightBuilder) - } - ) - externalBuilder.absorb(diff) - } + val stream = + if (!isParallel) java.util.Arrays.stream(parts).sequential() + else java.util.Arrays.stream(parts).parallel() + + val (diff, acc) = stream.collect( + new Supplier[(DiffGraphBuilder, Accumulator)] { + override def get(): (DiffGraphBuilder, Accumulator) = + (Cpg.newDiffGraphBuilder, createAccumulator()) + }, + new BiConsumer[(DiffGraphBuilder, Accumulator), AnyRef] { + override def accept(consumedArg: (DiffGraphBuilder, Accumulator), part: AnyRef): Unit = { + val (diff, acc) = consumedArg + runOnPart(diff, part.asInstanceOf[T], acc) + } + }, + new BiConsumer[(DiffGraphBuilder, Accumulator), (DiffGraphBuilder, Accumulator)] { + override def accept( + leftConsumedArg: (DiffGraphBuilder, Accumulator), + rightConsumedArg: (DiffGraphBuilder, Accumulator) + ): Unit = { + val (leftDiff, leftAcc) = leftConsumedArg + val (rightDiff, rightAcc) = leftConsumedArg + leftDiff.absorb(rightDiff) + mergeAccumulator(leftAcc, rightAcc) + } + } + ) + onAccumulatorComplete(diff, acc) + externalBuilder.absorb(diff) nParts } finally { finish() @@ -137,6 +241,9 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S } +/** Base trait for all CPG passes. Defines the lifecycle methods that every pass must implement: [[createAndApply]] for + * standalone execution, and [[runWithBuilder]] for composing passes that share a single [[DiffGraphBuilder]]. + */ trait CpgPassBase { protected def baseLogger: Logger = LoggerFactory.getLogger(getClass) @@ -156,8 +263,12 @@ trait CpgPassBase { */ def runWithBuilder(builder: DiffGraphBuilder): Int - /** Wraps runWithBuilder with logging, and swallows raised exceptions. Use with caution -- API is unstable. A return - * value of -1 indicates failure, otherwise the return value of runWithBuilder is passed through. + /** Wraps [[runWithBuilder]] with logging and exception handling. Use with caution — API is unstable. + * + * @param builder + * the [[DiffGraphBuilder]] to absorb changes into + * @return + * the number of parts processed, or `-1` if the pass threw an exception */ def runWithBuilderLogged(builder: DiffGraphBuilder): Int = { baseLogger.info(s"Start of pass: $name") @@ -189,6 +300,15 @@ trait CpgPassBase { @deprecated protected def store(overlay: GeneratedMessageV3, name: String, serializedCpg: SerializedCpg): Unit = {} + /** Executes `fun` while logging the pass start and completion time (including duration via MDC). + * + * @tparam A + * the return type of the wrapped computation + * @param fun + * the computation to execute + * @return + * the result of `fun` + */ protected def withStartEndTimesLogged[A](fun: => A): A = { baseLogger.info(s"Running pass: $name") val startTime = System.currentTimeMillis diff --git a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala index dc32d7975..db353d542 100644 --- a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala +++ b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala @@ -8,6 +8,7 @@ import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable class CpgPassNewTests extends AnyWordSpec with Matchers { @@ -90,4 +91,109 @@ class CpgPassNewTests extends AnyWordSpec with Matchers { } } + "ForkJoinParallelCpgPassWithAccumulator" should { + "merge accumulators and invoke completion callback once" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-pass") { + override def createAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int] + override def mergeAccumulator(left: ArrayBuffer[Int], accumulator: ArrayBuffer[Int]): Unit = { + left ++= accumulator + } + override def runOnPart(builder: DiffGraphBuilder, part: String, acc: ArrayBuffer[Int]): Unit = + acc += part.length + override def onAccumulatorComplete(builder: DiffGraphBuilder, acc: ArrayBuffer[Int]): Unit = + completed += acc.sum + override def generateParts(): Array[String] = Array("a", "bb", "ccc") + override def isParallel: Boolean = false + } + + pass.createAndApply() + + completed.toSeq shouldBe Seq(6) + } + + "use a fresh accumulator when there are no parts" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-empty") { + override def createAccumulator(): ArrayBuffer[Int] = ArrayBuffer(42) + override def mergeAccumulator(left: ArrayBuffer[Int], accumulator: ArrayBuffer[Int]): Unit = { + left ++= accumulator + } + override def runOnPart(builder: DiffGraphBuilder, part: String, acc: ArrayBuffer[Int]): Unit = () + override def onAccumulatorComplete(builder: DiffGraphBuilder, acc: ArrayBuffer[Int]): Unit = + completed += acc.sum + override def generateParts(): Array[String] = Array.empty + } + + pass.createAndApply() + + completed.toSeq shouldBe Seq(42) + } + + "clear accumulator state between runs" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-rerun") { + override def createAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int] + override def mergeAccumulator(left: ArrayBuffer[Int], accumulator: ArrayBuffer[Int]): Unit = { + left ++= accumulator + } + override def runOnPart(builder: DiffGraphBuilder, part: String, acc: ArrayBuffer[Int]): Unit = + acc += part.toInt + override def onAccumulatorComplete(builder: DiffGraphBuilder, acc: ArrayBuffer[Int]): Unit = + completed += acc.sum + override def generateParts(): Array[String] = Array("1", "2", "3") + override def isParallel: Boolean = false + } + + pass.createAndApply() + pass.createAndApply() + + completed.toSeq shouldBe Seq(6, 6) + } + + "call finish when a part fails" in { + val cpg = Cpg.empty + val events = ArrayBuffer.empty[String] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[String]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[String]](cpg, "acc-fail") { + override def createAccumulator(): ArrayBuffer[String] = ArrayBuffer.empty[String] + override def mergeAccumulator(left: ArrayBuffer[String], accumulator: ArrayBuffer[String]): Unit = { + left ++= accumulator + } + override def runOnPart(builder: DiffGraphBuilder, part: String, acc: ArrayBuffer[String]): Unit = { + events += "run" + throw new RuntimeException("boom") + } + override def onAccumulatorComplete(builder: DiffGraphBuilder, acc: ArrayBuffer[String]): Unit = + events += "final" + override def generateParts(): Array[String] = Array("p1") + override def isParallel: Boolean = false + override def init(): Unit = { + events += "init" + super.init() + } + override def finish(): Unit = { + events += "finish" + super.finish() + } + } + + intercept[RuntimeException] { + pass.createAndApply() + } + + events.toSeq shouldBe Seq("init", "run", "finish") + } + } + }