/* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at http://www.apache.org/licenses/LICENSE-2.0 or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package org.apache.tinkerpop.gremlin.driver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.utils.GitProperties; import software.amazon.utils.SoftwareVersion; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; public class GremlinCluster implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(GremlinCluster.class); private final Collection defaultEndpoints; private final ClusterFactory clusterFactory; private final Collection clientClusterCollections = new CopyOnWriteArrayList<>(); private final AtomicReference> closing = new AtomicReference<>(null); private final EndpointStrategies endpointStrategies; private final AcquireConnectionConfig acquireConnectionConfig; private final MetricsConfig metricsConfig; public GremlinCluster(Collection defaultEndpoints, ClusterFactory clusterFactory, EndpointStrategies endpointStrategies, AcquireConnectionConfig acquireConnectionConfig, MetricsConfig metricsConfig) { logger.info("Version: {} {}", SoftwareVersion.FromResource, GitProperties.FromResource); logger.info("Created GremlinCluster [defaultEndpoints: {}, enableMetrics: {}]", defaultEndpoints.stream() .map(Endpoint::getAddress) .collect(Collectors.toList()), metricsConfig.enableMetrics()); this.defaultEndpoints = defaultEndpoints; this.clusterFactory = clusterFactory; this.endpointStrategies = endpointStrategies; this.acquireConnectionConfig = acquireConnectionConfig; this.metricsConfig = metricsConfig; } public GremlinClient connect(List addresses, Client.Settings settings) { return connectToEndpoints( addresses.stream() .map(a -> new DatabaseEndpoint().withAddress(a)) .collect(Collectors.toList()), settings); } public GremlinClient connectToEndpoints(Collection endpoints, Client.Settings settings) { logger.info("Connecting with: {}", endpoints.stream() .map(Endpoint::getAddress) .collect(Collectors.toList())); if (endpoints.isEmpty()) { throw new IllegalStateException("You must supply at least one endpoint"); } Cluster parentCluster = clusterFactory.createCluster(null); ClientClusterCollection clientClusterCollection = new ClientClusterCollection(clusterFactory, parentCluster); Map clustersForEndpoints = clientClusterCollection.createClustersForEndpoints(new EndpointCollection(endpoints)); List newEndpointClients = EndpointClient.create(clustersForEndpoints); EndpointClientCollection endpointClientCollection = new EndpointClientCollection( EndpointClientCollection.builder() .withEndpointClients(newEndpointClients) .setCollectMetrics(metricsConfig.enableMetrics())); clientClusterCollections.add(clientClusterCollection); return new GremlinClient( clientClusterCollection.getParentCluster(), settings, endpointClientCollection, clientClusterCollection, endpointStrategies, acquireConnectionConfig, metricsConfig ); } public GremlinClient connect(List addresses) { return connect(addresses, Client.Settings.build().create()); } public GremlinClient connectToEndpoints(List endpoints) { return connectToEndpoints(endpoints, Client.Settings.build().create()); } public GremlinClient connect() { return connectToEndpoints(defaultEndpoints, Client.Settings.build().create()); } public GremlinClient connect(Client.Settings settings) { return connectToEndpoints(defaultEndpoints, settings); } public CompletableFuture closeAsync() { if (closing.get() != null) return closing.get(); List> futures = new ArrayList<>(); for (ClientClusterCollection clientClusterCollection : clientClusterCollections) { futures.add(clientClusterCollection.closeAsync()); } closing.set(CompletableFuture.allOf(futures.toArray(new CompletableFuture[]{}))); return closing.get(); } @Override public void close() throws Exception { closeAsync().join(); } }