diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java index 2757ee19..2bda9892 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java @@ -173,7 +173,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { TrainingResultJson result = addTrainingAuth( trainingRestClient.post() - .uri("/train") + .uri("/train-sender") .contentType(MediaType.MULTIPART_FORM_DATA)) .body(body) .retrieve() diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java new file mode 100644 index 00000000..17f084fa --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java @@ -0,0 +1,41 @@ +package org.raddatz.familienarchiv.service; + +import com.sun.net.httpserver.HttpServer; +import org.junit.jupiter.api.Test; + +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +import static org.assertj.core.api.Assertions.assertThat; + +class RestClientOcrClientTrainingTest { + + @Test + void trainSenderModel_postsToTrainSenderEndpoint() throws Exception { + List capturedPaths = new CopyOnWriteArrayList<>(); + HttpServer server = HttpServer.create(new InetSocketAddress(0), 0); + server.createContext("/", exchange -> { + capturedPaths.add(exchange.getRequestURI().getPath()); + byte[] response = "{\"loss\":0.1,\"accuracy\":0.95,\"cer\":0.05,\"epochs\":3}" + .getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, response.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(response); + } + }); + server.start(); + + try { + int port = server.getAddress().getPort(); + RestClientOcrClient client = new RestClientOcrClient("http://localhost:" + port, ""); + client.trainSenderModel(new byte[]{1, 2, 3}, "/app/models/sender_test.mlmodel"); + assertThat(capturedPaths).contains("/train-sender"); + } finally { + server.stop(0); + } + } +}