import boto3 import os import random import time import pyarrow as pa import pyarrow.dataset import pprint pp = pprint.PrettyPrinter(indent=4) class Athena: def __init__(self, profile, output): self.session = boto3.session.Session(profile_name=profile) self.athena = self.session.client("athena") self.output = output def query(self, sql): q = Query(self, sql) q.run() return q def download(self, s3): bucket = s3.split("/")[2] file = os.path.join(*s3.split("/")[3:]) tmp = f"/tmp/{random.random()}.bin" self.session.client('s3').download_file(bucket, file, tmp) return tmp class Query: DATA_STYLE = 'parquet' def __init__(self, handler, sql): self.handler = handler self.athena = handler.athena self.sql = sql def run(self): out = os.path.join(self.handler.output, str(random.random())) config = {"OutputLocation": out} query = f"unload ({self.sql}) to {repr(out)} " + \ f"with (format = '{self.DATA_STYLE}')" resp = self.athena.start_query_execution(QueryString=query, ResultConfiguration=config) self.query_id = resp['QueryExecutionId'] return resp def status(self): return self.info()['Status'] def info(self): res = self.athena.get_query_execution(QueryExecutionId=self.query_id) return res['QueryExecution'] def results(self): tiedot = self.info() status = tiedot['Status']['State'] while status in ['QUEUED', 'RUNNING']: time.sleep(5) status = self.status()['State'] if status == "SUCCEEDED": # Because we're using `UNLOAD`, we get a manifest of the files # that make up our data. files = self.manifest(tiedot).strip().split("\n") local = [self.handler.download(f) for f in files] self.ds = pa.dataset.dataset(local) return self.ds else: return status # canceled or error def manifest(self, tiedot): manif = tiedot['Statistics']['DataManifestLocation'] tmp = self.handler.download(manif) with open(tmp, "r") as f: return f.read()