import hashlib
import logging
import os
import pickle
from functools import partial
import elasticsearch
import elasticsearch_dsl
import requests
from elasticsearch import AIOHttpConnection, RequestsHttpConnection as _Conn
from tornado.ioloop import IOLoop
from biothings.utils.common import run_once
try:
import boto3
from requests_aws4auth import AWS4Auth
aws_avail = True
except ImportError:
# only needed for connecting to AWS OpenSearch
aws_avail = False
logger = logging.getLogger(__name__)
_should_log = run_once()
def _log_pkg():
es_ver = elasticsearch.__version__
es_dsl_ver = elasticsearch_dsl.__version__
logger.info("Elasticsearch Package Version: %s", ".".join(map(str, es_ver)))
logger.info("Elasticsearch DSL Package Version: %s", ".".join(map(str, es_dsl_ver)))
def _log_db(client, uri):
logger.info(client)
def _log_es(client, hosts):
_log_db(client, hosts)
# only perform health check with the async client
# so that it doesn't slow down program start time
if isinstance(client, elasticsearch.AsyncElasticsearch):
async def log_cluster(async_client):
cluster = await async_client.info()
# not specifying timeout in the function above because
# there could be a number of es tasks scheduled before
# this call and would take the cluster a while to respond
if _should_log():
_log_pkg()
cluster_name = cluster["cluster_name"]
version = cluster["version"]["number"]
logger.info("%s: %s %s", hosts, cluster_name, version)
IOLoop.current().add_callback(log_cluster, client)
# ------------------------
# Low Level Functions
# ------------------------
class _AsyncConn(AIOHttpConnection):
def __init__(self, *args, **kwargs):
self.aws_auth = None
_auth = kwargs.get("http_auth")
if _auth and hasattr(_auth, "region") and isinstance(_auth, AWS4Auth):
self.aws_auth = _auth
kwargs["http_auth"] = None
super().__init__(*args, **kwargs)
async def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None):
req = requests.PreparedRequest()
req.prepare(method, self.host + url, headers, None, body, params)
self.aws_auth(req) # sign the request
headers.update(req.headers)
return await super().perform_request(method, url, params, body, timeout, ignore, headers)
# https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html
AWS_META_URL = "http://169.254.169.254/latest/dynamic/instance-identity/document"
[docs]
def get_es_client(hosts=None, async_=False, **settings):
"""Enhanced ES client initialization.
Additionally support these parameters:
async_: use AsyncElasticserach instead of Elasticsearch.
aws: setup request signing and provide reasonable ES settings
to access AWS OpenSearch, by default assuming it is on HTTPS.
sniff: provide resonable default settings to enable client-side
LB to an ES cluster. this param itself is not an ES param.
"""
if settings.pop("aws", False):
if not aws_avail:
raise ImportError('"boto3" and "requests_aws4auth" are required for AWS OpenSearch')
# find region
session = boto3.Session()
region = session.region_name
if not region: # not in ~/.aws/config
region = os.environ.get("AWS_REGION")
if not region: # not in environment variable
try: # assume same-region service access
res = requests.get(AWS_META_URL)
region = res.json()["region"]
except Exception: # not running in VPC
region = "us-west-2" # default
# find credentials
credentials = session.get_credentials()
awsauth = AWS4Auth(refreshable_credentials=credentials, region=region, service="es")
_cc = _AsyncConn if async_ else _Conn
settings.update(http_auth=awsauth, connection_class=_cc)
settings.setdefault("use_ssl", True)
settings.setdefault("verify_certs", True)
# not evaluated when 'aws' flag is set because
# AWS OpenSearch is internally load-balanced
# and does not support client-side sniffing.
elif settings.pop("sniff", False):
settings.setdefault("sniff_on_start", True)
settings.setdefault("sniff_on_connection_fail", True)
settings.setdefault("sniffer_timeout", 60)
if async_:
from elasticsearch import AsyncElasticsearch
client = AsyncElasticsearch
else:
from elasticsearch import Elasticsearch
client = Elasticsearch
return client(hosts, **settings)
[docs]
def get_sql_client(uri, **settings):
from sqlalchemy import create_engine
return create_engine(uri, **settings).connect()
[docs]
def get_mongo_client(uri, **settings):
from pymongo import MongoClient
return MongoClient(uri, **settings).get_default_database()
def _not_implemented_client():
raise NotImplementedError()
# ------------------------
# High Level Utilities
# ------------------------
[docs]
class _ClientPool:
def __init__(self, client_factory, async_factory, callback=None):
self._client_factory = client_factory
self._clients = {}
self._async_client_factory = async_factory
self._async_clients = {}
self.callback = callback or _log_db
[docs]
@staticmethod
def hash(config):
_config = pickle.dumps(config)
_hash = hashlib.md5(_config)
return _hash.hexdigest()
def _get_client(self, repo, factory, uri, settings):
hash = self.hash((uri, settings))
if hash in repo:
return repo[hash]
repo[hash] = factory(uri, **settings)
self.callback(repo[hash], uri)
return repo[hash]
[docs]
def get_client(self, uri, **settings):
return self._get_client(self._clients, self._client_factory, uri, settings)
[docs]
def get_async_client(self, uri, **settings):
return self._get_client(self._async_clients, self._async_client_factory, uri, settings)
es = _ClientPool(get_es_client, partial(get_es_client, async_=True), _log_es)
sql = _ClientPool(get_sql_client, _not_implemented_client)
mongo = _ClientPool(get_mongo_client, _not_implemented_client)