@@ -170,7 +173,7 @@ testImplementation("io.flowinquiry.testcontainers:ollama:0.9.1") // Ollama s
io.flowinquiry.testcontainers
mysql
- 0.9.1
+
test
@@ -178,7 +181,7 @@ testImplementation("io.flowinquiry.testcontainers:ollama:0.9.1") // Ollama s
io.flowinquiry.testcontainers
ollama
- 0.9.1
+
test
```
diff --git a/examples/springboot-ollama/build.gradle.kts b/examples/springboot-ollama/build.gradle.kts
index 0577398..e879d43 100644
--- a/examples/springboot-ollama/build.gradle.kts
+++ b/examples/springboot-ollama/build.gradle.kts
@@ -24,6 +24,7 @@ dependencies {
implementation(libs.bundles.spring.ai)
testImplementation(platform(libs.junit.bom))
testImplementation(libs.junit.jupiter)
+ testImplementation(libs.junit.jupiter.params)
testImplementation(libs.junit.platform.launcher)
testImplementation(libs.spring.boot.starter.test)
}
diff --git a/examples/springboot-ollama/src/test/java/io/flowinquiry/testcontainers/examples/ollama/OllamaDemoAppTest.java b/examples/springboot-ollama/src/test/java/io/flowinquiry/testcontainers/examples/ollama/OllamaDemoAppTest.java
index e6445bd..7ff4221 100644
--- a/examples/springboot-ollama/src/test/java/io/flowinquiry/testcontainers/examples/ollama/OllamaDemoAppTest.java
+++ b/examples/springboot-ollama/src/test/java/io/flowinquiry/testcontainers/examples/ollama/OllamaDemoAppTest.java
@@ -9,6 +9,8 @@
import io.flowinquiry.testcontainers.ai.OllamaOptions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
import org.slf4j.Logger;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
@@ -23,7 +25,7 @@
@EnableOllamaContainer(
dockerImage = "ollama/ollama",
version = "0.9.0",
- model = "smollm2:135m",
+ model = "llama3:latest",
options = @OllamaOptions(temperature = "0.7", topP = "0.5"))
@ActiveProfiles("test")
public class OllamaDemoAppTest {
@@ -53,10 +55,13 @@ public void testHealthEndpoint() {
assertTrue(response.contains("Ollama Chat Controller is up and running"));
}
- @Test
- public void testChatClient() {
+ @ParameterizedTest
+ @CsvSource({
+ "What is the result of 1+2? Give the value only, 3",
+ "How many letter 'r' in the word 'Hello'? Give the value only, 0"
+ })
+ public void testChatClient(String prompt, String expectedResult) {
log.info("Testing chat client directly");
- String prompt = "What is Spring AI?";
log.info("Sending prompt: {}", prompt);
String content = chatClient.prompt().user(prompt).call().content();
@@ -64,5 +69,7 @@ public void testChatClient() {
log.info("Received response: {}", content);
assertNotNull(content);
assertFalse(content.isEmpty());
+ assertTrue(
+ content.contains(expectedResult), "Response should contain '" + expectedResult + "'");
}
}
diff --git a/gradle.properties b/gradle.properties
index 6328e07..5e7100a 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -2,5 +2,5 @@
# https://docs.gradle.org/current/userguide/build_environment.html#sec:gradle_configuration_properties
org.gradle.configuration-cache=true
-version=0.9.1
+version=0.9.2
diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml
index 04c9fc9..529c0b9 100644
--- a/gradle/libs.versions.toml
+++ b/gradle/libs.versions.toml
@@ -18,6 +18,7 @@ spring-ai = "1.0.0"
junit-bom = { group = "org.junit", name = "junit-bom", version.ref = "junit-jupiter" }
junit-jupiter = { group = "org.junit.jupiter", name = "junit-jupiter" }
junit-jupiter-api = { group = "org.junit.jupiter", name = "junit-jupiter-api" }
+junit-jupiter-params = { group = "org.junit.jupiter", name = "junit-jupiter-params" }
junit-jupiter-engine = { group = "org.junit.jupiter", name = "junit-jupiter-engine" }
junit-platform-launcher = { group = "org.junit.platform", name = "junit-platform-launcher" }
spring-bom = { group = "org.springframework", name = "spring-framework-bom", version.ref = "spring" }
diff --git a/modules/ollama/src/main/java/io/flowinquiry/testcontainers/ai/OllamaContainerProvider.java b/modules/ollama/src/main/java/io/flowinquiry/testcontainers/ai/OllamaContainerProvider.java
index 3452ba9..65b999c 100644
--- a/modules/ollama/src/main/java/io/flowinquiry/testcontainers/ai/OllamaContainerProvider.java
+++ b/modules/ollama/src/main/java/io/flowinquiry/testcontainers/ai/OllamaContainerProvider.java
@@ -1,8 +1,10 @@
package io.flowinquiry.testcontainers.ai;
import static io.flowinquiry.testcontainers.ContainerType.OLLAMA;
+import static org.testcontainers.containers.BindMode.READ_WRITE;
import io.flowinquiry.testcontainers.ContainerType;
+import io.flowinquiry.testcontainers.Slf4jOutputConsumer;
import io.flowinquiry.testcontainers.SpringAwareContainerProvider;
import java.io.IOException;
import java.util.Properties;
@@ -10,6 +12,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.PropertiesPropertySource;
+import org.testcontainers.containers.Container;
import org.testcontainers.ollama.OllamaContainer;
/**
@@ -46,7 +49,8 @@ public ContainerType getContainerType() {
*/
@Override
protected OllamaContainer createContainer() {
- return new OllamaContainer(dockerImage + ":" + version);
+ return new OllamaContainer(dockerImage + ":" + version)
+ .withFileSystemBind("/tmp/ollama-cache", "/root/.ollama", READ_WRITE);
}
/**
@@ -60,14 +64,31 @@ protected OllamaContainer createContainer() {
@Override
public void start() {
super.start();
+
+ Logger containerLog = LoggerFactory.getLogger(OllamaContainerProvider.class);
+ container.followOutput(new Slf4jOutputConsumer(containerLog));
+
try {
log.info("Starting pull model {}", enableContainerAnnotation.model());
- container.execInContainer("ollama", "pull", enableContainerAnnotation.model());
+ pullModelIfMissing(enableContainerAnnotation.model());
} catch (IOException | InterruptedException e) {
throw new RuntimeException(e);
}
}
+ private void pullModelIfMissing(String modelName) throws IOException, InterruptedException {
+ Container.ExecResult result = container.execInContainer("ollama", "list");
+ String output = result.getStdout();
+
+ if (!output.contains(modelName)) {
+ log.info("Model '{}' not found in ollama cache. Pulling...", modelName);
+ Container.ExecResult pullResult = container.execInContainer("ollama", "pull", modelName);
+ log.info("Pull complete: {}", pullResult.getStdout());
+ } else {
+ log.info("Model '{}' already exists. Skipping pull.", modelName);
+ }
+ }
+
/**
* Applies Ollama-specific configuration to the Spring environment.
*
diff --git a/spring-testcontainers/src/main/java/io/flowinquiry/testcontainers/Slf4jOutputConsumer.java b/spring-testcontainers/src/main/java/io/flowinquiry/testcontainers/Slf4jOutputConsumer.java
new file mode 100644
index 0000000..ee6b501
--- /dev/null
+++ b/spring-testcontainers/src/main/java/io/flowinquiry/testcontainers/Slf4jOutputConsumer.java
@@ -0,0 +1,102 @@
+package io.flowinquiry.testcontainers;
+
+import org.slf4j.Logger;
+import org.slf4j.event.Level;
+import org.testcontainers.containers.output.BaseConsumer;
+import org.testcontainers.containers.output.OutputFrame;
+
+/**
+ * An implementation of {@link BaseConsumer} that routes container output to SLF4J logging. This
+ * consumer allows for different log levels to be used for STDOUT and STDERR streams.
+ *
+ * Usage example:
+ *
+ *
+ * Logger logger = LoggerFactory.getLogger(MyClass.class);
+ * GenericContainer container = new GenericContainer("some-image")
+ * .withLogConsumer(new Slf4jOutputConsumer(logger));
+ *
+ */
+public class Slf4jOutputConsumer extends BaseConsumer {
+
+ /** The SLF4J logger to which container output will be written. */
+ private final Logger logger;
+
+ /** The log level to use for STDOUT output from the container. */
+ private final Level stdoutLogLevel;
+
+ /** The log level to use for STDERR output from the container. */
+ private final Level stderrLogLevel;
+
+ /**
+ * Creates a new Slf4jOutputConsumer with default log levels. STDOUT messages will be logged at
+ * DEBUG level, and STDERR messages at ERROR level.
+ *
+ * @param logger the SLF4J logger to which container output will be written
+ */
+ public Slf4jOutputConsumer(Logger logger) {
+ this(logger, Level.DEBUG, Level.ERROR);
+ }
+
+ /**
+ * Creates a new Slf4jOutputConsumer with custom log levels for STDOUT and STDERR.
+ *
+ * @param logger the SLF4J logger to which container output will be written
+ * @param stdoutLogLevel the log level to use for STDOUT output
+ * @param stderrLogLevel the log level to use for STDERR output
+ */
+ public Slf4jOutputConsumer(Logger logger, Level stdoutLogLevel, Level stderrLogLevel) {
+ this.logger = logger;
+ this.stdoutLogLevel = stdoutLogLevel;
+ this.stderrLogLevel = stderrLogLevel;
+ }
+
+ /**
+ * Processes an output frame from a container and logs it using the configured SLF4J logger.
+ *
+ * The method:
+ *
+ *
+ * - Skips null or empty frames
+ *
- Determines the appropriate log level based on the frame type (STDOUT or STDERR)
+ *
- Logs the message with the frame type as a prefix
+ *
+ *
+ * @param outputFrame the output frame to process
+ */
+ @Override
+ public void accept(OutputFrame outputFrame) {
+ if (outputFrame == null || outputFrame.getBytes() == null) return;
+
+ String message = outputFrame.getUtf8String().trim();
+ if (message.isEmpty()) return;
+
+ Level levelToUse =
+ switch (outputFrame.getType()) {
+ case STDOUT -> stdoutLogLevel;
+ case STDERR -> stderrLogLevel;
+ case END -> null;
+ };
+
+ if (levelToUse != null) {
+ logAtLevel(levelToUse, "[{}] {}", outputFrame.getType(), message);
+ }
+ }
+
+ /**
+ * Logs a message at the specified SLF4J level.
+ *
+ * @param level the SLF4J level at which to log the message
+ * @param format the message format string
+ * @param args the arguments to be formatted into the message string
+ */
+ private void logAtLevel(Level level, String format, Object... args) {
+ switch (level) {
+ case TRACE -> logger.trace(format, args);
+ case DEBUG -> logger.debug(format, args);
+ case INFO -> logger.info(format, args);
+ case WARN -> logger.warn(format, args);
+ case ERROR -> logger.error(format, args);
+ }
+ }
+}
diff --git a/spring-testcontainers/src/main/java/io/flowinquiry/testcontainers/SpringAwareContainerProvider.java b/spring-testcontainers/src/main/java/io/flowinquiry/testcontainers/SpringAwareContainerProvider.java
index ed49c61..b212b44 100644
--- a/spring-testcontainers/src/main/java/io/flowinquiry/testcontainers/SpringAwareContainerProvider.java
+++ b/spring-testcontainers/src/main/java/io/flowinquiry/testcontainers/SpringAwareContainerProvider.java
@@ -25,6 +25,9 @@ public abstract class SpringAwareContainerProvider<
private static final Logger log = LoggerFactory.getLogger(SpringAwareContainerProvider.class);
+ private static boolean reuseContainerSupport =
+ TestcontainersConfiguration.getInstance().environmentSupportsReuse();
+
/** The version of the container image to use. */
protected String version;
@@ -43,12 +46,17 @@ public final void initContainerInstance(A enableContainerAnnotation) {
enableContainerAnnotation.annotationType().getMethod("dockerImage");
Method versionMethod = enableContainerAnnotation.annotationType().getMethod("version");
- log.info("Initializing JDBC container with image {}:{}", dockerImage, version);
+ log.info("Initializing the container with image {}:{}", dockerImage, version);
this.version = (String) versionMethod.invoke(enableContainerAnnotation);
this.dockerImage = (String) dockerImageMethod.invoke(enableContainerAnnotation);
container = createContainer();
- container.withReuse(TestcontainersConfiguration.getInstance().environmentSupportsReuse());
+ container.withReuse(reuseContainerSupport);
+ log.info(
+ "Created the container with image {}:{} with reuse {}",
+ dockerImage,
+ version,
+ reuseContainerSupport);
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new IllegalArgumentException(
"Annotation "
@@ -74,7 +82,7 @@ public void start() {
/** Stops the container. This method is called when the Spring context is closed. */
@Override
public void stop() {
- if (!TestcontainersConfiguration.getInstance().environmentSupportsReuse()) {
+ if (!reuseContainerSupport) {
container.stop();
}
}