diff --git a/examples/cancel_query.py b/examples/cancel_query.py new file mode 100644 index 0000000..bcfe65c --- /dev/null +++ b/examples/cancel_query.py @@ -0,0 +1,17 @@ +import minerva + +m = minerva.Minerva("hay-te") + +athena = m.athena("s3://haystac-te-athena/") +file = "/tmp/queries.txt" + +with open(file, 'r') as f: + txt = f.read() + +for line in txt.split("\n"): + if not line: + continue + + print(line) + athena.cancel(line) + diff --git a/minerva/athena.py b/minerva/athena.py index b919848..099bd27 100644 --- a/minerva/athena.py +++ b/minerva/athena.py @@ -1,4 +1,5 @@ import boto3 +import math import os import random import time @@ -8,9 +9,20 @@ import pprint import datetime import dask.dataframe as dd from minerva import parallel_map +from mako.template import Template pp = pprint.PrettyPrinter(indent=4) +# Get full path of fname +def local(fname): + return os.path.join(os.path.abspath(os.path.dirname(__file__)), fname) + +def load_sql(path, **params): + with open(path, 'r') as f: + query = f.read() + + return Template(query).render(**params) + class Athena: def __init__(self, handler, output=None): self.handler = handler @@ -32,6 +44,25 @@ class Athena: e.run() return e + def parallelize(self, *args, **kwargs): + p = Parallelize(self, *args, **kwargs) + return p + + def delete_table(self, table, join=True): + e = Execute(self, f"drop table {table}") + e.run() + + if join: + e.finish() + + s3_uri = os.path.join(self.output, table, "") + self.handler.s3.rm(s3_uri) + + return e + + def cancel(self, query_id): + return self.client.stop_query_execution(QueryExecutionId=query_id) + class Execute: """ @@ -39,16 +70,17 @@ class Execute: Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE """ def __init__(self, athena, sql, params=[], format='parquet'): - self.athena = athena - self.handler = athena.handler - self.client = athena.client - self.sql = sql - self.params = [str(p) for p in params] + self.athena = athena + self.handler = athena.handler + self.client = athena.client + self.sql = sql + self.params = [str(p) for p in params] self.info_cache = None - self.temps = [] - self.ds = None - self.files = None - self.format = format + self.temps = [] + self.ds = None + self.files = None + self.format = format + self.finished = None # The string of the query @@ -86,6 +118,9 @@ class Execute: # Block until the SQL has finished running def finish(self): + if self.finished: + return self.finished + stat = self.status() while stat in ['QUEUED', 'RUNNING']: time.sleep(5) @@ -97,6 +132,8 @@ class Execute: scanned = self.info_cache['Statistics']['DataScannedInBytes'] self.cost = 5.0 * scanned / (1024 ** 4) # $5/TB scanned + self.finished = stat + return stat # finalized state @@ -190,3 +227,81 @@ class Query(Execute): for file in self.temps: os.remove(file) + +class Parallelize: + UNION_TABLES = local("athena/union_tables.sql") + + def __init__(self, athena, dest=None, data=[], n=1): + self.athena = athena + self.dest = dest + self.data = list(data) + self.n = n + self.tables = [] + self.queries = [] + self.runtime = None + self.cost = None + + + def __lshift__(self, res): + self.queries.append(res) + + + def __iter__(self): + n = min(self.n, len(self.data)) + size = math.ceil(len(self.data) / n) + + self.groups = [self.data[i:i + size] for i in range(0, len(self.data), size)] + self.current = 0 + + return self + + + def __next__(self): + if self.current >= len(self.groups): + if self.dest: + self.union_tables(self.dest).finish() + self.clear_temp_tables() + raise StopIteration + + # temp table name, in case it's needed + tmp = "temp_" + str(round(random.random() * 10_000_000)) + self.tables.append(tmp) + + obj = self.groups[self.current] + self.current += 1 + + return tmp, obj + + + # runs the SQL and removes the S3 data + def clear_temp_tables(self): + qs = [self.athena.delete_table(table, join=False) for table in self.tables] + for q in qs: + q.finish() + + + def finish(self): + for q in self.queries: + q.finish() + + self.cost = sum([q.cost for q in self.queries]) + self.runtime = max([q.runtime for q in self.queries]) + + def results(self): + self.finish() + + return pa.dataset.dataset([q.results().files for q in self.queries]) + + + def union_tables(self, dest): + self.finish() + + lines = [f"select * from {table}" for table in self.tables] + tables = ' union all '.join(lines) + + out = os.path.join(self.athena.output, dest) + sql = load_sql(self.UNION_TABLES, dest = dest, + output = out, + tables = tables) + return self.athena.execute(sql) + diff --git a/minerva/minerva.py b/minerva/minerva.py index 3a4095a..c0c7028 100644 --- a/minerva/minerva.py +++ b/minerva/minerva.py @@ -18,3 +18,7 @@ class Minerva: def pier(self, *args, **kwargs): return m.Pier(self, *args, **kwargs) + + def s3(self, *args, **kwargs): + return m.S3(self, *args, **kwargs) + diff --git a/minerva/s3.py b/minerva/s3.py index e4010fb..69556f6 100644 --- a/minerva/s3.py +++ b/minerva/s3.py @@ -36,3 +36,8 @@ class S3: #self.s3.list_objects_v2(Bucket=bucket, Prefix=key) return self.resource.Bucket(bucket).objects.filter(Prefix=key, **kwargs) + def rm(self, uri, **kwargs): + fs = self.ls(uri) + fs.delete() + return fs +