diff --git a/minerva/athena.py b/minerva/athena.py index 5073a92..a39e942 100644 --- a/minerva/athena.py +++ b/minerva/athena.py @@ -6,6 +6,7 @@ import pyarrow as pa import pyarrow.dataset import pprint import datetime +import dask.dataframe as dd from minerva import parallel_map pp = pprint.PrettyPrinter(indent=4) @@ -16,12 +17,14 @@ class Athena: self.client = handler.session.client("athena") self.output = output + # For when you want to receive the results of something def query(self, sql, params=[]): q = Query(self, sql, params) q.run() return q + # For when you want to send a query to run, but there aren't results # (like a DML query for creating databases and tables etc) def execute(self, sql, params=[]): @@ -29,6 +32,7 @@ class Athena: e.run() return e + class Execute: """ Execute will not return results, but will execute the SQL and return the final state. @@ -44,10 +48,12 @@ class Execute: self.temps = [] self.ds = None + # The string of the query def query(self): return self.sql + # Send the SQL to Athena for running def run(self): if self.__class__ == Query: @@ -66,16 +72,19 @@ class Execute: self.query_id = resp['QueryExecutionId'] return resp + # The status of the SQL (running, queued, succeeded, etc.) def status(self): return self.info()['Status']['State'] + # Get the basic information on the SQL def info(self): res = self.client.get_query_execution(QueryExecutionId=self.query_id) self.info_cache = res['QueryExecution'] return self.info_cache + # Block until the SQL has finished running def finish(self): stat = self.status() @@ -104,6 +113,7 @@ class Query(Execute): f"with (format = '{self.DATA_STYLE}')" return query + # Gets the files that are listed in the manifest (from the UNLOAD part of # the statement) # Blocks until the query has finished (because it calls `self.finish()`) @@ -128,6 +138,7 @@ class Query(Execute): raise Exception(self.info_cache['Status']['AthenaError']['ErrorMessage']) #return status # canceled or error + # Blocks until the query has finished running and then returns you a pyarrow # dataset of the results. # Calls `self.manifest_files()` which blocks via `self.finish()` @@ -140,17 +151,34 @@ class Query(Execute): self.ds = pa.dataset.dataset(self.temps) return self.ds + + def distribute_results(self, client): + if self.ds: + return self.ds + + futures = [] + for fn in self.manifest_files(): + df = pd.read_csv(fn) + future = client.scatter(df) + futures.append(future) + + return dd.from_delayed(futures, meta=df) + + # Return scalar results # Abstracts away a bunch of keystrokes def scalar(self): return self.results().head(1)[0][0].as_py() + def __enter__(self): return self + def __exit__(self, exception_type, exception_value, exception_traceback): self.close() + def close(self): if self.temps: for file in self.temps: diff --git a/minerva/cluster.py b/minerva/cluster.py index 02ffbc5..8391a84 100644 --- a/minerva/cluster.py +++ b/minerva/cluster.py @@ -40,9 +40,6 @@ class Cluster: self.create() self.login() self.start_dask() - #self.connect() - - #return self.client # Begin the startup process in the background @@ -69,11 +66,6 @@ class Cluster: w.cmd(f"dask worker {self.scheduler.private_ip}:8786", disown=True) - def connect(self): - self.client = Client(self.location) - return self.client - - def terminate(self): self.scheduler.terminate() for w in self.workers: diff --git a/minerva/machine.py b/minerva/machine.py index c0be007..9647dd0 100644 --- a/minerva/machine.py +++ b/minerva/machine.py @@ -75,6 +75,8 @@ class Machine: self.thread.join() def wait(self, n): + time.sleep(n) # give time for AWS to register that the instance has been created + i = 0 # Time for the server to show as "running" # and time for the server to finish getting daemons running