Source code for profit.run.zeromq

""" zeromq Interface

Ideas & Help from the 0MQ Guide (zguide.zeromq.org, examples are licensed with MIT)
"""

import zmq
import numpy as np
import json
from time import sleep
from logging import Logger
import os

from .interface import RunnerInterface, WorkerInterface


# === ZeroMQ Interface === #


[docs]class ZeroMQRunnerInterface(RunnerInterface, label="zeromq"): """Runner-Worker Interface using the lightweight message queue `ZeroMQ <https://zeromq.org/>`_ - can use different transport systems, most commonly tcp - can be used efficiently on a cluster (tested) - expected to be inefficient for a large number of small, locally run simulations where communication overhead is a concern (unverified, could be mitigated by using a different transport system) - known issue: some workers were unable to establish a connection with three tries, reason unknown Parameters: transport: ZeroMQ transport protocol address: override ip address or hostname of the Runner Interface (default: localhost, automatic with Slurm) port: port of the Runner Interface connection: override for the ZeroMQ connection spec (Worker side) bind: override for the ZeroMQ bind spec (Runner side) timeout: connection timeout when waiting for an answer in seconds (Worker) retries: number of tries to establish a connection (Worker) retry_sleep: sleep time in seconds between each retry (Worker) Attributes: socket (zmq.Socket): ZeroMQ backend logger (logging.Logger): Logger """ def __init__( self, size, input_config, output_config, *, transport="tcp", address=None, port=9000, connection=None, bind=None, timeout=4, retries=3, retry_sleep=1, logger_parent: Logger = None, ): if "FLAGS" not in [var[0] for var in self.internal_vars]: self.internal_vars += [("FLAGS", np.byte.__name__)] super().__init__(size, input_config, output_config, logger_parent=logger_parent) self.transport = transport self.address = address self.port = port self.connection = connection self._bind = bind self.timeout = timeout self.retries = retries self.retry_sleep = retry_sleep self.socket = zmq.Context.instance().socket(zmq.ROUTER) self.socket.bind(self.bind) self.logger.info(f"connected to {self.bind}") @property def bind(self): if self._bind is None: return f"{self.transport}://*:{self.port}" else: return self._bind @bind.setter def bind(self, value): self._bind = value @property def config(self): config = { "transport": self.transport, "address": self.address, "port": self.port, "connection": self.connection, "bind": self._bind, "timeout": self.timeout, "retries": self.retries, "retry_sleep": self.retry_sleep, } return {**super().config, **config} # super().config | config in python3.9
[docs] def poll(self): self.logger.debug("polling: checking for messages") # poll does not wait for messages (timeout=10ms as 0 means wait forever) # waiting should be done with the runner (sleep) # this allows the runner to react to messages immediately while self.socket.poll(timeout=10, flags=zmq.POLLIN): msg = self.socket.recv_multipart() # ToDo: Heartbeats self.handle_msg(msg[0], msg[2:])
[docs] def handle_msg(self, address: bytes, msg: list): if address[:4] == b"req_": # req_123 run_id = int(address[4:]) if msg[0] == b"READY": input_descr = json.dumps(self.input_vars).encode() output_descr = json.dumps(self.output_vars).encode() self.logger.debug( f"run {run_id} READY: {input_descr} + {self.input[run_id]} + output {output_descr}" ) self.socket.send_multipart( [address, b"", input_descr, self.input[run_id], output_descr] ) self.internal["FLAGS"][run_id] |= 0x02 elif msg[0] == b"DATA": self.output[run_id] = np.frombuffer(msg[1], dtype=self.output_vars) self.logger.debug( f"run {run_id} DATA: {np.frombuffer(msg[1], dtype=self.output_vars)[0]}" ) self.internal["DONE"][run_id] = True self.internal["FLAGS"][run_id] |= 0x08 self.socket.send_multipart([address, b"", b"ACK"]) # acknowledge elif msg[0] == b"TIME": self.internal["TIME"][run_id] = np.frombuffer(msg[1], dtype=np.uint) self.logger.debug( f"run {run_id} TIME: {np.frombuffer(msg[1], dtype=np.uint)[0]}" ) self.socket.send_multipart([address, b"", b"ACK"]) # acknowledge elif msg[0] == b"DIE": self.internal["FLAGS"][run_id] |= 0x04 self.logger.debug(f"run {run_id} DIE") self.socket.send_multipart([address, b"", b"ACK"]) # acknowledge else: self.logger.warning(f"received unknown message {address}: {msg}") else: self.logger.warning( f"received message from unknown client {address}: {msg}" )
[docs] def clean(self): self.logger.info("cleaning: closing socket") self.socket.close(linger=0) zmq.Context.instance().destroy(linger=1)
[docs] def __del__(self): self.socket.close(linger=0) zmq.Context.instance().destroy(linger=1)
[docs]class ZeroMQWorkerInterface(WorkerInterface, label="zeromq"): """Runner-Worker Interface using the lightweight message queue `ZeroMQ <https://zeromq.org/>`_ counterpart to :py:class:`ZeroMQRunnerInterface` """ def __init__( self, run_id: int, *, transport="tcp", address=None, port=9000, connection=None, bind=None, timeout=4, retries=3, retry_sleep=1, logger_parent: Logger = None, ): # TODO: duplicate default values super().__init__(run_id, logger_parent=logger_parent) self.transport = transport self.address = address self.port = port self._connection = connection self.bind = bind self.timeout = timeout self.retries = retries self.retry_sleep = retry_sleep self._connected = False @property def connection(self): if self._connection is None: address = ( self.address or os.environ.get("PROFIT_RUNNER_ADDRESS") or "localhost" ) return f"{self.transport}://{address}:{self.port}" else: return self._connection @connection.setter def connection(self, value): self._connection = value @property def config(self): config = { "transport": self.transport, "address": self.address, "port": self.port, "connection": self._connection, "bind": self._bind, "timeout": self.timeout, "retries": self.retries, "retry_sleep": self.retry_sleep, } return {**super().config, **config} # super().config | config in python3.9
[docs] def retrieve(self): self.connect() self.request("READY") self.disconnect()
[docs] def transmit(self): self.connect() self.request("TIME") self.request("DATA") self.disconnect()
[docs] def clean(self): if self._connected: self.disconnect()
[docs] def connect(self): self.socket = zmq.Context.instance().socket(zmq.REQ) self.socket.setsockopt(zmq.IDENTITY, f"req_{self.run_id}".encode()) self.socket.connect(self.connection) self.logger.info(f"connected to {self.connection}") self._connected = True
[docs] def disconnect(self): self.socket.close(linger=0) self._connected = False
[docs] def __del__(self): if self._connected: self.socket.close(linger=0)
[docs] def request(self, request): """0MQ - Lazy Pirate Pattern""" if not self._connected: self.logger.info("no connection") self.connect() if request not in ["READY", "DATA", "TIME"]: raise ValueError(f'unknown request "{request}"') tries = 0 while True: msg = [request.encode()] if request == "DATA": msg.append(self.output) elif request == "TIME": msg.append(np.uint(self.time)) self.socket.send_multipart(msg) self.logger.debug(f"send message {msg}") if self.socket.poll(timeout=int(1e3 * self.timeout), flags=zmq.POLLIN): response = None try: response = self.socket.recv_multipart() if request == "READY": input_descr, input_data, output_descr = response input_descr = [ tuple(column) for column in json.loads(input_descr.decode()) ] output_descr = [ tuple(column[:2] + [tuple(column[2])]) for column in json.loads(output_descr.decode()) ] self.input = np.frombuffer(input_data, dtype=input_descr)[0] self.output = np.zeros(1, dtype=output_descr)[0] self.logger.info("READY - received input data") self.logger.debug( f"received: {np.frombuffer(input_data, dtype=input_descr)}" ) return else: assert response[0] == b"ACK" self.logger.info(f"{request}: message acknowledged") return except (ValueError, AssertionError): self.logger.debug(f"{request}: received {response}") self.logger.error(f"{request}: malformed reply") else: self.logger.info(f"{request}: no response") tries += 1 sleep(self.retry_sleep) if tries >= self.retries + 1: self.logger.error( f"{request}: {tries} requests unsuccessful, abandoning" ) self.disconnect() raise ConnectionError("could not connect to RunnerInterface") # close and reopen the socket self.disconnect() self.connect()