{ "cells": [ { "cell_type": "markdown", "id": "368e0ff0-d7f6-4ba2-8891-de2a39db4b8e", "metadata": {}, "source": [ "# ChatGPT Python Code Writer for Analyzing Datasets\n", "With examples using Starbucks location data! ☕️ " ] }, { "cell_type": "markdown", "id": "b4a0410b-805e-4d1a-8758-a9e57691c4cc", "metadata": {}, "source": [ "## Step 1: Import (a lot of) libraries" ] }, { "cell_type": "code", "execution_count": 79, "id": "95e21d55-52e5-4f34-8988-f30bdbb51847", "metadata": {}, "outputs": [], "source": [ "# Since we cannot be certain of what kind of code ChatGPT will come up with, \n", "# we start by installing a bunch of libraries it may use in its response. \n", "# Lots of these are for making pretty maps. \n", "\n", "# Basics\n", "import pandas as pd\n", "import openai\n", "import os\n", "\n", "# Mapping\n", "from geopy import distance\n", "from geopy.geocoders import Nominatim\n", "import geopandas as gpd\n", "from shapely.geometry import Point, Polygon\n", "from geopy.distance import geodesic\n", "import folium\n", "\n", "# Charts\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "# Parsing text, displaying results in markdown\n", "from IPython.display import display, Markdown, Latex\n", "import re" ] }, { "cell_type": "markdown", "id": "774cbaf9-02ff-49fe-9d49-2dfa51799ee6", "metadata": {}, "source": [ "## Step 2: Set up your [OpenAI key](https://platform.openai.com/)" ] }, { "cell_type": "code", "execution_count": 80, "id": "c3f6aae9-33a5-4559-a49a-ae9812a17664", "metadata": {}, "outputs": [], "source": [ "## OPENAI KEY ##\n", "openai.api_key = os.environ.get('OPENAI_KEY')" ] }, { "cell_type": "markdown", "id": "fd087429-5278-45ca-86e5-3f11f8b48fd3", "metadata": {}, "source": [ "## Step 3. Load your data and parse inputs for the OpenAI query\n", "This sample dataset is a directory of Starbucks locations, scraped from the Starbucks store locator webpage by Github user [chrismeller](https://github.com/chrismeller/). Geospatial coordinates have been truncated, just in case..." ] }, { "cell_type": "code", "execution_count": 81, "id": "6e27b49f-3637-4776-b2a3-9f019966d618", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0.1Unnamed: 0BrandStore NumberStore NameOwnership TypeStreet AddressCityState/ProvinceCountryPostcodePhone NumberTimezoneLongitudeLatitude
000Starbucks47370-257954Meritxell, 96LicensedAv. Meritxell, 96Andorra la Vella7ADAD500376818720GMT+1:00 Europe/Andorra1.5342.51
1111Starbucks1579-122101HCT Abu Dhabi Women's College BlockLicensedNajda Street, Higher Colleges of TechnologyAbu DhabiAZAE316726426280GMT+04:00 Asia/Dubai54.3724.49
2212Starbucks32595-122105Standard Chartered BuildingLicensedKhalidiya St., Beside Union Cooperative SocietyAbu DhabiAZAE316726359275GMT+04:00 Asia/Muscat55.6924.19
3320Starbucks32767-131566Shangri-La SouqLicensedShangri-La Souk, Um Al NarAbu DhabiAZAE316725581641GMT+04:00 Asia/Dubai54.5124.42
4445Starbucks32640-131563Tawam HospitalLicensedAl Ain Abu Dhabi Rd, Khalifa Bin Zayed, Al Mak...Al AinAZAE316737677581GMT+04:00 Asia/Muscat55.6524.19
\n", "
" ], "text/plain": [ " Unnamed: 0.1 Unnamed: 0 Brand Store Number \\\n", "0 0 0 Starbucks 47370-257954 \n", "1 1 11 Starbucks 1579-122101 \n", "2 2 12 Starbucks 32595-122105 \n", "3 3 20 Starbucks 32767-131566 \n", "4 4 45 Starbucks 32640-131563 \n", "\n", " Store Name Ownership Type \\\n", "0 Meritxell, 96 Licensed \n", "1 HCT Abu Dhabi Women's College Block Licensed \n", "2 Standard Chartered Building Licensed \n", "3 Shangri-La Souq Licensed \n", "4 Tawam Hospital Licensed \n", "\n", " Street Address City \\\n", "0 Av. Meritxell, 96 Andorra la Vella \n", "1 Najda Street, Higher Colleges of Technology Abu Dhabi \n", "2 Khalidiya St., Beside Union Cooperative Society Abu Dhabi \n", "3 Shangri-La Souk, Um Al Nar Abu Dhabi \n", "4 Al Ain Abu Dhabi Rd, Khalifa Bin Zayed, Al Mak... Al Ain \n", "\n", " State/Province Country Postcode Phone Number Timezone \\\n", "0 7 AD AD500 376818720 GMT+1:00 Europe/Andorra \n", "1 AZ AE 3167 26426280 GMT+04:00 Asia/Dubai \n", "2 AZ AE 3167 26359275 GMT+04:00 Asia/Muscat \n", "3 AZ AE 3167 25581641 GMT+04:00 Asia/Dubai \n", "4 AZ AE 3167 37677581 GMT+04:00 Asia/Muscat \n", "\n", " Longitude Latitude \n", "0 1.53 42.51 \n", "1 54.37 24.49 \n", "2 55.69 24.19 \n", "3 54.51 24.42 \n", "4 55.65 24.19 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Read in your data (can be multiple datasets)\n", "data = pd.read_csv('../data/directory.csv')\n", "\n", "# Generate a list of columns for each dataset, for later use in generating code\n", "columns = list(data)\n", "\n", "# If you can do it under the API token limit, generate head (or synthetic head) data for each file, to give the OpenAI API context. \n", "# Store as a .json so it transmits properly.\n", "head = data.head()\n", "json_head = head.to_json(orient='records')\n", "\n", "# Displaying the data head so you have a sense of what's in it.\n", "display(head)" ] }, { "cell_type": "markdown", "id": "30d4d983-de2e-41ef-9b70-b312fd632ea7", "metadata": {}, "source": [ "## Step 4. Prime the AI -- define the role and response parameters for the AI" ] }, { "cell_type": "code", "execution_count": 105, "id": "b09c4542-ed98-4a5b-a8f3-cca99daa20b5", "metadata": {}, "outputs": [], "source": [ "# Prepare your base prompt, which tells the OpenAI oracale what its role is and any\n", "# information it should use to prepare its response. \n", "\n", "BASE_PROMPT = [\n", " {\"role\": \"system\", \"content\": f\"You are a python programmer. Write a program that uses the column names\"\n", " f\"--{columns}-- from the Starbucks dataframe that can be used to answer the question. If helpful, here\"\n", " f\"are sample data: {json_head}. The data can be found here: ../data/directory.csv\"\n", " }]\n", "\n", "# Set up global variables so that you can ask follow up questions and receive answers based on prior responses.\n", "messages = []\n", "messages += BASE_PROMPT\n", "message_response = \"x\"" ] }, { "cell_type": "markdown", "id": "ee9ae9fa-08cd-4166-8be2-4b21ef3efe10", "metadata": {}, "source": [ "## Step 5. Write a function for sending base prompt and question to OpenAI and storing response." ] }, { "cell_type": "code", "execution_count": 114, "id": "239eede2-18d8-4289-94e3-0dc9ddda576e", "metadata": {}, "outputs": [], "source": [ "def question (prompt):\n", " global messages\n", " global message_response\n", " \n", " # Add user prompt to messages.\n", " messages += [{\"role\": \"user\", \"content\": prompt}]\n", " \n", " # Call the OpenAI API and parse the response.\n", " response = openai.ChatCompletion.create(model=\"gpt-3.5-turbo\",messages=messages)\n", " message_response = response[\"choices\"][0][\"message\"][\"content\"]\n", " \n", " # Store response for follow up questions. \n", " messages += [{'role':'assistant', 'content':message_response}]\n", "\n", " # Display the question and results using Markdown\n", " display (Markdown(\"### \" + prompt))\n", " display(Markdown(message_response))\n" ] }, { "cell_type": "markdown", "id": "bafd98d2-96da-4f35-9e73-f4d252c758f6", "metadata": {}, "source": [ "## Step 6. Ask a question!" ] }, { "cell_type": "code", "execution_count": 107, "id": "3536fb34-e337-4874-9a79-e932647f3cbc", "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "### How many Starbucks are in the EU?" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/markdown": [ "To answer this question using the given dataframe, we need to filter the data where the \"Country\" column is in the EU countries list, and then count the number of rows. Here's the code to do this:\n", "\n", "```python\n", "import pandas as pd\n", "\n", "# Read the Starbucks dataframe\n", "starbucks_df = pd.read_csv('../data/directory.csv')\n", "\n", "# List of EU countries\n", "eu_countries = ['AT', 'BE', 'BG', 'CY', 'CZ', 'DE', 'DK', 'EE', 'ES', 'FI', 'FR', 'GR', 'HR', 'HU', 'IE', 'IT', 'LT', 'LU', 'LV', 'MT', 'NL', 'PL', 'PT', 'RO', 'SE', 'SI', 'SK']\n", "\n", "# Filter Starbucks stores in the EU\n", "starbucks_eu_df = starbucks_df[starbucks_df['Country'].isin(eu_countries) & (starbucks_df['Brand'] == 'Starbucks')]\n", "\n", "# Count the number of Starbucks stores in the EU\n", "num_of_starbucks_in_eu = len(starbucks_eu_df)\n", "\n", "# Print the result\n", "print(\"There are\", num_of_starbucks_in_eu, \"Starbucks stores in the EU.\")\n", "```\n", "\n", "Note that this code assumes that the \"Country\" column contains the two-letter country codes, such as \"AD\" for Andorra, \"AE\" for United Arab Emirates, etc." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "prompt = \"How many Starbucks are in the EU?\"\n", "question (prompt)" ] }, { "cell_type": "markdown", "id": "a88bd641-b18f-43d2-8478-d5186e03807e", "metadata": {}, "source": [ "### Optional: Ask a follow up question. " ] }, { "cell_type": "code", "execution_count": 111, "id": "8ff3df0e-33aa-4568-8602-e54c12289067", "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "### Can you turn the response into a chart, by country and using colors and fonts that are used in the Starbucks logo?" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/markdown": [ "Sure! Here's the modified code that includes a horizontal bar chart with colors and fonts based on the Starbucks logo, grouped by country:\n", "\n", "```python\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "# Read the Starbucks dataframe\n", "starbucks_df = pd.read_csv('../data/directory.csv')\n", "\n", "# List of EU countries\n", "eu_countries = ['AT', 'BE', 'BG', 'CY', 'CZ', 'DE', 'DK', 'EE', 'ES', 'FI', 'FR', 'GR', 'HR', 'HU', 'IE', 'IT', 'LT', 'LU', 'LV', 'MT', 'NL', 'PL', 'PT', 'RO', 'SE', 'SI', 'SK']\n", "\n", "# Filter Starbucks stores in the EU\n", "starbucks_eu_df = starbucks_df[starbucks_df['Country'].isin(eu_countries) & (starbucks_df['Brand'] == 'Starbucks')]\n", "\n", "# Count the number of Starbucks stores by country\n", "starbucks_by_country = starbucks_eu_df.groupby('Country')['Brand'].count().reset_index().rename(columns={'Brand': 'Num_of_starbucks'})\n", "\n", "# Sort the data by number of stores\n", "starbucks_by_country = starbucks_by_country.sort_values('Num_of_starbucks', ascending=False)\n", "\n", "# Set the color scheme and font\n", "sns.set_palette(sns.color_palette([\"#00704A\"]))\n", "sns.set(font=\"Helvetica\")\n", "\n", "# Create the bar chart\n", "plt.figure(figsize=(12,8))\n", "sns.barplot(x='Num_of_starbucks', y='Country', data=starbucks_by_country, color=\"#00704A\")\n", "plt.title('Number of Starbucks Stores in the EU by Country', fontsize=18, fontweight='bold')\n", "plt.xlabel('Number of Stores', fontsize=14)\n", "plt.ylabel('Country', fontsize=14)\n", "plt.xticks(fontsize=14)\n", "plt.yticks(fontsize=14)\n", "plt.show()\n", "```\n", "\n", "This code groups the data by country using the `groupby()` function, and creates a horizontal bar chart using `barplot()` function. The color code `#00704A` is the main green color used in the Starbucks logo, and the chart is sorted by the number of stores in each country. The final chart will look like this:\n", "\n", "![Starbucks chart by country](https://i.imgur.com/XEboX9P.png)\n", "\n", "I hope this helps!" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Optional follow up question. You can add these until you run out of tokens. \n", "prompt = \"Can you turn the response into a chart, by country and using colors and fonts that are used in the Starbucks logo?\"\n", "question(prompt)" ] }, { "cell_type": "markdown", "id": "f92b2d6f-3ba7-4067-9073-57df9eda93f1", "metadata": {}, "source": [ "### Pause!\n", "Sometimes the response will try to answer the prompt, in addition to providing the code. These answers are often (but not always) nonsense! Make sure you review and run the code yourself to be certain of the result. \n", "\n", "Now, here is what is really exciting about OpenAI: the API combines information about the dataset with its own knowledge. For example, when asking about the number of Starbucks in each EU country, the AI will generate a list of countries to use. Or, if asking the AI to use Starbucks colors for charts, it will know to use infamous Starbucks green. 😃 Amazing, and also very risky -- always double check the 'additional knowledge' being provided. " ] }, { "cell_type": "markdown", "id": "b372cb4b-614d-4e1c-aba5-b4709c67def2", "metadata": {}, "source": [ "## Step 7. Parse the OpenAI response to pull out the Python code." ] }, { "cell_type": "code", "execution_count": 112, "id": "ea7ba57b-e2bd-4be6-b294-7fafe892c701", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "import matplotlib.pyplot as plt\n", "# Read the Starbucks dataframe\n", "starbucks_df = pd.read_csv('../data/directory.csv')\n", "\n", "# List of EU countries\n", "eu_countries = ['AT', 'BE', 'BG', 'CY', 'CZ', 'DE', 'DK', 'EE', 'ES', 'FI', 'FR', 'GR', 'HR', 'HU', 'IE', 'IT', 'LT', 'LU', 'LV', 'MT', 'NL', 'PL', 'PT', 'RO', 'SE', 'SI', 'SK']\n", "\n", "# Filter Starbucks stores in the EU\n", "starbucks_eu_df = starbucks_df[starbucks_df['Country'].isin(eu_countries) & (starbucks_df['Brand'] == 'Starbucks')]\n", "\n", "# Count the number of Starbucks stores by country\n", "starbucks_by_country = starbucks_eu_df.groupby('Country')['Brand'].count().reset_index().rename(columns={'Brand': 'Num_of_starbucks'})\n", "\n", "# Sort the data by number of stores\n", "starbucks_by_country = starbucks_by_country.sort_values('Num_of_starbucks', ascending=False)\n", "\n", "# Set the color scheme and font\n", "sns.set_palette(sns.color_palette([\"#00704A\"]))\n", "sns.set(font=\"Helvetica\")\n", "\n", "# Create the bar chart\n", "plt.figure(figsize=(12,8))\n", "sns.barplot(x='Num_of_starbucks', y='Country', data=starbucks_by_country, color=\"#00704A\")\n", "plt.title('Number of Starbucks Stores in the EU by Country', fontsize=18, fontweight='bold')\n", "plt.xlabel('Number of Stores', fontsize=14)\n", "plt.ylabel('Country', fontsize=14)\n", "plt.xticks(fontsize=14)\n", "plt.yticks(fontsize=14)\n", "plt.show()\n" ] } ], "source": [ "# Now, we parse the ChatGPT response to pull out the Python code. Full disclosure:\n", "# I used ChatGPT to write this code, so don't ask me too many questions about it.\n", "\n", "text = message_response\n", "\n", "# Define the regular expression pattern to match the Python code\n", "pattern = r\"```(?:python)?\\n([\\s\\S]*?)\\n```\"\n", "\n", "# Extract all Python code blocks from the text\n", "python_blocks = re.findall(pattern, text)\n", "\n", "# Combine all Python code blocks into a single string\n", "python_code = \"\\n\".join(python_blocks)\n", "\n", "# Remove import statements from the Python code\n", "python_code = re.sub(r\"(?:from\\s+\\w+(?:\\.\\w+)*\\s+)?import\\s+\\w+(\\s+as\\s+\\w+)?(?:,\\s*\\w+(\\s+as\\s+\\w+)?)*\\s*\\n\", \"\", python_code)\n", "print (python_code)\n" ] }, { "cell_type": "markdown", "id": "c3804423-2daf-4b9a-aaf6-ac6a961d5a9e", "metadata": {}, "source": [ "## Step 8. Run the code (DANGER!)" ] }, { "cell_type": "code", "execution_count": 113, "id": "46b21445-b5b3-452c-8930-b5440332f318", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# The following code will execute the OPENAI-generated Python scripts. CHECK THE CODE FIRST!\n", "\n", "try:\n", " exec(python_code)\n", "except FileNotFoundError:\n", " print(\"FileNotFoundError: directory.csv not found in directory:\", os.getcwd())\n", "except Exception as e:\n", " print(\"Error:\", e)" ] }, { "cell_type": "code", "execution_count": null, "id": "dce8e2f4-063c-407d-b5dc-8ed968465a75", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }