diff --git a/minerva/__init__.py b/minerva/__init__.py index 1b2a66f..d9b4770 100644 --- a/minerva/__init__.py +++ b/minerva/__init__.py @@ -1,7 +1,9 @@ -from .minerva import Athena, Execute, Query +from .minerva import Athena, Execute, Query, Redshift, parallel_map __all__ = [ "Execute", "Query", - "Athena" + "Athena", + "Redshift", + "parallel_map" ] diff --git a/minerva/__pycache__/__init__.cpython-310.pyc b/minerva/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 2edd91a..0000000 Binary files a/minerva/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/minerva/__pycache__/blueshift.cpython-310.pyc b/minerva/__pycache__/blueshift.cpython-310.pyc deleted file mode 100644 index 3bddbfc..0000000 Binary files a/minerva/__pycache__/blueshift.cpython-310.pyc and /dev/null differ diff --git a/minerva/__pycache__/minerva.cpython-310.pyc b/minerva/__pycache__/minerva.cpython-310.pyc deleted file mode 100644 index 8beeb50..0000000 Binary files a/minerva/__pycache__/minerva.cpython-310.pyc and /dev/null differ diff --git a/minerva/minerva.py b/minerva/athena.py similarity index 50% rename from minerva/minerva.py rename to minerva/athena.py index da0ba0e..9c423d2 100644 --- a/minerva/minerva.py +++ b/minerva/athena.py @@ -6,6 +6,7 @@ import pyarrow as pa import pyarrow.dataset import pprint import datetime +from minerva import parallel_map pp = pprint.PrettyPrinter(indent=4) @@ -19,7 +20,7 @@ class Athena: q = Query(self, sql) q.run() return q - + def execute(self, sql): e = Execute(self, sql) e.run() @@ -42,13 +43,13 @@ class Execute: self.handler = handler self.athena = handler.athena self.sql = sql + self.info_cache = None + + def query(self): + return self.sql def run(self): - out = os.path.join(self.handler.output, - str(random.random())) - config = {"OutputLocation": out} - resp = self.athena.start_query_execution(QueryString=self.sql, - ResultConfiguration=config) + resp = self.athena.start_query_execution(QueryString=self.sql) self.query_id = resp['QueryExecutionId'] return resp @@ -57,75 +58,52 @@ class Execute: def info(self): res = self.athena.get_query_execution(QueryExecutionId=self.query_id) - return res['QueryExecution'] + self.info_cache = res['QueryExecution'] + return self.info_cache - def execute(self): - tiedot = self.info() - status = tiedot['Status']['State'] - - while status in ['QUEUED', 'RUNNING']: + def finish(self): + while stat := self.status() in ['QUEUED', 'RUNNING']: time.sleep(5) - tiedot = self.info() - status = tiedot['Status']['State'] - return status # finalized state + return stat # finalized state -class Query: +class Query(Execute): DATA_STYLE = 'parquet' - def __init__(self, handler, sql): - self.handler = handler - self.athena = handler.athena - self.sql = sql - - def run(self): + def query(self): out = os.path.join(self.handler.output, str(random.random())) - config = {"OutputLocation": out} query = f"unload ({self.sql}) to {repr(out)} " + \ f"with (format = '{self.DATA_STYLE}')" + return query - resp = self.athena.start_query_execution(QueryString=query, - ResultConfiguration=config) - self.query_id = resp['QueryExecutionId'] - return resp - - def status(self): - return self.info()['Status'] - - def info(self): - res = self.athena.get_query_execution(QueryExecutionId=self.query_id) - return res['QueryExecution'] - - def results(self): - tiedot = self.info() - status = tiedot['Status']['State'] - - while status in ['QUEUED', 'RUNNING']: - time.sleep(5) - tiedot = self.info() - status = tiedot['Status']['State'] + def manifest_files(self): + status = self.finish() if status == "SUCCEEDED": - # Because we're using `UNLOAD`, we get a manifest of the files - # that make up our data. - files = self.manifest(tiedot).strip().split("\n") - files = [f.strip() for f in files if f.strip()] # filter empty - - # TODO parallelize this - local = [self.handler.download(f) for f in files] - self.ds = pa.dataset.dataset(local) - - ms = tiedot['Statistics']['TotalExecutionTimeInMillis'] + # Track the runtime + ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis'] self.runtime = datetime.timedelta(seconds=ms / 1000) - return self.ds + # Because we're using `UNLOAD`, we get a manifest of the files + # that make up our data. + manif = self.info_cache['Statistics']['DataManifestLocation'] + tmp = self.handler.download(manif) + with open(tmp, "r") as f: + txt = f.read() + + files = txt.strip().split("\n") + files = [f.strip() for f in files if f.strip()] # filter empty + + return files else: return status # canceled or error - def manifest(self, tiedot): - manif = tiedot['Statistics']['DataManifestLocation'] - tmp = self.handler.download(manif) - with open(tmp, "r") as f: - return f.read() + # TODO parallelize this + def results(self): + #local = [self.handler.download(f) for f in self.manifest_files()] + local = parallel_map(self.handler.download, self.manifest_files()) + self.ds = pa.dataset.dataset(local) + return self.ds + diff --git a/minerva/parallel.py b/minerva/parallel.py new file mode 100644 index 0000000..fd44c7f --- /dev/null +++ b/minerva/parallel.py @@ -0,0 +1,21 @@ +from joblib import Parallel, delayed + +# Instead of taking each object in the list and giving it its own thread, +# this splits the list into `cores` groups and gives each group its own +# thread, where the group is now processed in series within its thread. +# +# Example: +# def say(stuff): +# print(stuff) +# +# parallel_map(say, [str(i) for i in range(10)], cores=4) +def parallel_map(func=None, data=None, cores=8): + size = len(data) // cores + groups = [data[i:i + size] for i in range(0, len(data), size)] + + def wrapper_func(fs): + return [func(f) for f in fs] + + res = Parallel(n_jobs=cores)(delayed(wrapper_func)(group) for group in groups) + + return [val for r in res for val in r] diff --git a/minerva/redshift.py b/minerva/redshift.py new file mode 100644 index 0000000..f7cf971 --- /dev/null +++ b/minerva/redshift.py @@ -0,0 +1,114 @@ +import boto3 +import os +import random +import time +import pyarrow as pa +import pyarrow.dataset +import pprint +import datetime +from minerva import parallel_map + +pp = pprint.PrettyPrinter(indent=4) + +class Redshift: + def __init__(self, profile, output, db=None, cluster=None): + self.session = boto3.session.Session(profile_name=profile) + self.redshift = self.session.client("redshift-data") + self.output = output + self.database = db + self.cluster = cluster + + def query(self, sql): + q = Query(self, sql) + q.run() + return q + + def execute(self, sql): + e = Execute(self, sql) + e.run() + return e + + def download(self, s3): + bucket = s3.split("/")[2] + file = os.path.join(*s3.split("/")[3:]) + tmp = f"/tmp/{random.random()}.bin" + self.session.client('s3').download_file(bucket, file, tmp) + + return tmp + +class Execute: + """ + Execute will not return results, but will execute the SQL and return the final state. + Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE + """ + def __init__(self, handler, sql): + self.handler = handler + self.redshift = handler.redshift + self.sql = sql + self.info_cache = None + + def query(self): + return self.sql + + def run(self): + resp = self.redshift.execute_statement(Sql=self.query(), + Database=self.handler.database, + ClusterIdentifier=self.handler.cluster) + self.query_id = resp['Id'] + return resp + + def status(self): + return self.info()['Status'] + + def info(self): + res = self.redshift.describe_statement(Id=self.query_id) + self.info_cache = res + return self.info_cache + + def finish(self): + while stat := self.status() in ['SUBMITTED', 'PICKED', 'STARTED']: + time.sleep(5) + + return stat # finalized state + + +class Query(Execute): + DATA_STYLE = 'parquet' + + def query(self): + self.out = os.path.join(self.handler.output, + str(random.random())) + query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \ + f"iam_role default " + \ + f"format as {self.DATA_STYLE} " + \ + f"manifest" + return query + + def manifest_files(self): + status = self.finish() + + if status == "SUCCEEDED": + # Track the runtime + self.runtime = tiedot['UpdatedAt'] - tiedot['CreatedAt'] + + # Because we're using `UNLOAD`, we get a manifest of the files + # that make up our data. + manif = self.out + "manifest" + tmp = self.handler.download(manif) + with open(tmp, "r") as f: + js = json.load(f) + + # Filter empty strings + files = [e['url'].strip() for e in js['entries'] if e['url'].strip()] + + return files + else: + return status # canceled or error + + # TODO parallelize this + def results(self): + #local = [self.handler.download(f) for f in self.manifest_files()] + local = parallel_map(self.handler.download, self.manifest_files()) + self.ds = pa.dataset.dataset(local) + return self.ds + diff --git a/minerva/s3.py b/minerva/s3.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 62631c5..6b127d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,8 @@ name = "minerva" version = "0.2.1" description = "Easier access to AWS Athena and Redshift" authors = [ - "Ari Brown " + "Ari Brown ", + "Roshan Punnoose " ] packages = [ { include = "minerva/**/*.py"} diff --git a/test.py b/test.py index 8cf7a3f..eb64c67 100644 --- a/test.py +++ b/test.py @@ -1,9 +1,9 @@ -import minerva.minerva as a +import minerva import pprint pp = pprint.PrettyPrinter(indent=4) -athena = a.Athena("hay", "s3://haystac-pmo-athena/") +athena = Athena("hay", "s3://haystac-pmo-athena/") #query = athena.query( #"""SELECT * #FROM trajectories.kitware diff --git a/test2.py b/test2.py index c108da7..41d2dea 100644 --- a/test2.py +++ b/test2.py @@ -1,11 +1,11 @@ -import minerva.blueshift as b +import minerva import pprint pp = pprint.PrettyPrinter(indent=4) -red = b.Redshift("hay", "s3://haystac-pmo-athena/", - db="dev", - cluster="redshift-cluster-1") +red = Redshift("hay", "s3://haystac-pmo-athena/", + db="dev", + cluster="redshift-cluster-1") query = red.query("select count(*) from myspectrum_schema.kitware") print(query) data = query.results()