forked from bellwether/minerva
parameterized queries for athena, but not for redshift.
This commit is contained in:
parent
c4e0b71a98
commit
52f44a79fe
6 changed files with 82 additions and 53 deletions
|
|
@ -1,9 +1,13 @@
|
|||
from .parallel import parallel_map
|
||||
from .athena import Athena
|
||||
from .redshift import Redshift
|
||||
from .s3 import S3
|
||||
from .minerva import Minerva
|
||||
|
||||
__all__ = [
|
||||
"Athena",
|
||||
"Redshift",
|
||||
"parallel_map"
|
||||
"parallel_map",
|
||||
"Minerva",
|
||||
"S3"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,45 +11,41 @@ from minerva import parallel_map
|
|||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
class Athena:
|
||||
def __init__(self, profile, output):
|
||||
self.session = boto3.session.Session(profile_name=profile)
|
||||
self.athena = self.session.client("athena")
|
||||
def __init__(self, handler, output):
|
||||
self.handler = handler
|
||||
self.client = handler.session.client("athena")
|
||||
self.output = output
|
||||
|
||||
def query(self, sql):
|
||||
q = Query(self, sql)
|
||||
def query(self, sql, params):
|
||||
q = Query(self, sql, params)
|
||||
q.run()
|
||||
return q
|
||||
|
||||
def execute(self, sql):
|
||||
e = Execute(self, sql)
|
||||
e = Execute(self, sql, params)
|
||||
e.run()
|
||||
return e
|
||||
|
||||
def download(self, s3):
|
||||
bucket = s3.split("/")[2]
|
||||
file = os.path.join(*s3.split("/")[3:])
|
||||
tmp = f"/tmp/{random.random()}.bin"
|
||||
self.session.client('s3').download_file(bucket, file, tmp)
|
||||
|
||||
return tmp
|
||||
|
||||
class Execute:
|
||||
"""
|
||||
Execute will not return results, but will execute the SQL and return the final state.
|
||||
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
|
||||
"""
|
||||
def __init__(self, handler, sql):
|
||||
self.handler = handler
|
||||
self.athena = handler.athena
|
||||
def __init__(self, athena, sql, params=[]):
|
||||
self.athena = athena
|
||||
self.handler = athena.handler
|
||||
self.client = athena.client
|
||||
self.sql = sql
|
||||
self.params = [str(p) for p in params]
|
||||
self.info_cache = None
|
||||
|
||||
def query(self):
|
||||
return self.sql
|
||||
|
||||
def run(self):
|
||||
resp = self.athena.start_query_execution(QueryString=self.query())
|
||||
resp = self.client.start_query_execution(QueryString=self.query(),
|
||||
ExecutionParameters=self.params)
|
||||
self.query_id = resp['QueryExecutionId']
|
||||
return resp
|
||||
|
||||
|
|
@ -57,7 +53,7 @@ class Execute:
|
|||
return self.info()['Status']['State']
|
||||
|
||||
def info(self):
|
||||
res = self.athena.get_query_execution(QueryExecutionId=self.query_id)
|
||||
res = self.client.get_query_execution(QueryExecutionId=self.query_id)
|
||||
self.info_cache = res['QueryExecution']
|
||||
return self.info_cache
|
||||
|
||||
|
|
@ -74,7 +70,7 @@ class Query(Execute):
|
|||
DATA_STYLE = 'parquet'
|
||||
|
||||
def query(self):
|
||||
out = os.path.join(self.handler.output,
|
||||
out = os.path.join(self.athena.output,
|
||||
str(random.random()))
|
||||
query = f"unload ({self.sql}) to {repr(out)} " + \
|
||||
f"with (format = '{self.DATA_STYLE}')"
|
||||
|
|
@ -91,7 +87,8 @@ class Query(Execute):
|
|||
# Because we're using `UNLOAD`, we get a manifest of the files
|
||||
# that make up our data.
|
||||
manif = self.info_cache['Statistics']['DataManifestLocation']
|
||||
tmp = self.handler.download(manif)
|
||||
print(manif)
|
||||
tmp = self.handler.s3.download(manif)
|
||||
with open(tmp, "r") as f:
|
||||
txt = f.read()
|
||||
|
||||
|
|
@ -102,10 +99,9 @@ class Query(Execute):
|
|||
else:
|
||||
return status # canceled or error
|
||||
|
||||
# TODO parallelize this
|
||||
def results(self):
|
||||
#local = [self.handler.download(f) for f in self.manifest_files()]
|
||||
local = parallel_map(self.handler.download, self.manifest_files())
|
||||
#local = [self.handler.s3.download(f) for f in self.manifest_files()]
|
||||
local = parallel_map(self.handler.s3.download, self.manifest_files())
|
||||
self.ds = pa.dataset.dataset(local)
|
||||
return self.ds
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ from minerva import parallel_map
|
|||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
class Redshift:
|
||||
def __init__(self, profile, output, db=None, cluster=None):
|
||||
self.session = boto3.session.Session(profile_name=profile)
|
||||
self.redshift = self.session.client("redshift-data")
|
||||
def __init__(self, handler, output, db=None, cluster=None):
|
||||
self.handler = handler
|
||||
self.client = handler.session.client("redshift-data")
|
||||
self.output = output
|
||||
self.database = db
|
||||
self.cluster = cluster
|
||||
|
|
@ -29,22 +29,16 @@ class Redshift:
|
|||
e.run()
|
||||
return e
|
||||
|
||||
def download(self, s3):
|
||||
bucket = s3.split("/")[2]
|
||||
file = os.path.join(*s3.split("/")[3:])
|
||||
tmp = f"/tmp/{random.random()}.bin"
|
||||
self.session.client('s3').download_file(bucket, file, tmp)
|
||||
|
||||
return tmp
|
||||
|
||||
class Execute:
|
||||
"""
|
||||
Execute will not return results, but will execute the SQL and return the final state.
|
||||
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
|
||||
"""
|
||||
def __init__(self, handler, sql):
|
||||
self.handler = handler
|
||||
self.redshift = handler.redshift
|
||||
def __init__(self, redshift, sql):
|
||||
self.redshift = redshift
|
||||
self.handler = redshift.handler
|
||||
self.client = redshift.client
|
||||
self.sql = sql
|
||||
self.info_cache = None
|
||||
|
||||
|
|
@ -52,9 +46,9 @@ class Execute:
|
|||
return self.sql
|
||||
|
||||
def run(self):
|
||||
resp = self.redshift.execute_statement(Sql=self.query(),
|
||||
Database=self.handler.database,
|
||||
ClusterIdentifier=self.handler.cluster)
|
||||
resp = self.client.execute_statement(Sql=self.query(),
|
||||
Database=self.redshift.database,
|
||||
ClusterIdentifier=self.redshift.cluster)
|
||||
self.query_id = resp['Id']
|
||||
return resp
|
||||
|
||||
|
|
@ -62,7 +56,7 @@ class Execute:
|
|||
return self.info()['Status']
|
||||
|
||||
def info(self):
|
||||
res = self.redshift.describe_statement(Id=self.query_id)
|
||||
res = self.client.describe_statement(Id=self.query_id)
|
||||
self.info_cache = res
|
||||
return self.info_cache
|
||||
|
||||
|
|
@ -79,7 +73,7 @@ class Query(Execute):
|
|||
DATA_STYLE = 'parquet'
|
||||
|
||||
def query(self):
|
||||
self.out = os.path.join(self.handler.output,
|
||||
self.out = os.path.join(self.redshift.output,
|
||||
str(random.random()))
|
||||
query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \
|
||||
f"iam_role default " + \
|
||||
|
|
@ -97,7 +91,7 @@ class Query(Execute):
|
|||
# Because we're using `UNLOAD`, we get a manifest of the files
|
||||
# that make up our data.
|
||||
manif = self.out + "manifest"
|
||||
tmp = self.handler.download(manif)
|
||||
tmp = self.handler.s3.download(manif)
|
||||
with open(tmp, "r") as f:
|
||||
js = json.load(f)
|
||||
|
||||
|
|
@ -110,8 +104,8 @@ class Query(Execute):
|
|||
|
||||
# TODO parallelize this
|
||||
def results(self):
|
||||
#local = [self.handler.download(f) for f in self.manifest_files()]
|
||||
local = parallel_map(self.handler.download, self.manifest_files())
|
||||
#local = [self.handler.s3.download(f) for f in self.manifest_files()]
|
||||
local = parallel_map(self.handler.s3.download, self.manifest_files())
|
||||
self.ds = pa.dataset.dataset(local)
|
||||
return self.ds
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
import boto3
|
||||
import os
|
||||
import random
|
||||
|
||||
class S3:
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
self.s3 = handler.session.client("s3")
|
||||
|
||||
def parse(self, uri):
|
||||
bucket = uri.split("/")[2]
|
||||
file = os.path.join(*uri.split("/")[3:])
|
||||
return bucket, file
|
||||
|
||||
def download(self, uri, loc=None):
|
||||
loc = loc or f"/tmp/{random.random()}.bin"
|
||||
bucket, key = self.parse(uri)
|
||||
self.s3.download_file(bucket, key, loc)
|
||||
return loc
|
||||
|
||||
def read(self, uri):
|
||||
bucket, key = self.parse(uri)
|
||||
file = self.s3.get_object(Bucket=bucket, Key=key)
|
||||
return file['Body'].read().decode('utf-8')
|
||||
|
||||
def upload(self, local, remote):
|
||||
# If `remote` is a directory, pick a name for the file
|
||||
if remote[-1] == "/":
|
||||
remote = os.path.join(remote, os.path.basename(local))
|
||||
|
||||
bucket, key = self.parse(remote)
|
||||
return self.s3.upload_file(local, bucket, key)
|
||||
|
||||
7
test.py
7
test.py
|
|
@ -1,9 +1,10 @@
|
|||
import minerva as m
|
||||
import minerva
|
||||
import pprint
|
||||
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
athena = m.Athena("hay", "s3://haystac-pmo-athena/")
|
||||
m = minerva.Minerva("hay")
|
||||
athena = m.athena("s3://haystac-pmo-athena/")
|
||||
#query = athena.query(
|
||||
#"""SELECT *
|
||||
#FROM trajectories.kitware
|
||||
|
|
@ -12,7 +13,7 @@ athena = m.Athena("hay", "s3://haystac-pmo-athena/")
|
|||
# ST_Point(longitude, latitude)
|
||||
#)
|
||||
#""")
|
||||
query = athena.query("select count(*) as count from trajectories.kitware")
|
||||
query = athena.query("select count(*) as count from trajectories.kitware where agent = ?", [4])
|
||||
data = query.results()
|
||||
pp.pprint(data.head(10))
|
||||
print(query.runtime)
|
||||
|
|
|
|||
7
test2.py
7
test2.py
|
|
@ -1,12 +1,13 @@
|
|||
import minerva as m
|
||||
import minerva
|
||||
import pprint
|
||||
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
red = m.Redshift("hay", "s3://haystac-pmo-athena/",
|
||||
m = minerva.Minerva("hay")
|
||||
red = m.redshift("s3://haystac-pmo-athena/",
|
||||
db="dev",
|
||||
cluster="redshift-cluster-1")
|
||||
query = red.query("select count(*) from myspectrum_schema.kitware")
|
||||
query = red.query("select count(*) from myspectrum_schema.kitware where agent = 4")
|
||||
data = query.results()
|
||||
pp.pprint(data.head(10))
|
||||
print(query.runtime)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue