import boto3 import math import os import random import time import pyarrow as pa import pyarrow.dataset import pprint import datetime import dask.dataframe as dd from minerva import parallel_map, load_template pp = pprint.PrettyPrinter(indent=4) def local(path): return os.path.join(os.path.dirname(os.path.abspath(__file__)), path) class Athena: def __init__(self, handler, output=None): self.handler = handler self.client = handler.session.client("athena") self.output = output # For when you want to receive the results of something def query(self, sql, params=[], format='parquet'): q = Query(self, sql, params, format) q.run() return q # For when you want to send a query to run, but there aren't results # (like a DML query for creating databases and tables etc) def execute(self, sql, params=[], format=None): e = Execute(self, sql, params, format) e.run() return e def parallelize(self, *args, **kwargs): p = Parallelize(self, *args, **kwargs) return p def delete_table(self, table, join=True): e = Execute(self, f"drop table {table}") e.run() if join: e.finish() s3_uri = os.path.join(self.output, table, "") self.handler.s3.rm(s3_uri) return e def cancel(self, query_id): return self.client.stop_query_execution(QueryExecutionId=query_id) class Execute: """ Execute will not return results, but will execute the SQL and return the final state. Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE """ def __init__(self, athena, sql, params=[], format='parquet'): self.athena = athena self.handler = athena.handler self.client = athena.client self.sql = sql self.params = [str(p) for p in params] self.info_cache = None self.temps = [] self.ds = None self.files = None self.format = format self.finished = None # The string of the query def query(self): return self.sql # Send the SQL to Athena for running def run(self): config = {"OutputLocation": os.path.join(self.athena.output, "results")} if self.params: resp = self.client.start_query_execution(QueryString=self.query(), ResultConfiguration=config, ExecutionParameters=self.params) else: resp = self.client.start_query_execution(QueryString=self.query(), ResultConfiguration=config) self.query_id = resp['QueryExecutionId'] return resp # The status of the SQL (running, queued, succeeded, etc.) def status(self): return self.info()['Status']['State'] # Get the basic information on the SQL def info(self): res = self.client.get_query_execution(QueryExecutionId=self.query_id) self.info_cache = res['QueryExecution'] return self.info_cache # Block until the SQL has finished running def finish(self): if self.finished: return self.finished stat = self.status() while stat in ['QUEUED', 'RUNNING']: time.sleep(5) stat = self.status() ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis'] self.runtime = datetime.timedelta(seconds=ms / 1000) scanned = self.info_cache['Statistics']['DataScannedInBytes'] self.cost = 5.0 * scanned / (1024 ** 4) # $5/TB scanned self.finished = stat return stat # finalized state class Query(Execute): # Automatically includes unloading the results to Parquet format def query(self): out = os.path.join(self.athena.output, "results", str(random.random())) query = f"unload ({self.sql}) to {repr(out)} " + \ f"with (format = '{self.format}', compression ='zstd', compression_level = 4)" return query # Gets the files that are listed in the manifest (from the UNLOAD part of # the statement) # Blocks until the query has finished (because it calls `self.finish()`) def manifest_files(self): if self.files: return self.files status = self.finish() if status == "SUCCEEDED": # Because we're using `UNLOAD`, we get a manifest of the files # that make up our data. manif = self.info_cache['Statistics']['DataManifestLocation'] files = self.handler.s3.read(manif).split("\n") files = [f.strip() for f in files if f.strip()] # filter empty self.files = files return files else: print("Error") print(self.info_cache['Status']['AthenaError']['ErrorMessage']) raise Exception(self.info_cache['Status']['AthenaError']['ErrorMessage']) #return status # canceled or error # Blocks until the query has finished running and then returns you a pyarrow # dataset of the results. # Calls `self.manifest_files()` which blocks via `self.finish()` def results(self): if self.ds: return self.ds self.temps = [self.handler.s3.download(f) for f in self.manifest_files()] #local = parallel_map(self.handler.s3.download, self.manifest_files()) self.ds = pa.dataset.dataset(self.temps) return self.ds def distribute_results(self, client, size=10000): import dask.dataframe as dd import pandas as pd if self.ds: return self.ds futures = [] print(f"{len(self.manifest_files())} files in manifest") for fn in self.manifest_files(): print(f"reading {fn}...") df = dd.from_pandas(pd.read_parquet(fn), chunksize=100000) print(df._meta) print("\tloaded") future = client.scatter(df) print("\tscattered") futures.append(future) return dd.from_delayed(futures, meta=df) # Return scalar results # Abstracts away a bunch of keystrokes def scalar(self): return self.results().head(1)[0][0].as_py() def __enter__(self): return self def __exit__(self, exception_type, exception_value, exception_traceback): self.close() def close(self): if self.temps: for file in self.temps: os.remove(file) class Parallelize: UNION_TABLES = local("athena/union_tables.sql") def __init__(self, athena, dest=None, data=[], n=1): self.athena = athena self.dest = dest self.data = list(data) self.n = n self.tables = [] self.queries = [] self.runtime = None self.cost = None def __lshift__(self, res): self.queries.append(res) def __iter__(self): n = min(self.n, len(self.data)) size = math.ceil(len(self.data) / n) self.groups = [self.data[i:i + size] for i in range(0, len(self.data), size)] self.current = 0 return self def __next__(self): if self.current >= len(self.groups): if self.dest: self.union_tables(self.dest).finish() self.clear_temp_tables() raise StopIteration # temp table name, in case it's needed tmp = "temp_" + str(round(random.random() * 10_000_000)) self.tables.append(tmp) obj = self.groups[self.current] self.current += 1 return tmp, obj # runs the SQL and removes the S3 data def clear_temp_tables(self): qs = [self.athena.delete_table(table, join=False) for table in self.tables] for q in qs: q.finish() def finish(self): for q in self.queries: q.finish() self.cost = sum([q.cost for q in self.queries]) self.runtime = max([q.runtime for q in self.queries]) def results(self): self.finish() return pa.dataset.dataset([q.results().files for q in self.queries]) def union_tables(self, dest): self.finish() lines = [f"select * from {table}" for table in self.tables] tables = ' union all '.join(lines) out = os.path.join(self.athena.output, dest) sql = load_sql(self.UNION_TABLES, dest = dest, output = out, tables = tables) return self.athena.execute(sql)