diff --git a/minerva/__init__.py b/minerva/__init__.py index 4c0d10b..0b41ca7 100644 --- a/minerva/__init__.py +++ b/minerva/__init__.py @@ -1,9 +1,13 @@ from .parallel import parallel_map from .athena import Athena from .redshift import Redshift +from .s3 import S3 +from .minerva import Minerva __all__ = [ "Athena", "Redshift", - "parallel_map" + "parallel_map", + "Minerva", + "S3" ] diff --git a/minerva/athena.py b/minerva/athena.py index 85b2332..3ed9a0d 100644 --- a/minerva/athena.py +++ b/minerva/athena.py @@ -11,45 +11,41 @@ from minerva import parallel_map pp = pprint.PrettyPrinter(indent=4) class Athena: - def __init__(self, profile, output): - self.session = boto3.session.Session(profile_name=profile) - self.athena = self.session.client("athena") + def __init__(self, handler, output): + self.handler = handler + self.client = handler.session.client("athena") self.output = output - def query(self, sql): - q = Query(self, sql) + def query(self, sql, params): + q = Query(self, sql, params) q.run() return q def execute(self, sql): - e = Execute(self, sql) + e = Execute(self, sql, params) e.run() return e - def download(self, s3): - bucket = s3.split("/")[2] - file = os.path.join(*s3.split("/")[3:]) - tmp = f"/tmp/{random.random()}.bin" - self.session.client('s3').download_file(bucket, file, tmp) - - return tmp class Execute: """ Execute will not return results, but will execute the SQL and return the final state. Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE """ - def __init__(self, handler, sql): - self.handler = handler - self.athena = handler.athena + def __init__(self, athena, sql, params=[]): + self.athena = athena + self.handler = athena.handler + self.client = athena.client self.sql = sql + self.params = [str(p) for p in params] self.info_cache = None def query(self): return self.sql def run(self): - resp = self.athena.start_query_execution(QueryString=self.query()) + resp = self.client.start_query_execution(QueryString=self.query(), + ExecutionParameters=self.params) self.query_id = resp['QueryExecutionId'] return resp @@ -57,7 +53,7 @@ class Execute: return self.info()['Status']['State'] def info(self): - res = self.athena.get_query_execution(QueryExecutionId=self.query_id) + res = self.client.get_query_execution(QueryExecutionId=self.query_id) self.info_cache = res['QueryExecution'] return self.info_cache @@ -74,7 +70,7 @@ class Query(Execute): DATA_STYLE = 'parquet' def query(self): - out = os.path.join(self.handler.output, + out = os.path.join(self.athena.output, str(random.random())) query = f"unload ({self.sql}) to {repr(out)} " + \ f"with (format = '{self.DATA_STYLE}')" @@ -91,7 +87,8 @@ class Query(Execute): # Because we're using `UNLOAD`, we get a manifest of the files # that make up our data. manif = self.info_cache['Statistics']['DataManifestLocation'] - tmp = self.handler.download(manif) + print(manif) + tmp = self.handler.s3.download(manif) with open(tmp, "r") as f: txt = f.read() @@ -102,10 +99,9 @@ class Query(Execute): else: return status # canceled or error - # TODO parallelize this def results(self): - #local = [self.handler.download(f) for f in self.manifest_files()] - local = parallel_map(self.handler.download, self.manifest_files()) + #local = [self.handler.s3.download(f) for f in self.manifest_files()] + local = parallel_map(self.handler.s3.download, self.manifest_files()) self.ds = pa.dataset.dataset(local) return self.ds diff --git a/minerva/redshift.py b/minerva/redshift.py index 92cacb2..beffce1 100644 --- a/minerva/redshift.py +++ b/minerva/redshift.py @@ -12,9 +12,9 @@ from minerva import parallel_map pp = pprint.PrettyPrinter(indent=4) class Redshift: - def __init__(self, profile, output, db=None, cluster=None): - self.session = boto3.session.Session(profile_name=profile) - self.redshift = self.session.client("redshift-data") + def __init__(self, handler, output, db=None, cluster=None): + self.handler = handler + self.client = handler.session.client("redshift-data") self.output = output self.database = db self.cluster = cluster @@ -29,22 +29,16 @@ class Redshift: e.run() return e - def download(self, s3): - bucket = s3.split("/")[2] - file = os.path.join(*s3.split("/")[3:]) - tmp = f"/tmp/{random.random()}.bin" - self.session.client('s3').download_file(bucket, file, tmp) - - return tmp class Execute: """ Execute will not return results, but will execute the SQL and return the final state. Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE """ - def __init__(self, handler, sql): - self.handler = handler - self.redshift = handler.redshift + def __init__(self, redshift, sql): + self.redshift = redshift + self.handler = redshift.handler + self.client = redshift.client self.sql = sql self.info_cache = None @@ -52,9 +46,9 @@ class Execute: return self.sql def run(self): - resp = self.redshift.execute_statement(Sql=self.query(), - Database=self.handler.database, - ClusterIdentifier=self.handler.cluster) + resp = self.client.execute_statement(Sql=self.query(), + Database=self.redshift.database, + ClusterIdentifier=self.redshift.cluster) self.query_id = resp['Id'] return resp @@ -62,7 +56,7 @@ class Execute: return self.info()['Status'] def info(self): - res = self.redshift.describe_statement(Id=self.query_id) + res = self.client.describe_statement(Id=self.query_id) self.info_cache = res return self.info_cache @@ -79,7 +73,7 @@ class Query(Execute): DATA_STYLE = 'parquet' def query(self): - self.out = os.path.join(self.handler.output, + self.out = os.path.join(self.redshift.output, str(random.random())) query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \ f"iam_role default " + \ @@ -97,7 +91,7 @@ class Query(Execute): # Because we're using `UNLOAD`, we get a manifest of the files # that make up our data. manif = self.out + "manifest" - tmp = self.handler.download(manif) + tmp = self.handler.s3.download(manif) with open(tmp, "r") as f: js = json.load(f) @@ -110,8 +104,8 @@ class Query(Execute): # TODO parallelize this def results(self): - #local = [self.handler.download(f) for f in self.manifest_files()] - local = parallel_map(self.handler.download, self.manifest_files()) + #local = [self.handler.s3.download(f) for f in self.manifest_files()] + local = parallel_map(self.handler.s3.download, self.manifest_files()) self.ds = pa.dataset.dataset(local) return self.ds diff --git a/minerva/s3.py b/minerva/s3.py index e69de29..4a694ce 100644 --- a/minerva/s3.py +++ b/minerva/s3.py @@ -0,0 +1,33 @@ +import boto3 +import os +import random + +class S3: + def __init__(self, handler): + self.handler = handler + self.s3 = handler.session.client("s3") + + def parse(self, uri): + bucket = uri.split("/")[2] + file = os.path.join(*uri.split("/")[3:]) + return bucket, file + + def download(self, uri, loc=None): + loc = loc or f"/tmp/{random.random()}.bin" + bucket, key = self.parse(uri) + self.s3.download_file(bucket, key, loc) + return loc + + def read(self, uri): + bucket, key = self.parse(uri) + file = self.s3.get_object(Bucket=bucket, Key=key) + return file['Body'].read().decode('utf-8') + + def upload(self, local, remote): + # If `remote` is a directory, pick a name for the file + if remote[-1] == "/": + remote = os.path.join(remote, os.path.basename(local)) + + bucket, key = self.parse(remote) + return self.s3.upload_file(local, bucket, key) + diff --git a/test.py b/test.py index 2819059..a87c713 100644 --- a/test.py +++ b/test.py @@ -1,9 +1,10 @@ -import minerva as m +import minerva import pprint pp = pprint.PrettyPrinter(indent=4) -athena = m.Athena("hay", "s3://haystac-pmo-athena/") +m = minerva.Minerva("hay") +athena = m.athena("s3://haystac-pmo-athena/") #query = athena.query( #"""SELECT * #FROM trajectories.kitware @@ -12,7 +13,7 @@ athena = m.Athena("hay", "s3://haystac-pmo-athena/") # ST_Point(longitude, latitude) #) #""") -query = athena.query("select count(*) as count from trajectories.kitware") +query = athena.query("select count(*) as count from trajectories.kitware where agent = ?", [4]) data = query.results() pp.pprint(data.head(10)) print(query.runtime) diff --git a/test2.py b/test2.py index eea8b39..ae6fcba 100644 --- a/test2.py +++ b/test2.py @@ -1,12 +1,13 @@ -import minerva as m +import minerva import pprint pp = pprint.PrettyPrinter(indent=4) -red = m.Redshift("hay", "s3://haystac-pmo-athena/", - db="dev", - cluster="redshift-cluster-1") -query = red.query("select count(*) from myspectrum_schema.kitware") +m = minerva.Minerva("hay") +red = m.redshift("s3://haystac-pmo-athena/", + db="dev", + cluster="redshift-cluster-1") +query = red.query("select count(*) from myspectrum_schema.kitware where agent = 4") data = query.results() pp.pprint(data.head(10)) print(query.runtime)