# Copyright 2019 IBM Corporation
#
# Licensed under the GNU General Public License 3.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.txt
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import csv
import logging
import os
import urllib.request
import mysql.connector
import pandas as pd
from lale.datasets.data_schemas import add_table_name
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
try:
from pyspark.sql import SparkSession
spark_installed = True
except ImportError:
spark_installed = False
imdb_config = {
"user": "guest",
"password": "relational",
"host": "relational.fit.cvut.cz",
"database": "imdb_ijs",
"port": 3306,
"raise_on_warnings": True,
}
[docs]def get_data_from_csv(datatype, data_file_name):
if datatype.casefold() == "pandas":
return pd.read_csv(data_file_name)
elif datatype.casefold() == "spark":
if spark_installed:
spark = SparkSession.builder.appName("GetDataset").getOrCreate()
return spark.read.options(inferSchema="True", delimiter=",").csv(data_file_name, header=True)
else:
raise ValueError("Spark is not installed on this machine.")
else:
raise ValueError(
"Can fetch the go_sales data in pandas or spark dataframes only. Pass either 'pandas' or 'spark' in datatype parameter."
)
[docs]def fetch_imdb_dataset(datatype="pandas"):
"""
Fetches the IMDB movie dataset from Relational Dataset Repo.
It contains information about directors, actors, roles
and genres of multiple movies in form of 7 CSV files.
This method downloads and stores these 7 CSV files under the
'lale/lale/datasets/multitable/imdb_data' directory. It creates
this directory by itself if it does not exists.
Dataset URL: https://relational.fit.cvut.cz/dataset/IMDb
Parameters
----------
datatype : string, optional, default 'pandas'
If 'pandas',
Returns a list of singleton dictionaries (each element of the list is one
table from the dataset) after reading the downloaded / existing CSV files.
The key of each dictionary is the name of the table and the value contains
a pandas dataframe consisting of the data.
If 'spark',
Returns a list of singleton dictionaries (each element of the list is one
table from the dataset) after reading the downloaded / existing CSV files.
The key of each dictionary is the name of the table and the value contains
a spark dataframe consisting of the data.
Else,
Throws an error as it does not support any other return type.
Returns
-------
imdb_list : list of singleton dictionary of pandas / spark dataframes
"""
try:
cnx = mysql.connector.connect(**imdb_config)
cursor = cnx.cursor()
imdb_table_list = []
download_data_dir = os.path.join(os.path.dirname(__file__), "imdb_data")
imdb_list = []
cursor.execute("show tables")
for table in cursor:
imdb_table_list.append(table[0])
for table in imdb_table_list:
header_list = []
cursor.execute("desc {}".format(table))
for column in cursor:
header_list.append(column[0])
csv_name = "{}.csv".format(table)
data_file_name = os.path.join(download_data_dir, csv_name)
if not os.path.exists(data_file_name):
if not os.path.exists(download_data_dir):
os.makedirs(download_data_dir)
cursor.execute("select * from {}".format(table))
result = cursor.fetchall()
file = open(data_file_name, "w", encoding="utf-8")
c = csv.writer(file)
c.writerow(header_list)
for row in result:
c.writerow(row)
file.close()
logger.info(" Created:{}".format(data_file_name))
table_name = csv_name.split(".")[0]
data_frame = get_data_from_csv(datatype, data_file_name)
imdb_list.append(add_table_name(data_frame, table_name))
logger.info(" Fetched the IMDB dataset. Process completed.")
return imdb_list
except mysql.connector.Error as err:
raise ValueError(err)
else:
cnx.close()