diff --git a/api/src/main/java/io/grpc/InternalServiceProviders.java b/api/src/main/java/io/grpc/InternalServiceProviders.java index c72e01db67a..debc786a82a 100644 --- a/api/src/main/java/io/grpc/InternalServiceProviders.java +++ b/api/src/main/java/io/grpc/InternalServiceProviders.java @@ -17,7 +17,9 @@ package io.grpc; import com.google.common.annotations.VisibleForTesting; +import java.util.Iterator; import java.util.List; +import java.util.ServiceLoader; @Internal public final class InternalServiceProviders { @@ -27,12 +29,17 @@ private InternalServiceProviders() { /** * Accessor for method. */ - public static T load( + @Deprecated + public static List loadAll( Class klass, - Iterable> hardcoded, + Iterable> hardCodedClasses, ClassLoader classLoader, PriorityAccessor priorityAccessor) { - return ServiceProviders.load(klass, hardcoded, classLoader, priorityAccessor); + return loadAll( + klass, + ServiceLoader.load(klass, classLoader).iterator(), + () -> hardCodedClasses, + priorityAccessor); } /** @@ -40,10 +47,10 @@ public static T load( */ public static List loadAll( Class klass, - Iterable> hardCodedClasses, - ClassLoader classLoader, + Iterator serviceLoader, + Supplier>> hardCodedClasses, PriorityAccessor priorityAccessor) { - return ServiceProviders.loadAll(klass, hardCodedClasses, classLoader, priorityAccessor); + return ServiceProviders.loadAll(klass, serviceLoader, hardCodedClasses::get, priorityAccessor); } /** @@ -71,4 +78,8 @@ public static boolean isAndroid(ClassLoader cl) { } public interface PriorityAccessor extends ServiceProviders.PriorityAccessor {} + + public interface Supplier { + T get(); + } } diff --git a/api/src/main/java/io/grpc/LoadBalancerRegistry.java b/api/src/main/java/io/grpc/LoadBalancerRegistry.java index f6b69f978b8..a8fbc102f5f 100644 --- a/api/src/main/java/io/grpc/LoadBalancerRegistry.java +++ b/api/src/main/java/io/grpc/LoadBalancerRegistry.java @@ -26,6 +26,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -42,7 +43,6 @@ public final class LoadBalancerRegistry { private static final Logger logger = Logger.getLogger(LoadBalancerRegistry.class.getName()); private static LoadBalancerRegistry instance; - private static final Iterable> HARDCODED_CLASSES = getHardCodedClasses(); private final LinkedHashSet allProviders = new LinkedHashSet<>(); @@ -101,8 +101,10 @@ public static synchronized LoadBalancerRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( LoadBalancerProvider.class, - HARDCODED_CLASSES, - LoadBalancerProvider.class.getClassLoader(), + ServiceLoader + .load(LoadBalancerProvider.class, LoadBalancerProvider.class.getClassLoader()) + .iterator(), + LoadBalancerRegistry::getHardCodedClasses, new LoadBalancerPriorityAccessor()); instance = new LoadBalancerRegistry(); for (LoadBalancerProvider provider : providerList) { diff --git a/api/src/main/java/io/grpc/ManagedChannelRegistry.java b/api/src/main/java/io/grpc/ManagedChannelRegistry.java index aed5eca9abf..9b782dcf48e 100644 --- a/api/src/main/java/io/grpc/ManagedChannelRegistry.java +++ b/api/src/main/java/io/grpc/ManagedChannelRegistry.java @@ -29,6 +29,7 @@ import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.concurrent.ThreadSafe; @@ -100,8 +101,10 @@ public static synchronized ManagedChannelRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( ManagedChannelProvider.class, - getHardCodedClasses(), - ManagedChannelProvider.class.getClassLoader(), + ServiceLoader + .load(ManagedChannelProvider.class, ManagedChannelProvider.class.getClassLoader()) + .iterator(), + ManagedChannelRegistry::getHardCodedClasses, new ManagedChannelPriorityAccessor()); instance = new ManagedChannelRegistry(); for (ManagedChannelProvider provider : providerList) { diff --git a/api/src/main/java/io/grpc/NameResolverRegistry.java b/api/src/main/java/io/grpc/NameResolverRegistry.java index 26eb5552b9b..c7e5cf30714 100644 --- a/api/src/main/java/io/grpc/NameResolverRegistry.java +++ b/api/src/main/java/io/grpc/NameResolverRegistry.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -125,8 +126,10 @@ public static synchronized NameResolverRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( NameResolverProvider.class, - getHardCodedClasses(), - NameResolverProvider.class.getClassLoader(), + ServiceLoader + .load(NameResolverProvider.class, NameResolverProvider.class.getClassLoader()) + .iterator(), + NameResolverRegistry::getHardCodedClasses, new NameResolverPriorityAccessor()); if (providerList.isEmpty()) { logger.warning("No NameResolverProviders found via ServiceLoader, including for DNS. This " diff --git a/api/src/main/java/io/grpc/ServerRegistry.java b/api/src/main/java/io/grpc/ServerRegistry.java index 5b9c8c558e7..1ec7030b82b 100644 --- a/api/src/main/java/io/grpc/ServerRegistry.java +++ b/api/src/main/java/io/grpc/ServerRegistry.java @@ -24,6 +24,7 @@ import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.concurrent.ThreadSafe; @@ -93,8 +94,9 @@ public static synchronized ServerRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( ServerProvider.class, - getHardCodedClasses(), - ServerProvider.class.getClassLoader(), + ServiceLoader.load(ServerProvider.class, ServerProvider.class.getClassLoader()) + .iterator(), + ServerRegistry::getHardCodedClasses, new ServerPriorityAccessor()); instance = new ServerRegistry(); for (ServerProvider provider : providerList) { diff --git a/api/src/main/java/io/grpc/ServiceProviders.java b/api/src/main/java/io/grpc/ServiceProviders.java index ac4b27d8783..861688be9fb 100644 --- a/api/src/main/java/io/grpc/ServiceProviders.java +++ b/api/src/main/java/io/grpc/ServiceProviders.java @@ -17,10 +17,13 @@ package io.grpc; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Supplier; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.ServiceConfigurationError; import java.util.ServiceLoader; @@ -29,42 +32,44 @@ private ServiceProviders() { // do not instantiate } - /** - * If this is not Android, returns the highest priority implementation of the class via - * {@link ServiceLoader}. - * If this is Android, returns an instance of the highest priority class in {@code hardcoded}. - */ - public static T load( - Class klass, - Iterable> hardcoded, - ClassLoader cl, - PriorityAccessor priorityAccessor) { - List candidates = loadAll(klass, hardcoded, cl, priorityAccessor); - if (candidates.isEmpty()) { - return null; - } - return candidates.get(0); - } - /** * If this is not Android, returns all available implementations discovered via * {@link ServiceLoader}. * If this is Android, returns all available implementations in {@code hardcoded}. * The list is sorted in descending priority order. + * + *

{@code serviceLoader} should be created with {@code ServiceLoader.load(MyClass.class, + * MyClass.class.getClassLoader()).iterator()} in order to be detected by R8 so that R8 full mode + * will keep the constructors for the provider classes. */ public static List loadAll( Class klass, - Iterable> hardcoded, - ClassLoader cl, + Iterator serviceLoader, + Supplier>> hardcoded, final PriorityAccessor priorityAccessor) { - Iterable candidates; - if (isAndroid(cl)) { - candidates = getCandidatesViaHardCoded(klass, hardcoded); + Iterator candidates; + if (serviceLoader instanceof ListIterator) { + // A rewriting tool has replaced the ServiceLoader with a List of some sort (R8 uses + // ArrayList, AppReduce uses singletonList). We prefer to use such iterators on Android as + // they won't need reflection like the hard-coded list does. In addition, the provider + // instances will have already been created, so it seems we should use them. + // + // R8: https://r8.googlesource.com/r8/+/490bc53d9310d4cc2a5084c05df4aadaec8c885d/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java + // AppReduce: service_loader_pass.cc + candidates = serviceLoader; + } else if (isAndroid(klass.getClassLoader())) { + // Avoid getResource() on Android, which must read from a zip which uses a lot of memory + candidates = getCandidatesViaHardCoded(klass, hardcoded.get()).iterator(); + } else if (!serviceLoader.hasNext()) { + // Attempt to load using the context class loader and ServiceLoader. + // This allows frameworks like http://aries.apache.org/modules/spi-fly.html to plug in. + candidates = ServiceLoader.load(klass).iterator(); } else { - candidates = getCandidatesViaServiceLoader(klass, cl); + candidates = serviceLoader; } List list = new ArrayList<>(); - for (T current: candidates) { + while (candidates.hasNext()) { + T current = candidates.next(); if (!priorityAccessor.isAvailable(current)) { continue; } @@ -101,15 +106,14 @@ static boolean isAndroid(ClassLoader cl) { } /** - * Loads service providers for the {@code klass} service using {@link ServiceLoader}. + * For testing only: Loads service providers for the {@code klass} service using {@link + * ServiceLoader}. Does not support spi-fly and related tricks. */ @VisibleForTesting public static Iterable getCandidatesViaServiceLoader(Class klass, ClassLoader cl) { Iterable i = ServiceLoader.load(klass, cl); - // Attempt to load using the context class loader and ServiceLoader. - // This allows frameworks like http://aries.apache.org/modules/spi-fly.html to plug in. if (!i.iterator().hasNext()) { - i = ServiceLoader.load(klass); + return null; } return i; } diff --git a/api/src/test/java/io/grpc/ServiceProvidersTest.java b/api/src/test/java/io/grpc/ServiceProvidersTest.java index 7d4388a5bb9..f971ed42646 100644 --- a/api/src/test/java/io/grpc/ServiceProvidersTest.java +++ b/api/src/test/java/io/grpc/ServiceProvidersTest.java @@ -16,6 +16,7 @@ package io.grpc; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -23,12 +24,15 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import io.grpc.InternalServiceProviders.PriorityAccessor; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.ServiceConfigurationError; +import java.util.ServiceLoader; +import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,7 +40,6 @@ /** Unit tests for {@link ServiceProviders}. */ @RunWith(JUnit4.class) public class ServiceProvidersTest { - private static final List> NO_HARDCODED = Collections.emptyList(); private static final PriorityAccessor ACCESSOR = new PriorityAccessor() { @Override @@ -51,6 +54,19 @@ public int getPriority(ServiceProvidersTestAbstractProvider provider) { }; private final String serviceFile = "META-INF/services/io.grpc.ServiceProvidersTestAbstractProvider"; + private boolean failingHardCodedAccessed; + private final Supplier>> failingHardCoded = new Supplier>>() { + @Override + public Iterable> get() { + failingHardCodedAccessed = true; + throw new AssertionError(); + } + }; + + @After + public void tearDown() { + assertThat(failingHardCodedAccessed).isFalse(); + } @Test public void contextClassLoaderProvider() { @@ -69,8 +85,8 @@ public void contextClassLoaderProvider() { Thread.currentThread().setContextClassLoader(rcll); assertEquals( Available7Provider.class, - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR).getClass()); + load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR) + .getClass()); } finally { Thread.currentThread().setContextClassLoader(ccl); } @@ -85,8 +101,7 @@ public void noProvider() { serviceFile, "io/grpc/ServiceProvidersTestAbstractProvider-doesNotExist.txt"); Thread.currentThread().setContextClassLoader(cl); - assertNull(ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR)); + assertNull(load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR)); } finally { Thread.currentThread().setContextClassLoader(ccl); } @@ -98,11 +113,11 @@ public void multipleProvider() throws Exception { "io/grpc/ServiceProvidersTestAbstractProvider-multipleProvider.txt"); assertSame( Available7Provider.class, - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR).getClass()); + load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR) + .getClass()); - List providers = ServiceProviders.loadAll( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + List providers = loadAll( + ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); assertEquals(3, providers.size()); assertEquals(Available7Provider.class, providers.get(0).getClass()); assertEquals(Available5Provider.class, providers.get(1).getClass()); @@ -116,8 +131,8 @@ public void unavailableProvider() { "io/grpc/ServiceProvidersTestAbstractProvider-unavailableProvider.txt"); assertEquals( Available7Provider.class, - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR).getClass()); + load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR) + .getClass()); } @Test @@ -125,8 +140,7 @@ public void unknownClassProvider() { ClassLoader cl = new ReplacingClassLoader(getClass().getClassLoader(), serviceFile, "io/grpc/ServiceProvidersTestAbstractProvider-unknownClassProvider.txt"); try { - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Exception expected"); } catch (ServiceConfigurationError e) { // noop @@ -140,8 +154,7 @@ public void exceptionSurfacedToCaller_failAtInit() { try { // Even though there is a working provider, if any providers fail then we should fail // completely to avoid returning something unexpected. - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Expected exception"); } catch (ServiceConfigurationError expected) { // noop @@ -154,8 +167,7 @@ public void exceptionSurfacedToCaller_failAtPriority() { "io/grpc/ServiceProvidersTestAbstractProvider-failAtPriorityProvider.txt"); try { // The exception should be surfaced to the caller - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Expected exception"); } catch (FailAtPriorityProvider.PriorityException expected) { // noop @@ -168,8 +180,7 @@ public void exceptionSurfacedToCaller_failAtAvailable() { "io/grpc/ServiceProvidersTestAbstractProvider-failAtAvailableProvider.txt"); try { // The exception should be surfaced to the caller - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Expected exception"); } catch (FailAtAvailableProvider.AvailableException expected) { // noop @@ -244,6 +255,30 @@ class RandomClass {} assertFalse(candidates.iterator().hasNext()); } + private static T load( + Class klass, + Supplier>> hardCoded, + ClassLoader cl, + PriorityAccessor priorityAccessor) { + List candidates = loadAll(klass, hardCoded, cl, priorityAccessor); + if (candidates.isEmpty()) { + return null; + } + return candidates.get(0); + } + + private static List loadAll( + Class klass, + Supplier>> hardCoded, + ClassLoader classLoader, + PriorityAccessor priorityAccessor) { + return ServiceProviders.loadAll( + klass, + ServiceLoader.load(klass, classLoader).iterator(), + hardCoded, + priorityAccessor); + } + private static class BaseProvider extends ServiceProvidersTestAbstractProvider { private final boolean isAvailable; private final int priority; diff --git a/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java b/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java index 9dfefaf1a65..9dd77a400cd 100644 --- a/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java +++ b/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java @@ -29,6 +29,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -109,8 +110,10 @@ public static synchronized XdsCredentialsRegistry getDefaultRegistry() { if (instance == null) { List providerList = InternalServiceProviders.loadAll( XdsCredentialsProvider.class, - getHardCodedClasses(), - XdsCredentialsProvider.class.getClassLoader(), + ServiceLoader + .load(XdsCredentialsProvider.class, XdsCredentialsProvider.class.getClassLoader()) + .iterator(), + XdsCredentialsRegistry::getHardCodedClasses, new XdsCredentialsProviderPriorityAccessor()); if (providerList.isEmpty()) { logger.warning("No XdsCredsRegistry found via ServiceLoader, including for GoogleDefault, "