Skip to content

Commit 7e7f19a

Browse files
Ensure finish() is called exactly once in ForkJoinParallelCpgPass lifecycle
1 parent 66866d6 commit 7e7f19a

2 files changed

Lines changed: 54 additions & 39 deletions

File tree

codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,14 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
7878
nDiffT = flatgraph.DiffGraphApplier.applyDiff(cpg.graph, diffGraph)
7979
} catch {
8080
case exc: Exception =>
81-
baseLogger.error(s"Pass ${name} failed", exc)
81+
baseLogger.error(s"Pass $name failed", exc)
8282
throw exc
8383
} finally {
84-
try {
85-
finish()
86-
} finally {
87-
// the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish()
88-
// in the reported timings, and we must have our final log message if finish() throws
89-
val nanosStop = System.nanoTime()
90-
val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1)
91-
baseLogger.info(
92-
f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts."
93-
)
94-
}
84+
val nanosStop = System.nanoTime()
85+
val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1)
86+
baseLogger.info(
87+
f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms ($fracRun%.0f%% on mutations). $nDiff%d + ${nDiffT - nDiff}%d changes committed from $nParts%d parts."
88+
)
9589
}
9690
}
9791

@@ -106,27 +100,12 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
106100
runOnPart(externalBuilder, parts(0).asInstanceOf[T])
107101
case _ =>
108102
val stream =
109-
if (!isParallel)
110-
java.util.Arrays
111-
.stream(parts)
112-
.sequential()
113-
else
114-
java.util.Arrays
115-
.stream(parts)
116-
.parallel()
103+
if (!isParallel) java.util.Arrays.stream(parts).sequential()
104+
else java.util.Arrays.stream(parts).parallel()
117105
val diff = stream.collect(
118-
new Supplier[DiffGraphBuilder] {
119-
override def get(): DiffGraphBuilder =
120-
Cpg.newDiffGraphBuilder
121-
},
122-
new BiConsumer[DiffGraphBuilder, AnyRef] {
123-
override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit =
124-
runOnPart(builder, part.asInstanceOf[T])
125-
},
126-
new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] {
127-
override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit =
128-
leftBuilder.absorb(rightBuilder)
129-
}
106+
() => Cpg.newDiffGraphBuilder,
107+
(builder: DiffGraphBuilder, part: AnyRef) => runOnPart(builder, part.asInstanceOf[T]),
108+
(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder) => leftBuilder.absorb(rightBuilder)
130109
)
131110
externalBuilder.absorb(diff)
132111
}
@@ -152,12 +131,12 @@ trait CpgPassBase {
152131
@deprecated("Please use createAndApply")
153132
def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit
154133

155-
/** Name of the pass. By default it is inferred from the name of the class, override if needed.
134+
/** Name of the pass. By default, it is inferred from the name of the class, override if needed.
156135
*/
157136
def name: String = getClass.getName
158137

159138
/** Runs the cpg pass, adding changes to the passed builder. Use with caution -- API is unstable. Returns max(nParts,
160-
* 1), where nParts is either the number of parallel parts, or the number of iterarator elements in case of legacy
139+
* 1), where nParts is either the number of parallel parts, or the number of iterator elements in case of legacy
161140
* passes. Includes init() and finish() logic.
162141
*/
163142
def runWithBuilder(builder: DiffGraphBuilder): Int
@@ -172,11 +151,11 @@ trait CpgPassBase {
172151
Try(runWithBuilder(builder)) match {
173152
case Success(nParts) =>
174153
baseLogger.info(
175-
f"Pass ${name} completed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms. ${builder.size - size0}%d changes generated from ${nParts}%d parts."
154+
f"Pass $name completed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms. ${builder.size - size0}%d changes generated from $nParts%d parts."
176155
)
177156
nParts
178157
case Failure(exception) =>
179-
baseLogger.warn(f"Pass ${name} failed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms", exception)
158+
baseLogger.warn(f"Pass $name failed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms", exception)
180159
-1
181160
}
182161
}

codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
package io.shiftleft.passes
22

3-
import better.files.File
43
import flatgraph.SchemaViolationException
54
import io.shiftleft.codepropertygraph.generated.Cpg
6-
import io.shiftleft.codepropertygraph.generated.nodes.NewFile
75
import io.shiftleft.codepropertygraph.generated.language.*
6+
import io.shiftleft.codepropertygraph.generated.nodes.NewFile
87
import org.scalatest.matchers.should.Matchers
98
import org.scalatest.wordspec.AnyWordSpec
109

11-
import java.nio.file.Files
10+
import scala.collection.mutable.ArrayBuffer
1211

1312
class CpgPassNewTests extends AnyWordSpec with Matchers {
1413

@@ -52,6 +51,43 @@ class CpgPassNewTests extends AnyWordSpec with Matchers {
5251
pass.createAndApply()
5352
}
5453
}
54+
55+
"call init and finish once around run" in {
56+
val cpg = Cpg.empty
57+
val events = ArrayBuffer.empty[String]
58+
val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "lifecycle-pass") {
59+
override def init(): Unit = events += "init"
60+
override def generateParts(): Array[String] = Array("p1")
61+
override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = events += "run"
62+
override def finish(): Unit = events += "finish"
63+
}
64+
65+
pass.createAndApply()
66+
67+
// all events should be in the expected order and should only occur once
68+
events.toSeq shouldBe Seq("init", "run", "finish")
69+
}
70+
71+
"call finish once when run fails" in {
72+
val cpg = Cpg.empty
73+
val events = ArrayBuffer.empty[String]
74+
val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "failing-lifecycle-pass") {
75+
override def init(): Unit = events += "init"
76+
override def generateParts(): Array[String] = Array("p1")
77+
override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = {
78+
events += "run"
79+
throw new RuntimeException("run failed")
80+
}
81+
override def finish(): Unit = events += "finish"
82+
}
83+
84+
intercept[RuntimeException] {
85+
pass.createAndApply()
86+
}
87+
88+
// all events should be in the expected order and should only occur once even if run fails
89+
events.toSeq shouldBe Seq("init", "run", "finish")
90+
}
5591
}
5692

5793
}

0 commit comments

Comments
 (0)