diff --git a/TODO.md b/TODO.md index d37098b..242df33 100644 --- a/TODO.md +++ b/TODO.md @@ -1,2 +1,4 @@ -* add lambda support -* add outfile tracking to docker containers and instances and docker groups +* start docker containers via method on Machine() + * control and automate logfiles + * add logfile for `docker events` + diff --git a/minerva/__init__.py b/minerva/__init__.py index b0bd4ff..67302d2 100644 --- a/minerva/__init__.py +++ b/minerva/__init__.py @@ -3,12 +3,14 @@ from .helpers import parallel_map, load_template, load_sql, cluster_pool, AWS_IN from .athena import Athena from .redshift import Redshift from .s3 import S3 +from .lambda_func import Lambda from .docker import Docker -from .remote import Remote +from .remote import Remote, Logset from .machine import Machine from .pier import Pier -from .pool import Pool, TempOuts +from .pool import Pool +from .command import Command from .timing import Timing @@ -29,7 +31,9 @@ __all__ = [ "AWS_INSTANCES", "Pool", "Remote", - "TempOuts", - "Timing" + "Logset", + "Timing", + "Lambda", + "Command" ] diff --git a/minerva/command.py b/minerva/command.py new file mode 100644 index 0000000..bc31120 --- /dev/null +++ b/minerva/command.py @@ -0,0 +1,139 @@ +import threading +import select +import io +import sys +import minerva.remote +import tempfile + + +def flush_data(data, pipe, display=None): + pipe.write(data) + pipe.flush() + if display: + display.write(data) + display.flush() + + +class Command: + def __init__(self, machine, command, disown=False, watch=False, logset=(None, None)): + self.machine = machine + self.command = command + self.disown = disown + self.watch = watch + self.logset = logset + self.thread = None + + + # Unfortunately, under the hood, it's running /bin/bash -c '...' + # You stand informed + # + # This creates a pseudo-TTY on the other end + # + # `watch` means it'll print the output live, else it'll return the + # output (stdout, stderr) streams and the thread + # `disown` means it'll run in the background + # + # https://github.com/paramiko/paramiko/issues/593#issuecomment-145377328 + # + def execute(self): + stdin, stdout, stderr = self.machine.ssh.client.exec_command(self.command) + + # this is the same for all three inputs + channel = stdin.channel + + # regular TemporaryFile doesn't work for some reason, even with + # explicit flush(). I think it's because it doesn't actually create + # a file on disk until enough input has been gathered. + # + # A flush is required after every write + # Leave the files so that the readers can work even after the writers + # are done + # + # Thanks to SirDonNick in #python for the help here + + # Support for passing a logset object or manually specifying the outfiles + out = self.logset[0] or tempfile.NamedTemporaryFile(delete=False) + err = self.logset[1] or tempfile.NamedTemporaryFile(delete=False) + + # just in case it was (None, None) when passed in, we want to save the outputs + # this is admittedly changing the datatype of what was supplied from open file + # handles to strings + self.logset = (out.name, err.name) + + # Taken from + # https://stackoverflow.com/a/78765054 + # and then improved/cleaned up + + # we do not need stdin. + stdin.close() + # indicate that we're not going to write to that channel anymore + channel.shutdown_write() + + timeout = 60 + + def fill_buffers(out, err): + # perform chunked read to prevent stalls + while (not channel.closed + or channel.recv_ready() + or channel.recv_stderr_ready()): + # stop if channel was closed prematurely and buffers are empty + got_chunk = False + + readq, _, _ = select.select([channel], [], [], timeout) + + # returns three empty lists on timeout + if not readq: + break + for c in readq: + if c.recv_ready(): + flush_data(channel.recv(len(c.in_buffer)), + out, + (self.watch and sys.stdout.buffer)) + got_chunk = True + + if c.recv_stderr_ready(): + flush_data(channel.recv_stderr(len(c.in_stderr_buffer)), + err, + (self.watch and sys.stderr.buffer)) + got_chunk = True + # for c + + """ + 1) make sure that there are at least 2 cycles with no data in the input + buffers in order to not exit too early; i.e., cat on a >200k file + 2) if no data arrived in the last loop, check if we received exit code + 3) check if input buffers are empty + 4) exit the loop + """ + if (not got_chunk + and channel.exit_status_ready() + and not channel.recv_stderr_ready() + and not channel.recv_ready()): + # indicate that we're not going to read from this channel anymore + channel.shutdown_read() + # close the channel + channel.close() + # remote side is finished and our buffers are empty + break + # if + + # Don't close these because we want to reuse logfiles + #out.close() + #err.close() + + # while + + # close the pseudofiles + stdout.close() + stderr.close() + + self.thread = threading.Thread(target = fill_buffers, + args = (out, err)) + self.thread.start() + + if not self.disown: + print(f"running: {self.command}") + print(self.logset) + self.thread.join() + + diff --git a/minerva/docker.py b/minerva/docker.py index 4132485..c1c2986 100644 --- a/minerva/docker.py +++ b/minerva/docker.py @@ -1,13 +1,12 @@ import threading class Docker: - def __init__(self, container, machine=None, variables={}, stdout=None, stderr=None): + def __init__(self, container, machine=None, variables={}, logset=(None, None)): self.machine = machine self.uri = container self.variables = variables self.finished = False - self.stdout = stdout - self.stderr = stderr + self.logset = logset self.out = {"stdout": None, "stderr": None} self.registry = container.split("/")[0] @@ -42,12 +41,12 @@ class Docker: self.machine.login() if self.registry.endswith("amazonaws.com"): - self.machine.aws_docker_login(self.registry) + self.machine.aws_docker_login(self.registry, logset=self.logset) res = self.machine.docker_run(self.uri, cmd = cmd, env = self.variables, - output = (self.stdout, self.stderr)) + logset = self.logset) #self.out["stdout"] = res[0].name #self.out["stderr"] = res[1].name @@ -61,6 +60,8 @@ class Docker: self.finished = True print(f"finished on {self.machine.name}") + return res + def terminate(self): self.machine.terminate() diff --git a/minerva/lambda.py b/minerva/lambda.py deleted file mode 100644 index a9889e8..0000000 --- a/minerva/lambda.py +++ /dev/null @@ -1,13 +0,0 @@ -import json - -class Lambda: - def __init__(self, handler, name): - self.handler = handler - self.name = name - self.client = handler.session.client("lambda") - - def invoke(self, payload): - self.client.invoke(InvocationType = "RequestResponse", - FunctionName = self.name, - Payload = json.dumps(payload) or "{}") - diff --git a/minerva/lambda_func.py b/minerva/lambda_func.py new file mode 100644 index 0000000..0112e72 --- /dev/null +++ b/minerva/lambda_func.py @@ -0,0 +1,21 @@ +import json + +# TODO change the default timeout +class Lambda: + def __init__(self, handler, name): + self.handler = handler + self.name = name + self.client = handler.session.client("lambda") + + def invoke(self, payload, asynchronous=False): + asyn = (asynchronous and "Event") or "RequestResponse" + + return self.client.invoke(InvocationType = asyn, + FunctionName = self.name, + Payload = json.dumps(payload) or "{}") + + + def update(self, **kwargs): + return self.client.update_function_code(FunctionName = self.name, + **kwargs) + diff --git a/minerva/machine.py b/minerva/machine.py index 4b6502c..b6058f2 100644 --- a/minerva/machine.py +++ b/minerva/machine.py @@ -3,6 +3,7 @@ import math import datetime #from pexpect import pxssh from fabric import Connection +import paramiko.ssh_exception import shlex import threading import os @@ -16,11 +17,13 @@ class Machine(minerva.Remote): instance_type = "t2.micro", variables = {}, username = None, - key_pair = None, + key_pair = (None, None), name = "Minerva Instance", public = True, disk_size = 8): + super().__init__(None, username, key_pair[1], name) + self.pier = pier self.ami = ami self.instance_type = instance_type @@ -36,6 +39,7 @@ class Machine(minerva.Remote): self.terminated = False self.public = public self.disk_size = disk_size + self.ip = None # tracking which IP we're using for our connection def create(self): @@ -46,7 +50,7 @@ class Machine(minerva.Remote): res = self.pier.ec2.run_instances( ImageId = self.ami, InstanceType = self.instance_type, - KeyName = self.key_pair or self.pier.key_pair_name, + KeyName = self.key_pair[0] or self.pier.key_pair_name, MinCount = 1, MaxCount = 1, TagSpecifications = [{'ResourceType': 'instance', @@ -103,10 +107,7 @@ class Machine(minerva.Remote): reason = f"{self.info['InstanceId']} took too long to start ({i} attempts)" raise Exception(reason) - # Final wait, now that the server is up and running -- need - # some time for daemons to start - time.sleep(35) - self.ready = True + self.started = datetime.datetime.now() # alternatively, could maybe implement this with SSM so that we can access @@ -127,15 +128,29 @@ class Machine(minerva.Remote): print(f"\t{self.name} ({self.info['InstanceId']}\t- {self.instance_type}) => {self.public_ip} ({self.private_ip})") - ip = self.public_ip or self.private_ip - self.ssh = Connection(ip, + self.ip = self.public_ip or self.private_ip + self.ssh = Connection(self.ip, self.username, connect_kwargs = { - "key_filename": self.pier.key_path + "key_filename": self.key_pair[1] #self.pier.key_path } ) - self.ssh.open() - self.started = datetime.datetime.now() + + i = 0 + max_wait = 120 + # Time for the server to get SSH up and running + while True: + try: + self.ssh.open() + break + + except paramiko.ssh_exception.NoValidConnectionsError: + time.sleep(10) + i += 1 + + if i > (max_wait / 10): + reason = f"{self.info['InstanceId']} took too long to start ssh ({i} attempts)" + raise Exception(reason) return True diff --git a/minerva/minerva.py b/minerva/minerva.py index 1480817..231afde 100644 --- a/minerva/minerva.py +++ b/minerva/minerva.py @@ -28,3 +28,7 @@ class Minerva: return m.Pier(self, *args, **kwargs) + def lambda_func(self, *args, **kwargs): + return m.Lambda(self, *args, **kwargs) + + diff --git a/minerva/pier.py b/minerva/pier.py index ebb64e0..24de4cf 100644 --- a/minerva/pier.py +++ b/minerva/pier.py @@ -7,6 +7,8 @@ import pprint from minerva.machine import Machine from minerva.cluster import Cluster +from fabric import Connection + pp = pprint.PrettyPrinter(indent=4) # Used for interacting with AWS @@ -51,6 +53,31 @@ class Pier: return Machine(self, **kwargs) + # TODO make this fetch the instance id from the IP (filtered from the list + # of all instances) + def from_ip(self, ip, username, key_path=None, name=None): + mach = Machine(self, username=username) + + mach.username = username + mach.public_ip = ip + mach.ip = ip + + if key_path: + mach.key_pair = ("", key_path) + else: + mach.key_pair = (self.key_pair_name, self.key_path) + + mach.ssh = Connection(ip, + username, + connect_kwargs = { + "key_filename": mach.key_pair[1] + } + ) + mach.ssh.open() + + return mach + + def t3_med(self, num): return self.machine(ami = "ami-0a538467cc9da9bb2", instance_type = "t3.medium", diff --git a/minerva/pool.py b/minerva/pool.py index 3724592..5ebc23e 100644 --- a/minerva/pool.py +++ b/minerva/pool.py @@ -74,30 +74,3 @@ class Pool: return sum([mach.cost() for mach in self.machines]) -class TempOuts: - def __init__(self, directory, prefix): - self.directory = directory - self.prefix = prefix - self.stdout = None - self.stderr = None - - - def __enter__(self): - try: - os.mkdir(self.directory) - except: - pass - - path = os.path.join(self.directory, self.prefix) - - self.stdout = open(f"{path}_stdout.out", "ab") - self.stderr = open(f"{path}_stderr.out", "ab") - - return (self.stdout, self.stderr) - - - def __exit__(self, exception_type, exception_value, exception_traceback): - self.stdout.close() - self.stderr.close() - - diff --git a/minerva/remote.py b/minerva/remote.py index 87ddfcc..86d4fb0 100644 --- a/minerva/remote.py +++ b/minerva/remote.py @@ -1,29 +1,25 @@ from fabric import Connection import os -import sys -import threading -import select -import tempfile -import io import shlex +import random -def flush_data(data, pipe, display=None): - pipe.write(data) - pipe.flush() - if display: - display.write(data) - display.flush() +import minerva.docker as d +import minerva.command as c # Bare machine, not necessarily associated with AWS class Remote: def __init__(self, ip, username, - key_path): + key_path, + name = None): self.ip = ip + self.name = name or ip self.username = username - self.key_path = os.path.expanduser(key_path) # full path + self.key_path = key_path and os.path.expanduser(key_path) # full path self.ssh = None + self.logsets = [] + self.history = [] def login(self): if self.ssh: @@ -44,127 +40,22 @@ class Remote: return "; ".join([*base, *varz]) - # Unfortunately, under the hood, it's running /bin/bash -c '...' - # You stand informed - # - # This creates a pseudo-TTY on the other end - # - # `watch` means it'll print the output live, else it'll return the - # output (stdout, stderr) streams and the thread - # `disown` means it'll run in the background - # - # https://github.com/paramiko/paramiko/issues/593#issuecomment-145377328 - # - def cmd(self, command, hide=True, disown=False, watch=False, output=(None, None)): - # TODO this is necessary to load paramiko details - #self.ssh.run("echo hello world", warn=True, hide=hide, disown=disown) + def cmd(self, command, disown = False, + watch = False, + logset = (None, None)): - stdin, stdout, stderr = self.ssh.client.exec_command(command) - - # this is the same for all three inputs - channel = stdin.channel - - # regular TemporaryFile doesn't work for some reason, even with - # explicit flush(). I think it's because it doesn't actually create - # a file on disk until enough input has been gathered. - # - # A flush is required after every write - # Leave the files so that the readers can work even after the writers - # are done - # - # Thanks to SirDonNick in #python for the help here - out = output[0] or tempfile.NamedTemporaryFile(delete=False) - err = output[1] or tempfile.NamedTemporaryFile(delete=False) - - print(command) - print(f"\t{out.name} -- {err.name}") - - # Taken from - # https://stackoverflow.com/a/78765054 - # and then improved/cleaned up - - # we do not need stdin. - stdin.close() - # indicate that we're not going to write to that channel anymore - channel.shutdown_write() - - ## read stdout/stderr to prevent read block hangs - #flush_data(channel.recv(len(channel.in_buffer)), - # out, - # (watch and sys.stdout.buffer)) - - #flush_data(channel.recv_stderr(len(channel.in_stderr_buffer)), - # err, - # (watch and sys.stderr.buffer)) - - timeout = 60 - - def fill_buffers(out, err): - # perform chunked read to prevent stalls - while (not channel.closed - or channel.recv_ready() - or channel.recv_stderr_ready()): - # stop if channel was closed prematurely and buffers are empty - got_chunk = False - - readq, _, _ = select.select([channel], [], [], timeout) - - # returns three empty lists on timeout - if not readq: - break - for c in readq: - if c.recv_ready(): - flush_data(channel.recv(len(c.in_buffer)), - out, - (watch and sys.stdout.buffer)) - got_chunk = True - - if c.recv_stderr_ready(): - flush_data(channel.recv_stderr(len(c.in_stderr_buffer)), - err, - (watch and sys.stderr.buffer)) - got_chunk = True - # for c - - """ - 1) make sure that there are at least 2 cycles with no data in the input - buffers in order to not exit too early; i.e., cat on a >200k file - 2) if no data arrived in the last loop, check if we received exit code - 3) check if input buffers are empty - 4) exit the loop - """ - if (not got_chunk - and channel.exit_status_ready() - and not channel.recv_stderr_ready() - and not channel.recv_ready()): - # indicate that we're not going to read from this channel anymore - channel.shutdown_read() - # close the channel - channel.close() - # remote side is finished and our buffers are empty - break - # if - out.close() - err.close() - # while - - # close the pseudofiles - stdout.close() - stderr.close() - - thread = threading.Thread(target = fill_buffers, - args = (out, err)) - thread.start() - - if not disown: - thread.join() - - return (open(out.name, "rb"), open(err.name, "rb"), thread) + command = c.Command(self, command, disown = disown, + watch = watch, + logset = logset) + self.history.append(command) + command.execute() + return command - def write_env_file(self, variables, fname="~/env.list", output=(None, None)): + # maybe turn this into a `cat > filename` and write directly to stdin + def write_env_file(self, variables, fname="~/env.list", logset=(None, None)): vals = "\n".join([f"{var}={val}" for var, val in variables.items()]) - self.cmd(f"echo {shlex.quote(vals)} > {fname}", output=output) + self.cmd(f"echo {shlex.quote(vals)} > {fname}", logset=logset) return fname @@ -173,22 +64,70 @@ class Remote: return docker - def aws_docker_login(self, ecr, output=(None, None)): + def aws_docker_login(self, ecr, logset=(None, None)): return self.cmd(f"aws ecr get-login-password --region {self.pier.session.region_name} | " + f"docker login --username AWS --password-stdin {ecr}", - output=output) + logset=logset) - def docker_run(self, uri, cmd="", env={}, output=(None, None)): + def docker_run(self, uri, cmd="", env={}, logset=(None, None)): if env: - fname = self.write_env_file(env) + fname = self.write_env_file(env, logset=logset) environ = f"--env-file {fname}" else: environ = "" - return self.cmd(f"docker run -t {environ} {uri} {cmd}", output=output) + return self.cmd(f"docker run -t {environ} {uri} {cmd}", logset=logset) - def docker_pull(self, uri, output=(None, None)): - return self.cmd(f"docker pull {uri}", output=output) + def docker_pull(self, uri, logset=(None, None)): + return self.cmd(f"docker pull {uri}", logset=logset) + + + def docker(self, *args, **kwargs): + return d.Docker(machine = self, *args, **kwargs) + + + def stream_logs(self, job_id=None, hold_open=False): + ls = Logset(self, job_id, hold_open) + self.logsets.append(ls) + return ls + + + def track_docker_events(self): + with self.stream_logs(hold_open = True) as logset: + print(f"docker events at: {logset[0].name}") + self.docker_events = self.cmd("docker events", disown=True, logset=logset) + + +class Logset: + def __init__(self, machine, job_id=None, hold_open=False): + self.job_id = job_id or "job-%0.6f" % random.random() + self.machine = machine + self.directory = f"/tmp/{machine.ip}" + self.stdout = None + self.stderr = None + self.hold_open = hold_open # useful for getting a logset for a background command + + + def __enter__(self): + try: + os.mkdir(self.directory) + except: + pass + + # this had better not already exist + path = os.path.join(self.directory, self.job_id) + os.mkdir(path) + + self.stdout = open(os.path.join(path, "stdout.out"), "ab") + self.stderr = open(os.path.join(path, "stderr.out"), "ab") + + return (self.stdout, self.stderr) + + + def __exit__(self, exception_type, exception_value, exception_traceback): + if not self.hold_open: + self.stdout.close() + self.stderr.close()