Source code for lalegpl.lib.r.arules_cba_classifier

# Copyright 2019 IBM Corporation
#
# Licensed under the GNU General Public License 3.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.txt
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
import lale
import lale.helpers
from lalegpl.lib.r.util import install_r_package, create_r_dataframe
import pandas
try:
  import rpy2.robjects
  import rpy2.robjects.packages
except ImportError:
  raise ImportError("""ArulesCBAClassifier needs a Python package called `rpy2`. 
  You can install it using `pip install rpy2` or install lalegpl[full] which will install it for you.""")

import numpy as np

[docs]class ArulesCBAClassifier_Impl: def __init__(self, **hyperparams): self._hyperparams = hyperparams
[docs] def fit(self, X, y): arules_pkg = install_r_package('arulesCBA') if not isinstance(y, pandas.Series): y_name = "target" else: y_name = y.name formula = rpy2.robjects.Formula(f'{y_name} ~ .') r_train = create_r_dataframe(X, y) hps = {k: v for k, v in self._hyperparams.items() if v is not None} if False: lale.helpers.println_pos('arules_pkg.CBA(formula="{}", data=[\n{}], {})'.format(formula, r_train, ', '.join([f'{k}={v}' for k, v in hps.items()]))) self._r_model = arules_pkg.CBA(formula=formula, data=r_train, **hps) return self
[docs] def predict(self, X): stats_pkg = rpy2.robjects.packages.importr('stats') predict_fun = stats_pkg.predict r_test = create_r_dataframe(X) r_predictions = predict_fun(self._r_model, r_test) levels = r_predictions.levels predictions = [levels[yi - 1] for yi in r_predictions] return np.array(predictions, dtype=np.int)
_input_schema_fit = { '$schema': 'http://json-schema.org/draft-04/schema#', 'description': 'Input data schema for training. So far only works for strings.', 'type': 'object', 'required': ['X', 'y'], 'additionalProperties': False, 'properties': { 'X': { 'description': 'Features; the outer array is over samples.', 'type': 'array', 'items': {'type': 'array', 'items': {'type': 'number'}}}, 'y': { 'description': 'Target class labels; the array is over samples.', 'type': 'array', 'items': {'type': 'number'}}}} _input_schema_predict = { '$schema': 'http://json-schema.org/draft-04/schema#', 'description': 'Input data schema for predictions. So far only works for strings', 'type': 'object', 'required': ['X'], 'additionalProperties': False, 'properties': { 'X': { 'description': 'Features; the outer array is over samples.', 'type': 'array', 'items': {'type': 'array', 'items': {'type': 'number'}}}}} _output_predict_schema = { '$schema': 'http://json-schema.org/draft-04/schema#', 'description': 'Output data schema for predictions (target class labels). ' 'So far only works for strings.', 'type': 'array', 'items': {'type': 'number'}} _hyperparams_schema = { '$schema': 'http://json-schema.org/draft-04/schema#', 'description': 'Hyperparameter schema.', 'allOf': [ { 'description': 'This first sub-object lists all constructor arguments with their ' 'types, one at a time, omitting cross-argument constraints.', 'type': 'object', 'additionalProperties': False, 'required': [ 'support', 'confidence', 'disc_method', 'balanceSupport', 'pruning'], 'relevantToOptimizer': ['confidence', 'disc_method'], 'properties': { 'support': { 'description': 'Minimum support for creating association rules.', 'type': 'number', 'default': 0.2, 'minimum': 0, 'maximumForOptimizer': 1.0}, 'confidence': { 'description': 'Minimum confidence for creating association rules.', 'type': 'number', 'default': 0.8, 'minimum': 0, 'maximumForOptimizer': 1.0}, 'disc_method': { 'description': 'Discretization method for factorizing numeric input.', 'default': 'mdlp', 'enum': ['mdlp', 'caim', 'cacc', 'ameva', 'chi2', 'chimerge', 'extendedchi2', 'modchi2']}, 'balanceSupport': { 'description': 'If true, class imbalance is counteracted by using the minimum support only for the majority class.', 'type': 'boolean', 'default': False}, 'pruning': { 'enum': ['M1', 'M2'], 'default': 'M1'}, # 'parameter': { # 'description': 'Optional parameter list for apriori.', # 'anyOf': [ # { 'type': 'array'}, # { 'enum': [None]}], # 'default': None}, # 'control': { # 'description': 'Optional control list for apriori.', # 'anyOf': [ # { 'type': 'array'}, # { 'enum': [None]}], # 'default': None}, # 'verbose': { # 'description': # 'Optional logical flag to allow verbose execution, where ' # 'additional intermediary execution information is printed ' # 'at runtime.', # 'type': 'boolean', # 'default': False}, }}]} _combined_schemas = { '$schema': 'http://json-schema.org/draft-04/schema#', 'description': 'Combined schema for expected data and hyperparameters.', 'documentation_url': 'https://cran.r-project.org/web/packages/arulesCBA/', 'type': 'object', 'tags': { 'pre': ['categoricals'], 'op': ['estimator', 'classifier', 'interpretable'], 'post': []}, 'properties': { 'input_fit': _input_schema_fit, 'input_predict': _input_schema_predict, 'output_predict': _output_predict_schema, 'hyperparams': _hyperparams_schema } } if __name__ == "__main__": lale.helpers.validate_is_schema(_combined_schemas) from lale.operators import make_operator ArulesCBAClassifier = make_operator(ArulesCBAClassifier_Impl, _combined_schemas)