Source code for biothings.web.connections

import hashlib
import logging
import os
import pickle
from functools import partial

import elasticsearch
import elasticsearch_dsl
import requests
from elasticsearch import AIOHttpConnection, RequestsHttpConnection as _Conn
from tornado.ioloop import IOLoop

from biothings.utils.common import run_once

try:
    import boto3
    from requests_aws4auth import AWS4Auth

    aws_avail = True
except ImportError:
    # only needed for connecting to AWS OpenSearch
    aws_avail = False

logger = logging.getLogger(__name__)

_should_log = run_once()


def _log_pkg():
    es_ver = elasticsearch.__version__
    es_dsl_ver = elasticsearch_dsl.__version__
    logger.info("Elasticsearch Package Version: %s", ".".join(map(str, es_ver)))
    logger.info("Elasticsearch DSL Package Version: %s", ".".join(map(str, es_dsl_ver)))


def _log_db(client, uri):
    logger.info(client)


def _log_es(client, hosts):
    _log_db(client, hosts)

    # only perform health check with the async client
    # so that it doesn't slow down program start time
    if isinstance(client, elasticsearch.AsyncElasticsearch):

        async def log_cluster(async_client):
            cluster = await async_client.info()
            # not specifying timeout in the function above because
            # there could be a number of es tasks scheduled before
            # this call and would take the cluster a while to respond

            if _should_log():
                _log_pkg()

            cluster_name = cluster["cluster_name"]
            version = cluster["version"]["number"]

            logger.info("%s: %s %s", hosts, cluster_name, version)

        IOLoop.current().add_callback(log_cluster, client)


# ------------------------
#   Low Level Functions
# ------------------------


class _AsyncConn(AIOHttpConnection):
    def __init__(self, *args, **kwargs):
        self.aws_auth = None
        _auth = kwargs.get("http_auth")
        if _auth and hasattr(_auth, "region") and isinstance(_auth, AWS4Auth):
            self.aws_auth = _auth
            kwargs["http_auth"] = None
        super().__init__(*args, **kwargs)

    async def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None):
        req = requests.PreparedRequest()
        req.prepare(method, self.host + url, headers, None, body, params)
        self.aws_auth(req)  # sign the request
        headers.update(req.headers)
        return await super().perform_request(method, url, params, body, timeout, ignore, headers)


# https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html
AWS_META_URL = "http://169.254.169.254/latest/dynamic/instance-identity/document"


[docs] def get_es_client(hosts=None, async_=False, **settings): """Enhanced ES client initialization. Additionally support these parameters: async_: use AsyncElasticserach instead of Elasticsearch. aws: setup request signing and provide reasonable ES settings to access AWS OpenSearch, by default assuming it is on HTTPS. sniff: provide resonable default settings to enable client-side LB to an ES cluster. this param itself is not an ES param. """ if settings.pop("aws", False): if not aws_avail: raise ImportError('"boto3" and "requests_aws4auth" are required for AWS OpenSearch') # find region session = boto3.Session() region = session.region_name if not region: # not in ~/.aws/config region = os.environ.get("AWS_REGION") if not region: # not in environment variable try: # assume same-region service access res = requests.get(AWS_META_URL) region = res.json()["region"] except Exception: # not running in VPC region = "us-west-2" # default # find credentials credentials = session.get_credentials() awsauth = AWS4Auth(refreshable_credentials=credentials, region=region, service="es") _cc = _AsyncConn if async_ else _Conn settings.update(http_auth=awsauth, connection_class=_cc) settings.setdefault("use_ssl", True) settings.setdefault("verify_certs", True) # not evaluated when 'aws' flag is set because # AWS OpenSearch is internally load-balanced # and does not support client-side sniffing. elif settings.pop("sniff", False): settings.setdefault("sniff_on_start", True) settings.setdefault("sniff_on_connection_fail", True) settings.setdefault("sniffer_timeout", 60) if async_: from elasticsearch import AsyncElasticsearch client = AsyncElasticsearch else: from elasticsearch import Elasticsearch client = Elasticsearch return client(hosts, **settings)
[docs] def get_sql_client(uri, **settings): from sqlalchemy import create_engine return create_engine(uri, **settings).connect()
[docs] def get_mongo_client(uri, **settings): from pymongo import MongoClient return MongoClient(uri, **settings).get_default_database()
def _not_implemented_client(): raise NotImplementedError() # ------------------------ # High Level Utilities # ------------------------
[docs] class _ClientPool: def __init__(self, client_factory, async_factory, callback=None): self._client_factory = client_factory self._clients = {} self._async_client_factory = async_factory self._async_clients = {} self.callback = callback or _log_db
[docs] @staticmethod def hash(config): _config = pickle.dumps(config) _hash = hashlib.md5(_config) return _hash.hexdigest()
def _get_client(self, repo, factory, uri, settings): hash = self.hash((uri, settings)) if hash in repo: return repo[hash] repo[hash] = factory(uri, **settings) self.callback(repo[hash], uri) return repo[hash]
[docs] def get_client(self, uri, **settings): return self._get_client(self._clients, self._client_factory, uri, settings)
[docs] def get_async_client(self, uri, **settings): return self._get_client(self._async_clients, self._async_client_factory, uri, settings)
es = _ClientPool(get_es_client, partial(get_es_client, async_=True), _log_es) sql = _ClientPool(get_sql_client, _not_implemented_client) mongo = _ClientPool(get_mongo_client, _not_implemented_client)