feat(ocr): add Preprocessing NDJSON event to Java stream pipeline
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -82,6 +82,10 @@ public class OcrAsyncRunner {
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
}
|
||||
case OcrStreamEvent.Preprocessing preprocessing -> {
|
||||
updateProgress(job, "PREPROCESSING_PAGE:" + preprocessing.pageNumber()
|
||||
+ ":" + totalPages.get());
|
||||
}
|
||||
case OcrStreamEvent.Page page -> {
|
||||
for (OcrBlockResult block : page.blocks()) {
|
||||
createSingleBlock(documentId, block, userId,
|
||||
|
||||
@@ -6,6 +6,8 @@ public sealed interface OcrStreamEvent {
|
||||
|
||||
record Start(int totalPages) implements OcrStreamEvent {}
|
||||
|
||||
record Preprocessing(int pageNumber) implements OcrStreamEvent {}
|
||||
|
||||
record Page(int pageNumber, List<OcrBlockResult> blocks) implements OcrStreamEvent {}
|
||||
|
||||
record Error(int pageNumber, String message) implements OcrStreamEvent {}
|
||||
|
||||
@@ -232,6 +232,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
switch (type) {
|
||||
case "start" -> handler.accept(
|
||||
new OcrStreamEvent.Start(node.path("totalPages").asInt()));
|
||||
case "preprocessing" -> handler.accept(
|
||||
new OcrStreamEvent.Preprocessing(node.path("pageNumber").asInt()));
|
||||
case "page" -> {
|
||||
int pageNumber = node.path("pageNumber").asInt();
|
||||
List<OcrBlockResult> blocks = NDJSON_MAPPER.convertValue(
|
||||
|
||||
@@ -230,6 +230,42 @@ class OcrAsyncRunnerTest {
|
||||
assertThat(job.getProgressMessage()).isEqualTo("DONE:0:1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void runSingleDocument_updatesProgressOnPreprocessingEvent() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
|
||||
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
|
||||
.thenReturn(Optional.of(jobDoc));
|
||||
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
List<String> progressMessages = new ArrayList<>();
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(5));
|
||||
handler.accept(new OcrStreamEvent.Preprocessing(1));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
// Must have totalPages=5 from the preceding Start event, not 0
|
||||
assertThat(progressMessages.get(0)).isEqualTo("PREPROCESSING_PAGE:1:5");
|
||||
}
|
||||
|
||||
@Test
|
||||
void runSingleDocument_logsStreamErrorAtWarnWithoutSettingJobFailed() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
|
||||
@@ -37,11 +37,19 @@ class OcrStreamEventTest {
|
||||
assertThat(done.skippedPages()).isEqualTo(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void preprocessingRecordHoldsPageNumber() {
|
||||
var preprocessing = new OcrStreamEvent.Preprocessing(4);
|
||||
assertThat(preprocessing.pageNumber()).isEqualTo(4);
|
||||
assertThat(preprocessing).isInstanceOf(OcrStreamEvent.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void patternMatchingWorksOnSealedInterface() {
|
||||
OcrStreamEvent event = new OcrStreamEvent.Start(3);
|
||||
String result = switch (event) {
|
||||
case OcrStreamEvent.Start s -> "start:" + s.totalPages();
|
||||
case OcrStreamEvent.Preprocessing pre -> "preprocessing:" + pre.pageNumber();
|
||||
case OcrStreamEvent.Page p -> "page:" + p.pageNumber();
|
||||
case OcrStreamEvent.Error e -> "error:" + e.pageNumber();
|
||||
case OcrStreamEvent.Done d -> "done:" + d.totalBlocks();
|
||||
|
||||
@@ -83,6 +83,25 @@ class RestClientOcrClientStreamTest {
|
||||
assertThat(events).hasSize(3);
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_dispatchesPreprocessingEvent() {
|
||||
String ndjson = """
|
||||
{"type":"start","totalPages":3}
|
||||
{"type":"preprocessing","pageNumber":1}
|
||||
{"type":"page","pageNumber":1,"blocks":[]}
|
||||
{"type":"done","totalBlocks":0,"skippedPages":0}
|
||||
""";
|
||||
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
RestClientOcrClient.parseNdjsonStream(stream, events::add);
|
||||
|
||||
assertThat(events).hasSize(4);
|
||||
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Preprocessing.class);
|
||||
var preprocessing = (OcrStreamEvent.Preprocessing) events.get(1);
|
||||
assertThat(preprocessing.pageNumber()).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_ignoresUnknownEventTypes() {
|
||||
String ndjson = """
|
||||
|
||||
Reference in New Issue
Block a user