dropped pycache files, renaming others

This commit is contained in:
Ari Brown 2023-08-02 14:51:47 -04:00
parent 8765e3c428
commit 386cafe0dd
11 changed files with 184 additions and 68 deletions

View file

@ -1,7 +1,9 @@
from .minerva import Athena, Execute, Query from .minerva import Athena, Execute, Query, Redshift, parallel_map
__all__ = [ __all__ = [
"Execute", "Execute",
"Query", "Query",
"Athena" "Athena",
"Redshift",
"parallel_map"
] ]

View file

@ -6,6 +6,7 @@ import pyarrow as pa
import pyarrow.dataset import pyarrow.dataset
import pprint import pprint
import datetime import datetime
from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
@ -42,13 +43,13 @@ class Execute:
self.handler = handler self.handler = handler
self.athena = handler.athena self.athena = handler.athena
self.sql = sql self.sql = sql
self.info_cache = None
def query(self):
return self.sql
def run(self): def run(self):
out = os.path.join(self.handler.output, resp = self.athena.start_query_execution(QueryString=self.sql)
str(random.random()))
config = {"OutputLocation": out}
resp = self.athena.start_query_execution(QueryString=self.sql,
ResultConfiguration=config)
self.query_id = resp['QueryExecutionId'] self.query_id = resp['QueryExecutionId']
return resp return resp
@ -57,75 +58,52 @@ class Execute:
def info(self): def info(self):
res = self.athena.get_query_execution(QueryExecutionId=self.query_id) res = self.athena.get_query_execution(QueryExecutionId=self.query_id)
return res['QueryExecution'] self.info_cache = res['QueryExecution']
return self.info_cache
def execute(self): def finish(self):
tiedot = self.info() while stat := self.status() in ['QUEUED', 'RUNNING']:
status = tiedot['Status']['State']
while status in ['QUEUED', 'RUNNING']:
time.sleep(5) time.sleep(5)
tiedot = self.info()
status = tiedot['Status']['State']
return status # finalized state return stat # finalized state
class Query: class Query(Execute):
DATA_STYLE = 'parquet' DATA_STYLE = 'parquet'
def __init__(self, handler, sql): def query(self):
self.handler = handler
self.athena = handler.athena
self.sql = sql
def run(self):
out = os.path.join(self.handler.output, out = os.path.join(self.handler.output,
str(random.random())) str(random.random()))
config = {"OutputLocation": out}
query = f"unload ({self.sql}) to {repr(out)} " + \ query = f"unload ({self.sql}) to {repr(out)} " + \
f"with (format = '{self.DATA_STYLE}')" f"with (format = '{self.DATA_STYLE}')"
return query
resp = self.athena.start_query_execution(QueryString=query, def manifest_files(self):
ResultConfiguration=config) status = self.finish()
self.query_id = resp['QueryExecutionId']
return resp
def status(self):
return self.info()['Status']
def info(self):
res = self.athena.get_query_execution(QueryExecutionId=self.query_id)
return res['QueryExecution']
def results(self):
tiedot = self.info()
status = tiedot['Status']['State']
while status in ['QUEUED', 'RUNNING']:
time.sleep(5)
tiedot = self.info()
status = tiedot['Status']['State']
if status == "SUCCEEDED": if status == "SUCCEEDED":
# Because we're using `UNLOAD`, we get a manifest of the files # Track the runtime
# that make up our data. ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis']
files = self.manifest(tiedot).strip().split("\n")
files = [f.strip() for f in files if f.strip()] # filter empty
# TODO parallelize this
local = [self.handler.download(f) for f in files]
self.ds = pa.dataset.dataset(local)
ms = tiedot['Statistics']['TotalExecutionTimeInMillis']
self.runtime = datetime.timedelta(seconds=ms / 1000) self.runtime = datetime.timedelta(seconds=ms / 1000)
return self.ds # Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data.
manif = self.info_cache['Statistics']['DataManifestLocation']
tmp = self.handler.download(manif)
with open(tmp, "r") as f:
txt = f.read()
files = txt.strip().split("\n")
files = [f.strip() for f in files if f.strip()] # filter empty
return files
else: else:
return status # canceled or error return status # canceled or error
def manifest(self, tiedot): # TODO parallelize this
manif = tiedot['Statistics']['DataManifestLocation'] def results(self):
tmp = self.handler.download(manif) #local = [self.handler.download(f) for f in self.manifest_files()]
with open(tmp, "r") as f: local = parallel_map(self.handler.download, self.manifest_files())
return f.read() self.ds = pa.dataset.dataset(local)
return self.ds

21
minerva/parallel.py Normal file
View file

@ -0,0 +1,21 @@
from joblib import Parallel, delayed
# Instead of taking each object in the list and giving it its own thread,
# this splits the list into `cores` groups and gives each group its own
# thread, where the group is now processed in series within its thread.
#
# Example:
# def say(stuff):
# print(stuff)
#
# parallel_map(say, [str(i) for i in range(10)], cores=4)
def parallel_map(func=None, data=None, cores=8):
size = len(data) // cores
groups = [data[i:i + size] for i in range(0, len(data), size)]
def wrapper_func(fs):
return [func(f) for f in fs]
res = Parallel(n_jobs=cores)(delayed(wrapper_func)(group) for group in groups)
return [val for r in res for val in r]

114
minerva/redshift.py Normal file
View file

@ -0,0 +1,114 @@
import boto3
import os
import random
import time
import pyarrow as pa
import pyarrow.dataset
import pprint
import datetime
from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4)
class Redshift:
def __init__(self, profile, output, db=None, cluster=None):
self.session = boto3.session.Session(profile_name=profile)
self.redshift = self.session.client("redshift-data")
self.output = output
self.database = db
self.cluster = cluster
def query(self, sql):
q = Query(self, sql)
q.run()
return q
def execute(self, sql):
e = Execute(self, sql)
e.run()
return e
def download(self, s3):
bucket = s3.split("/")[2]
file = os.path.join(*s3.split("/")[3:])
tmp = f"/tmp/{random.random()}.bin"
self.session.client('s3').download_file(bucket, file, tmp)
return tmp
class Execute:
"""
Execute will not return results, but will execute the SQL and return the final state.
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
"""
def __init__(self, handler, sql):
self.handler = handler
self.redshift = handler.redshift
self.sql = sql
self.info_cache = None
def query(self):
return self.sql
def run(self):
resp = self.redshift.execute_statement(Sql=self.query(),
Database=self.handler.database,
ClusterIdentifier=self.handler.cluster)
self.query_id = resp['Id']
return resp
def status(self):
return self.info()['Status']
def info(self):
res = self.redshift.describe_statement(Id=self.query_id)
self.info_cache = res
return self.info_cache
def finish(self):
while stat := self.status() in ['SUBMITTED', 'PICKED', 'STARTED']:
time.sleep(5)
return stat # finalized state
class Query(Execute):
DATA_STYLE = 'parquet'
def query(self):
self.out = os.path.join(self.handler.output,
str(random.random()))
query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \
f"iam_role default " + \
f"format as {self.DATA_STYLE} " + \
f"manifest"
return query
def manifest_files(self):
status = self.finish()
if status == "SUCCEEDED":
# Track the runtime
self.runtime = tiedot['UpdatedAt'] - tiedot['CreatedAt']
# Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data.
manif = self.out + "manifest"
tmp = self.handler.download(manif)
with open(tmp, "r") as f:
js = json.load(f)
# Filter empty strings
files = [e['url'].strip() for e in js['entries'] if e['url'].strip()]
return files
else:
return status # canceled or error
# TODO parallelize this
def results(self):
#local = [self.handler.download(f) for f in self.manifest_files()]
local = parallel_map(self.handler.download, self.manifest_files())
self.ds = pa.dataset.dataset(local)
return self.ds

0
minerva/s3.py Normal file
View file

View file

@ -3,7 +3,8 @@ name = "minerva"
version = "0.2.1" version = "0.2.1"
description = "Easier access to AWS Athena and Redshift" description = "Easier access to AWS Athena and Redshift"
authors = [ authors = [
"Ari Brown <ari@airintech.com>" "Ari Brown <ari@airintech.com>",
"Roshan Punnoose <roshan.punnoose@jhuapl.edu>"
] ]
packages = [ packages = [
{ include = "minerva/**/*.py"} { include = "minerva/**/*.py"}

View file

@ -1,9 +1,9 @@
import minerva.minerva as a import minerva
import pprint import pprint
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
athena = a.Athena("hay", "s3://haystac-pmo-athena/") athena = Athena("hay", "s3://haystac-pmo-athena/")
#query = athena.query( #query = athena.query(
#"""SELECT * #"""SELECT *
#FROM trajectories.kitware #FROM trajectories.kitware

View file

@ -1,11 +1,11 @@
import minerva.blueshift as b import minerva
import pprint import pprint
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
red = b.Redshift("hay", "s3://haystac-pmo-athena/", red = Redshift("hay", "s3://haystac-pmo-athena/",
db="dev", db="dev",
cluster="redshift-cluster-1") cluster="redshift-cluster-1")
query = red.query("select count(*) from myspectrum_schema.kitware") query = red.query("select count(*) from myspectrum_schema.kitware")
print(query) print(query)
data = query.results() data = query.results()