#
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
"""This module implements the Session class which provides methods to connect to and interact with the Deephaven
server."""
from __future__ import annotations
import base64
import logging
import os
from random import random
import threading
from typing import Any, Dict, Iterable, List, Union, Tuple, NewType
import grpc
import pyarrow as pa
import pyarrow.flight as paflight
from bitstring import BitArray
from pyarrow._flight import ClientMiddlewareFactory, ClientMiddleware
from pydeephaven._app_service import AppService
from pydeephaven._arrow_flight_service import ArrowFlightService
from pydeephaven._config_service import ConfigService
from pydeephaven._console_service import ConsoleService
from pydeephaven._input_table_service import InputTableService
from pydeephaven._plugin_obj_service import PluginObjService
from pydeephaven._session_service import SessionService
from pydeephaven._table_ops import TimeTableOp, EmptyTableOp, MergeTablesOp, FetchTableOp, CreateInputTableOp
from pydeephaven._table_service import TableService
from pydeephaven.ticket import SharedTicket, ExportTicket, ScopeTicket, Ticket, ServerObject, _server_object_from_proto
from pydeephaven._utils import to_list
from pydeephaven.dherror import DHError
from pydeephaven.experimental.plugin_client import PluginClient
from pydeephaven.query import Query
from pydeephaven.table import Table, InputTable
logger = logging.getLogger(__name__)
class _DhClientAuthMiddlewareFactory(ClientMiddlewareFactory):
def __init__(self, session):
super().__init__()
self._session = session
self._middleware = _DhClientAuthMiddleware(session)
def start_call(self, info):
return self._middleware
class _DhClientAuthMiddleware(ClientMiddleware):
def __init__(self, session):
super().__init__()
self._session = session
def call_completed(self, exception):
super().call_completed(exception)
def received_headers(self, headers):
super().received_headers(headers)
header_key = "authorization"
try:
if headers and header_key in headers:
header_value = headers.get(header_key)
auth_header_value = bytes(header_value[0], encoding='ascii')
if auth_header_value:
self._session._auth_header_value = auth_header_value
except Exception as e:
logger.exception(f'_DhClientAuthMiddleware.received_headers got headers={headers}')
return
def sending_headers(self):
return None
def _trace(who: str) -> None:
logger.debug(f'TRACE: {who}')
_BidiRpc = NewType("_BidiRpc", grpc.StreamStreamMultiCallable)
_NotBidiRpc = NewType(
"_NotBidiRpc",
Union[
grpc.UnaryUnaryMultiCallable,
grpc.UnaryStreamMultiCallable,
grpc.StreamUnaryMultiCallable])
[docs]class Session:
"""A Session represents a connection to the Deephaven data server. It contains a number of convenience
methods for asking the server to create tables, import Arrow data into tables, merge tables, run Python scripts, and
execute queries.
Session objects can be used in Python with statement so that whatever happens in the with statement block, they
are guaranteed to be closed upon exit.
Attributes:
tables (list[str]): names of the global tables available in the server after running scripts
is_alive (bool): check if the session is still alive (may refresh the session)
"""
def __init__(self, host: str = None,
port: int = None,
auth_type: str = "Anonymous",
auth_token: str = "",
never_timeout: bool = True,
session_type: str = 'python',
use_tls: bool = False,
tls_root_certs: bytes = None,
client_cert_chain: bytes = None,
client_private_key: bytes = None,
client_opts: List[Tuple[str, Union[int, str]]] = None,
extra_headers: Dict[bytes, bytes] = None):
"""Initializes a Session object that connects to the Deephaven server
Args:
host (str): the host name or IP address of the remote machine, default is 'localhost'
port (int): the port number that Deephaven server is listening on, default is 10000
auth_type (str): the authentication type string, can be "Anonymous', 'Basic", or any custom-built
authenticator in the server, such as "io.deephaven.authentication.psk.PskAuthenticationHandler",
default is 'Anonymous'.
auth_token (str): the authentication token string. When auth_type is 'Basic', it must be
"user:password"; when auth_type is "Anonymous', it will be ignored; when auth_type is a custom-built
authenticator, it must conform to the specific requirement of the authenticator
never_timeout (bool): never allow the session to timeout, default is True
session_type (str): the Deephaven session type. Defaults to 'python'
use_tls (bool): if True, use a TLS connection. Defaults to False
tls_root_certs (bytes): PEM encoded root certificates to use for TLS connection, or None to use system defaults.
If not None implies use a TLS connection and the use_tls argument should have been passed
as True. Defaults to None
client_cert_chain (bytes): PEM encoded client certificate if using mutual TLS. Defaults to None,
which implies not using mutual TLS.
client_private_key (bytes): PEM encoded client private key for client_cert_chain if using mutual TLS.
Defaults to None, which implies not using mutual TLS.
client_opts (List[Tuple[str,Union[int,str]]): list of tuples for name and value of options to
the underlying grpc channel creation. Defaults to None, which implies not using any channel
options.
See https://grpc.github.io/grpc/cpp/group__grpc__arg__keys.html for a list of valid options.
Example options:
[ ('grpc.target_name_override', 'idonthaveadnsforthishost'),
('grpc.min_reconnect_backoff_ms', 2000) ]
extra_headers (Dict[bytes, bytes]): additional headers (and values) to add to server requests.
Defaults to None, which implies not using any extra headers.
Raises:
DHError
"""
_trace('Session.__init__')
self._r_lock = threading.RLock() # for thread-safety when accessing/changing session global state
self._services_lock = threading.Lock() # for lazy initialization of services
self._last_export_ticket_number: int = 0
self._ticket_bitarray = BitArray(1024)
self.host = host
if not host:
self.host = os.environ.get("DH_HOST", "localhost")
self.port = port
if not port:
self.port = int(os.environ.get("DH_PORT", 10000))
self._logpfx = f'pydh.Session {id(self)} {host}:port: '
self._use_tls = use_tls
self._tls_root_certs = tls_root_certs
self._client_cert_chain = client_cert_chain
self._client_private_key = client_private_key
self._client_opts = client_opts
self._extra_headers = extra_headers if extra_headers else {}
self.is_connected = False
# We set here the initial value for the authorization header,
# which will bootstrap our authentication to the server on the first
# RPC going out. The server will give us back a bearer token to use
# in subsequent RPCs in the same authorization header. From then
# on, the value of _auth_header_value will be similar to b'Bearer X'
# where X is the bearer token provided by the server.
if auth_type == "Anonymous":
self._auth_header_value = auth_type
elif auth_type == "Basic":
auth_token_base64 = base64.b64encode(auth_token.encode("ascii")).decode("ascii")
self._auth_header_value = "Basic " + auth_token_base64
else:
self._auth_header_value = str(auth_type) + " " + auth_token
self._auth_header_value = bytes(self._auth_header_value, 'ascii')
# Counter for consecutive failures to refresh auth token, used to calculate retry backoff
self._refresh_failures = 0
self.grpc_channel = None
self._session_service = None
self._table_service = None
self._grpc_barrage_stub = None
self._console_service = None
self._flight_service = None
self._app_service = None
self._input_table_service = None
self._plugin_obj_service = None
self._never_timeout = never_timeout
self._keep_alive_timer = None
self._session_type = session_type
self._flight_client = None
self._auth_handler = None
self._config_service = None
self._connect()
def __enter__(self):
if not self.is_connected:
# double-checked locking, is_connected is checked inside _connect again, which
# may not end up connecting.
self._connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __del__(self):
self.close()
def update_metadata(self, metadata: Iterable[Tuple[str, Union[str, bytes]]]) -> None:
for header_tuple in metadata:
if header_tuple[0] == "authorization":
v = header_tuple[1]
self._auth_header_value = v if isinstance(v, bytes) else v.encode('ascii')
break
def wrap_rpc(self, stub_call: _NotBidiRpc, *args, **kwargs) -> Any:
if 'metadata' in kwargs:
raise DHError('Internal error: "metadata" in kwargs not supported in wrap_rpc.')
kwargs["metadata"] = self.grpc_metadata
# We use a future to get a chance to process initial metadata before the call
# is completed
future = stub_call.future(*args, **kwargs)
self.update_metadata(future.initial_metadata())
# Now block until we get the result (or an exception)
return future.result()
def wrap_bidi_rpc(self, stub_call: _BidiRpc, *args, **kwargs) -> Any:
if 'metadata' in kwargs:
raise DHError('Internal error: "metadata" in kwargs not supported in wrap_bidi_rpc.')
kwargs["metadata"] = self.grpc_metadata
response = stub_call(*args, **kwargs)
self.update_metadata(response.initial_metadata())
return response
@property
def tables(self):
with self._r_lock:
fields = self._fetch_fields()
return [field.field_name for field in fields if
field.application_id == 'scope' and field.typed_ticket.type == 'Table']
@property
def exportable_objects(self) -> Dict[str, ServerObject]:
with self._r_lock:
fields = self._fetch_fields()
return {field.field_name: _server_object_from_proto(field.typed_ticket) for field in fields}
@property
def grpc_metadata(self):
header_value_snap = self._auth_header_value # ensure it doesn't change while doing multiple reads
if not header_value_snap or not isinstance(header_value_snap, bytes):
logger.warning(f'{self._logpfx} internal invariant violated, _auth_header_value={header_value_snap}')
l = []
else:
l = [(b'authorization', header_value_snap)]
if self._extra_headers:
l.extend(list(self._extra_headers.items()))
return l
@property
def table_service(self) -> TableService:
if not self._table_service:
with self._services_lock:
if not self._table_service:
self._table_service = TableService(self)
return self._table_service
@property
def session_service(self) -> SessionService:
if not self._session_service:
with self._services_lock:
if not self._session_service:
self._session_service = SessionService(self)
return self._session_service
@property
def console_service(self) -> ConsoleService:
if not self._console_service:
with self._services_lock:
if not self._console_service:
self._console_service = ConsoleService(self)
return self._console_service
@property
def flight_service(self) -> ArrowFlightService:
if not self._flight_service:
with self._services_lock:
if not self._flight_service:
self._flight_service = ArrowFlightService(self, self._flight_client)
return self._flight_service
@property
def app_service(self) -> AppService:
if not self._app_service:
with self._services_lock:
if not self._app_service:
self._app_service = AppService(self)
return self._app_service
@property
def config_service(self):
if not self._config_service:
with self._services_lock:
if not self._config_service:
self._config_service = ConfigService(self)
return self._config_service
@property
def input_table_service(self) -> InputTableService:
if not self._input_table_service:
with self._services_lock:
if not self._input_table_service:
self._input_table_service = InputTableService(self)
return self._input_table_service
@property
def plugin_object_service(self) -> PluginObjService:
if not self._plugin_obj_service:
with self._services_lock:
if not self._plugin_obj_service:
self._plugin_obj_service = PluginObjService(self)
return self._plugin_obj_service
def make_export_ticket(self, ticket_no: int = None) -> ExportTicket:
if not ticket_no:
ticket_no = self.next_export_ticket_number()
return ExportTicket.export_ticket(ticket_no)
def next_export_ticket_number(self) -> int:
with self._r_lock:
self._last_export_ticket_number += 1
if self._last_export_ticket_number == 2 ** 31 - 1:
raise DHError("fatal error: out of free internal ticket")
return self._last_export_ticket_number
def _fetch_fields(self):
"""Returns a list of available fields on the server.
Raises:
DHError
"""
with self._r_lock:
list_fields = self.app_service.list_fields()
resp = next(list_fields)
if not list_fields.cancel():
raise DHError("could not cancel ListFields subscription")
return resp.created if resp.created else []
def _connect(self):
_trace(f'_connect id={id(self)}')
with self._r_lock:
if self.is_connected:
return
_trace(f'_connect id={id(self)} connecting.')
try:
scheme = "grpc+tls" if self._use_tls else "grpc"
self._flight_client = paflight.FlightClient(
location=f"{scheme}://{self.host}:{self.port}",
middleware=[_DhClientAuthMiddlewareFactory(self)],
tls_root_certs=self._tls_root_certs,
cert_chain=self._client_cert_chain,
private_key=self._client_private_key,
generic_options=self._client_opts
)
except Exception as e:
raise DHError("failed to connect to the server.") from e
self.grpc_channel = self.session_service.connect()
# This RPC will get is the configuration and will also bootstrap
# our authentication to the server by virtue of sending the right
# header: "authorization" header key and our selected header value.
# The implementation will process the initial headers coming back
# from the server which will contain the bearer token we will
# use in subsequent RPCs; the token will be included in the updated
# value for self._auth_header_value that will happen through a call
# to update_metadata.
config_dict = self.config_service.get_configuration_constants()
session_duration = config_dict.get("http.session.durationMs")
if not session_duration:
raise DHError("server configuration is missing http.session.durationMs")
self._timeout_seconds = int(session_duration.string_value)/1000.0
# Random skew to ensure multiple processes that may have
# started together don't align retries.
skew = random()
# Backoff schedule for retries after consecutive failures to refresh auth token
self._refresh_backoff = [skew + 0.1, skew + 1, skew + 10]
if self._refresh_backoff[0] > self._timeout_seconds:
raise DHError(f'server configuration http.session.durationMs={session_duration} is too small.')
if 0.25*self._timeout_seconds < self._refresh_backoff[-1]:
self._refresh_backoff.extend(
[skew + 0.25 * self._timeout_seconds,
skew + 0.35 * self._timeout_seconds,
skew + 0.45 * self._timeout_seconds])
for i in range(1, len(self._refresh_backoff)):
if self._refresh_backoff[i] > self._timeout_seconds:
self._refresh_backoff = self._refresh_backoff[0:i]
break
self.is_connected = True
if self._never_timeout:
self._keep_alive()
def _keep_alive(self):
_trace(f'_keep_alive')
if not self.is_connected:
return
ok = True
if self._keep_alive_timer:
ok = self._refresh_token()
if ok:
self._refresh_failures = 0
else:
self._refresh_failures += 1
if self._refresh_failures == 0:
timer_wakeup = 0.5*self._timeout_seconds
elif self._refresh_failures >= len(self._refresh_backoff):
msg = f'Failed to refresh token {self._refresh_failures} times, will stop retrying.'
logger.critical(msg)
raise DHError(msg)
else:
timer_wakeup = self._refresh_backoff[self._refresh_failures]
_trace(f'_keep_alive timer_wakeup={timer_wakeup}')
self._keep_alive_timer = threading.Timer(timer_wakeup, self._keep_alive)
self._keep_alive_timer.daemon = True
self._keep_alive_timer.start()
if not ok:
logger.warning(
f'{self._logpfx}: failed to refresh auth token (retry #{self._refresh_failures-1}).' +
f' Will retry in {timer_wakeup} seconds.')
def _refresh_token(self) -> bool:
_trace('_refresh_token')
try:
self.config_service.get_configuration_constants()
return True
except Exception as ex:
logger.warning(f'{self._logpfx} Caught exception while refreshing auth token: {ex}.')
return False
@property
def is_alive(self) -> bool:
"""Whether the session is alive."""
with self._r_lock:
if not self.is_connected:
return False
if self._never_timeout:
return True
try:
self.config_service.get_configuration_constants()
return True
except DHError as e:
self.is_connected = False
return False
[docs] def close(self) -> None:
"""Closes the Session object if it hasn't timed out already.
Raises:
DHError
"""
with self._r_lock:
if not self.is_connected:
return
self.session_service.close()
self.grpc_channel.close()
self.is_connected = False
self._last_export_ticket_number = 0
self._flight_client.close()
[docs] def release(self, ticket: ExportTicket) -> None:
"""Releases an export ticket.
Args:
ticket (Ticket): the ticket to release
"""
self.session_service.release(ticket)
# convenience/factory methods
[docs] def run_script(self, script: str, systemic: Optional[bool] = None) -> None:
"""Runs the supplied Python script on the server.
Args:
script (str): the Python script code
systemic (bool): Whether to treat the code as systemically important. Defaults to None which uses the
default system behavior
Raises:
DHError
"""
response = self.console_service.run_script(script, systemic)
if response.error_message != '':
raise DHError("could not run script: " + response.error_message)
[docs] def open_table(self, name: str) -> Table:
"""Opens a table in the global scope with the given name on the server.
Args:
name (str): the name of the table
Returns:
a Table object
Raises:
DHError
"""
ticket = ScopeTicket.scope_ticket(name)
faketable = Table(session=self, ticket=ticket)
try:
table_op = FetchTableOp()
return self.table_service.grpc_table_op(faketable, table_op)
except Exception as e:
if isinstance(e.__cause__, grpc.RpcError):
if e.__cause__.code() == grpc.StatusCode.INVALID_ARGUMENT:
raise DHError(f"no table by the name {name}") from None
raise e
finally:
# Explicitly close the table without releasing it (because it isn't ours)
faketable.ticket = None
faketable.schema = None
[docs] def bind_table(self, name: str, table: Table) -> None:
"""Binds a table to the given name on the server so that it can be referenced by that name.
Args:
name (str): name for the table
table (Table): a Table object
Raises:
DHError
"""
self.console_service.bind_table(table=table, variable_name=name)
[docs] def publish(self, source_ticket: Ticket, result_ticket: Ticket) -> None:
""" Publishes a source ticket to the result ticket.
This is low-level method that can be used to publish non-Table server objects that are previously
fetched from the server. The source ticket represents the previously fetched server object to be published, and
the result ticket, which should normally be a SharedTicket, is the ticket to publish to. The result ticket can
then be fetched by other sessions to access the object as long as the object is not released. This method is
used together with the :meth:`.fetch` method to share server objects between sessions.
Args:
source_ticket (Ticket): The source ticket to publish from.
result_ticket (Ticket): The result ticket to publish to.
Raises:
DHError: If the operation fails.
"""
self._session_service.publish(source_ticket, result_ticket)
[docs] def fetch(self, ticket: Ticket) -> ExportTicket:
"""Fetches a server object by ticket.
This is low-level method that can be used to fetch non-Table server objects. The ticket represents a
fetchable server object, e.g :class:`~.plugin_client.PluginClient`, :class:`~.plugin_client.Fetchable`.
This method is used together with the :meth:`.publish` method to share server objects between sessions.
Args:
ticket (Ticket): a ticket
Returns:
an ExportTicket object
Raises:
DHError
"""
return self._session_service.fetch(ticket)
[docs] def publish_table(self, ticket: SharedTicket, table: Table) -> None:
"""Publishes a table to the given shared ticket. The ticket can then be used by another session to fetch the
table.
Note that, the shared ticket can be fetched by other sessions to access the table as long as the table is
not released. When the table is released either through an explicit call of the close method on it, or
implicitly through garbage collection, or through the closing of the publishing session, the shared ticket will
no longer be valid.
Args:
ticket (SharedTicket): a SharedTicket object
table (Table): a Table object
Raises:
DHError
"""
self.publish(table.ticket, ticket)
[docs] def fetch_table(self, ticket: SharedTicket) -> Table:
"""Fetches a table by ticket.
Args:
ticket (SharedTicket): a ticket
Returns:
a Table object
Raises:
DHError
"""
try:
return self.table_service.fetch_etcr(ticket.pb_ticket)
except Exception as e:
raise DHError("could not fetch table by ticket") from e
[docs] def time_table(self, period: Union[int, str], start_time: Union[int, str] = None,
blink_table: bool = False) -> Table:
"""Creates a time table on the server.
Args:
period (Union[int, str]): the interval at which the time table ticks (adds a row); units are nanoseconds
or a time interval string, e.g. "PT00:00:.001" or "PT1S"
start_time (Union[int, str]): the start time for the time table in nanoseconds or as a date time
formatted string; default is None (meaning now)
blink_table (bool, optional): if the time table should be a blink table, defaults to False
Returns:
a Table object
Raises:
DHError
"""
table_op = TimeTableOp(start_time=start_time, period=period, blink_table=blink_table)
return self.table_service.grpc_table_op(None, table_op)
[docs] def empty_table(self, size: int) -> Table:
"""Creates an empty table on the server.
Args:
size (int): the size of the empty table in number of rows
Returns:
a Table object
Raises:
DHError
"""
table_op = EmptyTableOp(size=size)
return self.table_service.grpc_table_op(None, table_op)
[docs] def import_table(self, data: pa.Table) -> Table:
"""Imports the pyarrow table as a new Deephaven table on the server.
Deephaven supports most of the Arrow data types. However, if the pyarrow table contains any field with a data
type not supported by Deephaven, the import operation will fail.
Args:
data (pa.Table): a pyarrow Table object
Returns:
a Table object
Raises:
DHError
"""
return self.flight_service.import_table(data=data)
[docs] def merge_tables(self, tables: List[Table], order_by: str = None) -> Table:
"""Merges several tables into one table on the server.
Args:
tables (list[Table]): the list of Table objects to merge
order_by (str, optional): if specified the resultant table will be sorted on this column
Returns:
a Table object
Raises:
DHError
"""
table_op = MergeTablesOp(tables=tables, key_column=order_by)
return self.table_service.grpc_table_op(None, table_op)
[docs] def query(self, table: Table) -> Query:
"""Creates a Query object to define a sequence of operations on a Deephaven table.
Args:
table (Table): a Table object
Returns:
a Query object
Raises:
DHError
"""
return Query(self, table)
[docs] def plugin_client(self, server_obj: ServerObject) -> PluginClient:
"""Wraps a ticket as a PluginClient. Capabilities here vary based on the server implementation of the ObjectType,
but most will at least send a response payload to the client, possibly including references to other objects.
In some cases, depending on the server implementation, the client will also be able to send the same sort of
messages back to the server.
Part of the experimental plugin API."""
return PluginClient(self, server_obj)