Skip to content

Commit ada7e11

Browse files
committed
Add support for Spring Boot's DataSourceScriptDatabaseInitializer
Intercept DataSourceScriptDatabaseInitializer via BeanPostProcessor to integrate with Zonky's optimized template-based database initialization. Uses ThreadLocal DataSource pattern for parallel prefetching support. - Add DataSourceScriptDatabaseExtension (BeanPostProcessor) - Add DataSourceScriptDatabasePreparer with field-by-field equals/hashCode - Add ThreadLocalDataSource for thread-safe DataSource delegation - Extract AbstractDelegatingDataSource from AbstractEmbeddedDatabase - Add Spring Boot 4 stubs for compile-time compatibility - Add integration tests for both SB3 and SB4 - Add zonky.test.database.sql-init.enabled property (default: true)
1 parent d684bfb commit ada7e11

11 files changed

Lines changed: 530 additions & 67 deletions

File tree

embedded-database-spring-test/src/main/java/io/zonky/test/db/config/EmbeddedDatabaseAutoConfiguration.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import io.zonky.test.db.flyway.FlywayDatabaseExtension;
2020
import io.zonky.test.db.flyway.FlywayPropertiesPostProcessor;
21+
import io.zonky.test.db.init.DataSourceScriptDatabaseExtension;
2122
import io.zonky.test.db.init.EmbeddedDatabaseInitializer;
2223
import io.zonky.test.db.init.ScriptDatabasePreparer;
2324
import io.zonky.test.db.liquibase.LiquibaseDatabaseExtension;
@@ -289,6 +290,14 @@ public BeanPostProcessor liquibasePropertiesPostProcessor() {
289290
return new LiquibasePropertiesPostProcessor();
290291
}
291292

293+
@Bean
294+
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
295+
@ConditionalOnClass(name = "org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer")
296+
@ConditionalOnMissingBean(name = "dataSourceScriptDatabaseExtension")
297+
public DataSourceScriptDatabaseExtension dataSourceScriptDatabaseExtension(Environment environment) {
298+
return new DataSourceScriptDatabaseExtension(environment);
299+
}
300+
292301
@Bean
293302
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
294303
@ConditionalOnMissingBean(name = "embeddedDatabaseInitializer")
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.zonky.test.db.init;
18+
19+
import io.zonky.test.db.context.DatabaseContext;
20+
import io.zonky.test.db.util.AopProxyUtils;
21+
import io.zonky.test.db.util.ReflectionUtils;
22+
import org.springframework.beans.factory.config.BeanPostProcessor;
23+
import org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer;
24+
import org.springframework.core.Ordered;
25+
import org.springframework.core.env.Environment;
26+
27+
import javax.sql.DataSource;
28+
29+
public class DataSourceScriptDatabaseExtension implements BeanPostProcessor, Ordered {
30+
31+
private final boolean enabled;
32+
33+
public DataSourceScriptDatabaseExtension(Environment environment) {
34+
this.enabled = environment.getProperty("zonky.test.database.sql-init.enabled", boolean.class, true);
35+
}
36+
37+
@Override
38+
public int getOrder() {
39+
return Ordered.HIGHEST_PRECEDENCE + 1;
40+
}
41+
42+
@Override
43+
public Object postProcessBeforeInitialization(Object bean, String beanName) {
44+
if (enabled && bean instanceof DataSourceScriptDatabaseInitializer) {
45+
DataSourceScriptDatabaseInitializer initializer = (DataSourceScriptDatabaseInitializer) bean;
46+
DataSource dataSource = ReflectionUtils.getField(initializer, "dataSource");
47+
DatabaseContext context = AopProxyUtils.getDatabaseContext(dataSource);
48+
49+
if (context != null) {
50+
context.apply(new DataSourceScriptDatabasePreparer(initializer));
51+
return new SuppressedInitializerWrapper(initializer);
52+
}
53+
}
54+
55+
return bean;
56+
}
57+
58+
@Override
59+
public Object postProcessAfterInitialization(Object bean, String beanName) {
60+
return bean;
61+
}
62+
63+
public static class SuppressedInitializerWrapper {
64+
65+
private final DataSourceScriptDatabaseInitializer initializer;
66+
67+
public SuppressedInitializerWrapper(DataSourceScriptDatabaseInitializer initializer) {
68+
this.initializer = initializer;
69+
}
70+
71+
@Override
72+
public String toString() {
73+
return "SuppressedInitializerWrapper{initializer=" + initializer.getClass().getName() + "}";
74+
}
75+
}
76+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.zonky.test.db.init;
18+
19+
import com.cedarsoftware.util.DeepEquals;
20+
import io.zonky.test.db.preparer.DatabasePreparer;
21+
import io.zonky.test.db.util.ReflectionUtils;
22+
import org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer;
23+
24+
import org.springframework.util.ReflectionUtils.FieldFilter;
25+
26+
import javax.sql.DataSource;
27+
import java.lang.reflect.Modifier;
28+
import java.util.Arrays;
29+
import java.util.HashSet;
30+
import java.util.Set;
31+
import java.util.concurrent.atomic.AtomicBoolean;
32+
import java.util.concurrent.atomic.AtomicInteger;
33+
34+
import static org.springframework.util.ReflectionUtils.makeAccessible;
35+
36+
public class DataSourceScriptDatabasePreparer implements DatabasePreparer {
37+
38+
private static final Set<String> EXCLUDED_FIELDS = new HashSet<>(Arrays.asList("dataSource", "resourceLoader"));
39+
40+
private static final FieldFilter FIELD_FILTER =
41+
field -> !Modifier.isStatic(field.getModifiers()) && !EXCLUDED_FIELDS.contains(field.getName());
42+
43+
private final DataSourceScriptDatabaseInitializer initializer;
44+
private final ThreadLocalDataSource threadLocalDataSource;
45+
46+
public DataSourceScriptDatabasePreparer(DataSourceScriptDatabaseInitializer initializer) {
47+
this.initializer = initializer;
48+
this.threadLocalDataSource = new ThreadLocalDataSource();
49+
ReflectionUtils.setField(initializer, "dataSource", threadLocalDataSource);
50+
}
51+
52+
@Override
53+
public long estimatedDuration() {
54+
return 10;
55+
}
56+
57+
@Override
58+
public void prepare(DataSource dataSource) {
59+
threadLocalDataSource.set(dataSource);
60+
try {
61+
initializer.initializeDatabase();
62+
} finally {
63+
threadLocalDataSource.clear();
64+
}
65+
}
66+
67+
@Override
68+
public boolean equals(Object o) {
69+
if (this == o) return true;
70+
if (o == null || getClass() != o.getClass()) return false;
71+
DataSourceScriptDatabasePreparer that = (DataSourceScriptDatabasePreparer) o;
72+
if (initializer.getClass() != that.initializer.getClass()) return false;
73+
AtomicBoolean equal = new AtomicBoolean(true);
74+
org.springframework.util.ReflectionUtils.doWithFields(initializer.getClass(),
75+
field -> {
76+
if (!equal.get()) return;
77+
makeAccessible(field);
78+
if (!DeepEquals.deepEquals(field.get(initializer), field.get(that.initializer))) {
79+
equal.set(false);
80+
}
81+
},
82+
FIELD_FILTER);
83+
return equal.get();
84+
}
85+
86+
@Override
87+
public int hashCode() {
88+
AtomicInteger result = new AtomicInteger(initializer.getClass().hashCode());
89+
org.springframework.util.ReflectionUtils.doWithFields(initializer.getClass(),
90+
field -> {
91+
makeAccessible(field);
92+
result.set(31 * result.get() + DeepEquals.deepHashCode(field.get(initializer)));
93+
},
94+
FIELD_FILTER);
95+
return result.get();
96+
}
97+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.zonky.test.db.init;
18+
19+
import io.zonky.test.db.provider.support.AbstractDelegatingDataSource;
20+
21+
import javax.sql.DataSource;
22+
23+
class ThreadLocalDataSource extends AbstractDelegatingDataSource {
24+
25+
private final ThreadLocal<DataSource> current = new ThreadLocal<>();
26+
27+
void set(DataSource dataSource) {
28+
current.set(dataSource);
29+
}
30+
31+
void clear() {
32+
current.remove();
33+
}
34+
35+
@Override
36+
protected DataSource getDataSource() {
37+
return current.get();
38+
}
39+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.zonky.test.db.provider.support;
18+
19+
import javax.sql.DataSource;
20+
import java.io.PrintWriter;
21+
import java.sql.Connection;
22+
import java.sql.SQLException;
23+
import java.sql.SQLFeatureNotSupportedException;
24+
import java.util.logging.Logger;
25+
26+
public abstract class AbstractDelegatingDataSource implements DataSource {
27+
28+
protected abstract DataSource getDataSource();
29+
30+
@Override
31+
public Connection getConnection() throws SQLException {
32+
return getDataSource().getConnection();
33+
}
34+
35+
@Override
36+
public Connection getConnection(String username, String password) throws SQLException {
37+
return getDataSource().getConnection(username, password);
38+
}
39+
40+
@Override
41+
public PrintWriter getLogWriter() throws SQLException {
42+
return getDataSource().getLogWriter();
43+
}
44+
45+
@Override
46+
public void setLogWriter(PrintWriter out) throws SQLException {
47+
getDataSource().setLogWriter(out);
48+
}
49+
50+
@Override
51+
public int getLoginTimeout() throws SQLException {
52+
return getDataSource().getLoginTimeout();
53+
}
54+
55+
@Override
56+
public void setLoginTimeout(int seconds) throws SQLException {
57+
getDataSource().setLoginTimeout(seconds);
58+
}
59+
60+
@Override
61+
public <T> T unwrap(Class<T> iface) throws SQLException {
62+
if (iface.isAssignableFrom(getClass())) {
63+
return iface.cast(this);
64+
}
65+
if (iface.isAssignableFrom(getDataSource().getClass())) {
66+
return iface.cast(getDataSource());
67+
}
68+
return getDataSource().unwrap(iface);
69+
}
70+
71+
@Override
72+
public boolean isWrapperFor(Class<?> iface) throws SQLException {
73+
if (iface.isAssignableFrom(getClass())) {
74+
return true;
75+
}
76+
if (iface.isAssignableFrom(getDataSource().getClass())) {
77+
return true;
78+
}
79+
return getDataSource().isWrapperFor(iface);
80+
}
81+
82+
@Override
83+
public Logger getParentLogger() throws SQLFeatureNotSupportedException {
84+
return getDataSource().getParentLogger();
85+
}
86+
}

0 commit comments

Comments
 (0)