|
19 | 19 | */ |
20 | 20 | package org.neo4j.gds.projection; |
21 | 21 |
|
22 | | -import com.carrotsearch.hppc.LongHashSet; |
23 | | -import org.agrona.collections.MutableBoolean; |
24 | 22 | import org.junit.jupiter.api.Test; |
25 | 23 | import org.neo4j.gds.ImmutableRelationshipProjections; |
26 | | -import org.neo4j.gds.NodeLabel; |
27 | | -import org.neo4j.gds.NodeProjection; |
28 | 24 | import org.neo4j.gds.NodeProjections; |
29 | 25 | import org.neo4j.gds.Orientation; |
30 | 26 | import org.neo4j.gds.RelationshipProjection; |
|
35 | 31 | import org.neo4j.gds.core.GraphDimensions; |
36 | 32 | import org.neo4j.gds.core.ImmutableGraphDimensions; |
37 | 33 | import org.neo4j.gds.core.concurrency.Concurrency; |
38 | | -import org.neo4j.gds.core.utils.progress.JobId; |
39 | | -import org.neo4j.gds.core.utils.progress.TaskRegistry; |
40 | | -import org.neo4j.gds.core.utils.progress.TaskRegistryFactory; |
41 | | -import org.neo4j.gds.logging.Log; |
| 34 | +import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; |
42 | 35 | import org.neo4j.gds.mem.MemoryEstimation; |
43 | 36 | import org.neo4j.gds.mem.MemoryTree; |
44 | 37 |
|
45 | | -import java.util.Map; |
46 | | -import java.util.Optional; |
47 | 38 | import java.util.concurrent.atomic.AtomicReference; |
48 | 39 |
|
49 | | -import static org.assertj.core.api.Assertions.assertThat; |
| 40 | +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; |
50 | 41 | import static org.junit.jupiter.api.Assertions.assertEquals; |
51 | | -import static org.mockito.ArgumentMatchers.any; |
| 42 | +import static org.mockito.Mockito.doThrow; |
52 | 43 | import static org.mockito.Mockito.mock; |
| 44 | +import static org.mockito.Mockito.spy; |
53 | 45 | import static org.mockito.Mockito.times; |
54 | 46 | import static org.mockito.Mockito.verify; |
55 | | -import static org.mockito.Mockito.when; |
| 47 | +import static org.mockito.Mockito.verifyNoMoreInteractions; |
56 | 48 |
|
57 | 49 | class NativeFactoryTest { |
58 | 50 |
|
@@ -161,65 +153,24 @@ void memoryEstimationForMultipleProjections() { |
161 | 153 | } |
162 | 154 |
|
163 | 155 | @Test |
164 | | - void shouldCleanTaskOnValidateFailure(){ |
| 156 | + void shouldCleanTaskOnValidateFailure() { |
| 157 | + var progressTrackerMock = mock(ProgressTracker.class); |
165 | 158 |
|
166 | | - var mockRegistry = mock(TaskRegistry.class); |
| 159 | + var nativeFactorySpy = spy(new NativeFactory( |
| 160 | + mock(GraphProjectFromStoreConfig.class), |
| 161 | + mock(GraphLoaderContext.class), |
| 162 | + mock(GraphDimensions.class), |
| 163 | + progressTrackerMock |
| 164 | + )); |
167 | 165 |
|
168 | | - var mockRegistryFactory = mock(TaskRegistryFactory.class); |
169 | | - when(mockRegistryFactory.newInstance(any())).thenReturn(mockRegistry); |
170 | | - var mockgraphLoaderContext = mock(GraphLoaderContext.class); |
171 | | - when(mockgraphLoaderContext.taskRegistryFactory()).thenReturn(mockRegistryFactory); |
172 | | - when(mockgraphLoaderContext.log()).thenReturn(Log.noOpLog()); |
| 166 | + doThrow(new IllegalArgumentException("Intentionally failing validation")).when(nativeFactorySpy).validate(); |
173 | 167 |
|
174 | | - var labelSet = new LongHashSet(); |
175 | | - labelSet.add(GraphDimensions.NO_SUCH_LABEL); |
176 | | - var graphDimensions =ImmutableGraphDimensions.builder() |
177 | | - .nodeCount(100) |
178 | | - .nodeLabelTokens(labelSet) |
179 | | - .build(); |
180 | | - |
181 | | - var config = mock(GraphProjectFromStoreConfig.class); |
182 | | - when(config.relationshipProjections()).thenReturn(new RelationshipProjections() { |
183 | | - @Override |
184 | | - public Map<RelationshipType, RelationshipProjection> projections() { |
185 | | - return Map.of(); |
186 | | - } |
187 | | - }); |
188 | | - |
189 | | - when(config.nodeProjections()).thenReturn(new NodeProjections() { |
190 | | - @Override |
191 | | - public Map<NodeLabel, NodeProjection> projections() { |
192 | | - return Map.of( |
193 | | - NodeLabel.of("foo"), new NodeProjection() { |
194 | | - @Override |
195 | | - public String label() { |
196 | | - return "bar"; |
197 | | - } |
198 | | - } |
199 | | - ); |
200 | | - } |
201 | | - }); |
202 | | - |
203 | | - when(config.logProgress()).thenReturn(true); |
204 | | - when(config.jobId()).thenReturn(new JobId("foo")); |
205 | | - when(config.readConcurrency()).thenReturn(new Concurrency(1)); |
206 | | - |
207 | | - var nativeFactory = NativeFactory.nativeFactory( |
208 | | - config, |
209 | | - mockgraphLoaderContext, |
210 | | - Optional.of(graphDimensions) |
211 | | - ); |
212 | | - |
213 | | - MutableBoolean failed = new MutableBoolean(false); |
214 | | - try { |
215 | | - nativeFactory.build(); |
216 | | - } catch (Exception ignored){ |
217 | | - failed.set(true); |
218 | | - } |
219 | | - |
220 | | - assertThat(failed.get()).isTrue(); |
221 | | - verify(mockRegistry,times(1)).registerTask(any()); |
222 | | - verify(mockRegistry,times(1)).markCompleted(); |
| 168 | + assertThatIllegalArgumentException() |
| 169 | + .isThrownBy(nativeFactorySpy::build) |
| 170 | + .withMessageContaining("Intentionally failing validation"); |
223 | 171 |
|
| 172 | + verify(progressTrackerMock, times(1)).beginSubTask(); |
| 173 | + verify(progressTrackerMock, times(1)).endSubTaskWithFailure(); |
| 174 | + verifyNoMoreInteractions(progressTrackerMock); |
224 | 175 | } |
225 | 176 | } |
0 commit comments