## Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending#""" Deephaven's learn module provides utilities for efficient data transfer between Deephaven tables and Python objects,as well as a framework for using popular machine-learning / deep-learning libraries with Deephaven tables."""fromtypingimportList,Union,Callable,TypeimportjpyfromdeephavenimportDHErrorfromdeephaven.tableimportTable_JLearnInput=jpy.get_type("io.deephaven.integrations.learn.Input")_JLearnOutput=jpy.get_type("io.deephaven.integrations.learn.Output")_JLearnComputer=jpy.get_type("io.deephaven.integrations.learn.Computer")_JLearnScatterer=jpy.get_type("io.deephaven.integrations.learn.Scatterer")
[docs]classInput:""" Input specifies how to gather data from a Deephaven table into an object. """def__init__(self,col_names:Union[str,List[str]],gather_func:Callable):""" Initializes an Input object with the given arguments. Args: col_names (Union[str, List[str]]) : column name or list of column names from which to gather input. gather_func (Callable): function that determines how input gets transformed into an object. """self.input=_JLearnInput(col_names,gather_func)def__str__(self):""" Returns the Input object as a string containing a printable representation of the Input object."""returnself.input.toString()
[docs]classOutput:""" Output specifies how to scatter data from an object into a table column. """def__init__(self,col_name:str,scatter_func:Callable,col_type:Type):""" Initializes an Output object with the given arguments. Args: col_name (str) : name of the new column that will store results. scatter_func (Callable): function that determines how data is taken from an object and placed into a Deephaven table column. col_type (Type) : desired data type of the new output column, default is None (no explicit type cast). """self.output=_JLearnOutput(col_name,scatter_func,col_type)def__str__(self):""" Returns the Output object as a string containing a printable representation of the Output object. """returnself.output.toString()
def_validate(inputs:Input,outputs:Output,table:Table):""" Ensures that all input columns exist in the table, and that no output column names already exist in the table. Args: inputs (Input) : list of Inputs to validate. outputs (Output) : list of Outputs to validate. table (Table) : table to check Input and Output columns against. Raises: ValueError : if at least one of the Input columns does not exist in the table. ValueError : if at least one of the Output columns already exists in the table. ValueError : if there are duplicates in the Output column names. """input_columns_list=[input_.input.getColNames()[i]forinput_ininputsforiinrange(len(input_.input.getColNames()))]input_columns=set(input_columns_list)table_columns=set(table.definition.keys())iftable_columns>=input_columns:ifoutputsisnotNone:output_columns_list=[output.output.getColName()foroutputinoutputs]output_columns=set(output_columns_list)iflen(output_columns_list)!=len(output_columns):repeats=set([columnforcolumninoutput_columns_listifoutput_columns_list.count(column)>1])raiseValueError(f"Cannot assign the same column name {repeats} to multiple columns.")eliftable_columns&output_columns:overlap=output_columns&table_columnsraiseValueError(f"The columns {overlap} already exist in the table. Please choose Output column names that are "f"not already in the table.")else:difference=input_columns-table_columnsraiseValueError(f"Cannot find columns {difference} in the table.")def_create_non_conflicting_col_name(table:Table,base_col_name:str)->str:""" Creates a column name that is not present in the table. Args: table (Table): table to check column name against. base_col_name (str): base name to create a column from. Returns: column name that is not present in the table. """table_col_names=set(table.definition.keys())ifbase_col_namenotintable_col_names:returnbase_col_nameelse:i=0whilebase_col_nameintable_col_names:base_col_name=base_col_name+str(i)returnbase_col_name
[docs]deflearn(table:Table=None,model_func:Callable=None,inputs:List[Input]=[],outputs:List[Output]=[],batch_size:int=None)->Table:""" Learn gathers data from multiple rows of the input table, performs a calculation, and scatters values from the calculation into an output table. This is a common computing paradigm for artificial intelligence, machine learning, and deep learning. Args: table (Table): the Deephaven table to perform computations on. model_func (Callable): function that performs computations on the table. inputs (List[Input]): list of Input objects that determine how data gets extracted from the table. outputs (List[Output]): list of Output objects that determine how data gets scattered back into the results table. batch_size (int): maximum number of rows for which model_func is evaluated at once. Returns: a Table with added columns containing the results of evaluating model_func. Raises: DHError """try:_validate(inputs,outputs,table)ifbatch_sizeisNone:raiseValueError("Batch size cannot be inferred. Please specify a batch size.")__computer=_JLearnComputer(table.j_table,model_func,[input_.inputforinput_ininputs],batch_size)future_offset=_create_non_conflicting_col_name(table,"__FutureOffset")clean=_create_non_conflicting_col_name(table,"__CleanComputer")ifoutputsisnotNone:__scatterer=_JLearnScatterer([output.outputforoutputinoutputs])return(table.update(formulas=[f"{future_offset} = __computer.compute(k)",]).update(formulas=[__scatterer.generateQueryStrings(f"{future_offset}"),]).update(formulas=[f"{clean} = __computer.clear()",]).drop_columns(cols=[f"{future_offset}",f"{clean}",]))result=_create_non_conflicting_col_name(table,"__Result")# calling __computer.clear() in a separate update ensures calculations are complete before computer is clearedreturn(table.update(formulas=[f"{future_offset} = __computer.compute(k)",f"{result} = {future_offset}.getFuture().get()"]).update(formulas=[f"{clean} = __computer.clear()"]).drop_columns(cols=[f"{future_offset}",f"{clean}",f"{result}",]))exceptExceptionase:raiseDHError(e,"failed to complete the learn function.")frome