diff --git a/lecture_material/21-classification/21-classification.ipynb b/lecture_material/21-classification/21-classification.ipynb index f04482d33881d52351d29e1931ca24ea54d2f7f5..c51fb06e486a453c162321caf4e732ec1db3ce94 100644 --- a/lecture_material/21-classification/21-classification.ipynb +++ b/lecture_material/21-classification/21-classification.ipynb @@ -1337,7 +1337,7 @@ { "data": { "text/plain": [ - "[<matplotlib.lines.Line2D at 0x7f5a78793820>]" + "[<matplotlib.lines.Line2D at 0x7f504c7f2850>]" ] }, "execution_count": 22, @@ -2561,7 +2561,7 @@ { "data": { "text/plain": [ - "<matplotlib.contour.QuadContourSet at 0x7f5a76693040>" + "<matplotlib.contour.QuadContourSet at 0x7f504a6ec340>" ] }, "execution_count": 52, @@ -2593,7 +2593,7 @@ { "data": { "text/plain": [ - "<matplotlib.contour.QuadContourSet at 0x7f5a764ddfa0>" + "<matplotlib.contour.QuadContourSet at 0x7f504a5bb9d0>" ] }, "execution_count": 53, @@ -2770,6 +2770,14 @@ "df[xcols]" ] }, + { + "cell_type": "markdown", + "id": "22058e9c-c601-462f-9d81-6ce9ffb02248", + "metadata": {}, + "source": [ + "#### `predictions = F(sepl, sepw)`" + ] + }, { "cell_type": "code", "execution_count": 56, @@ -2794,6 +2802,7 @@ } ], "source": [ + "# Creating range of values for sepl (X) and sepw (Y)\n", "sepl, sepw = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))\n", "sepl" ] @@ -2914,6 +2923,7 @@ } ], "source": [ + "# Predicting setosa (True / False) labels using LogisticRegression model\n", "Z_predictions = cls_model.predict(predict_df)\n", "Z_predictions" ] @@ -2949,7 +2959,7 @@ { "data": { "text/plain": [ - "<matplotlib.contour.QuadContourSet at 0x7f5a76379490>" + "<matplotlib.contour.QuadContourSet at 0x7f504a450b20>" ] }, "execution_count": 60, @@ -3007,6 +3017,7 @@ } ], "source": [ + "# Creating contourf plot\n", "plt.contourf(sepl, sepw, Z_predictions.reshape(sepl.shape))\n", "ax = plt.gca() # get current axes subplot\n", "df[df[\"setosa\"]].plot.scatter(x=\"sepal length (cm)\", y=\"sepal width (cm)\", \\\n", @@ -3084,6 +3095,8 @@ } ], "source": [ + "# we need numeric values for Z\n", + "# hence .index usage instead of actual values of variety predictions\n", "predictions = np.array([classes_.index(name) for name in mult_model.predict(predict_df)])\n", "predictions" ] @@ -3186,12 +3199,14 @@ } ], "source": [ + "# Creating range of values for sepl (X) and sepw (Y)\n", "sepl, sepw = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))\n", "predict_df = pd.DataFrame({\n", " 'sepal length (cm)': sepl.reshape(-1),\n", " 'sepal width (cm)': sepw.reshape(-1),\n", " 'const': 1\n", "})\n", + "# Predicting setosa (True / False) labels using LogisticRegression model\n", "Z_predictions = model.predict(predict_df)\n", "Z_predictions" ] @@ -3224,6 +3239,7 @@ } ], "source": [ + "# Creating contourf plot\n", "plt.contourf(sepl, sepw, Z_predictions.reshape(sepl.shape))\n", "ax = plt.gca()\n", "df[df[\"setosa\"]].plot.scatter(x=\"sepal length (cm)\", y=\"sepal width (cm)\", \\\n", @@ -3233,6 +3249,14 @@ " color=\"0.8\", label=\"Not Setosa\")" ] }, + { + "cell_type": "markdown", + "id": "61357da5-900a-4c3f-ae56-c41c0958cf33", + "metadata": {}, + "source": [ + "#### Multi-classification model to predict `variety` using `Pipeline` of `PolynomialFeatures` and `LogisticRegression` models" + ] + }, { "cell_type": "code", "execution_count": 69, @@ -3355,6 +3379,7 @@ } ], "source": [ + "# Predicting variety labels using LogisticRegression model\n", "predictions = np.array([classes_.index(name) for name in model.predict(predict_df)])\n", "predictions" ] @@ -3387,6 +3412,7 @@ } ], "source": [ + "# Creating contourf plot\n", "plt.contourf(sepl, sepw, predictions.reshape(sepl.shape))\n", "ax = plt.gca() # get current axes\n", "df[df[\"variety\"] == \"setosa\"].plot.scatter(x=\"sepal length (cm)\", y=\"sepal width (cm)\", \\\n", @@ -4104,6 +4130,16 @@ "accuracy_score(test[\"y\"], model.predict(test[[\"x\"]]))" ] }, + { + "cell_type": "markdown", + "id": "b1149507-b365-45cd-8d8b-e2907e89cbcc", + "metadata": {}, + "source": [ + "What are the range of values of `accuracy_score` for a classification model?\n", + "- `0 to 1` because it is a fraction of predicted labels / actual labels => remember you can either get it correct or wrong.\n", + "- Recall as opposed to classification, for regression range for R^2 score is `-infinity to 1` => that is because the model can introduce non-existent variance within the data while trying to fit it!" + ] + }, { "cell_type": "markdown", "id": "1bdc164a", @@ -4112,6 +4148,23 @@ "#### Confusion Matrices" ] }, + { + "cell_type": "markdown", + "id": "a1c52e70-b134-49da-b230-f03056752ecb", + "metadata": {}, + "source": [ + "#### `confusion_matrix(y_true, y_pred)`\n", + "\n", + "- computes confusion matrix for classification:\n", + " - row dimension represents actual value\n", + " - column dimension represents predicted value\n", + "- documentation: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html\n", + "\n", + "<div>\n", + "<img src=\"Confusion_matrix.png\" width=\"500\"/>\n", + "</div>" + ] + }, { "cell_type": "code", "execution_count": 95, @@ -4148,19 +4201,6 @@ "confusion_matrix(actual, predicted)" ] }, - { - "cell_type": "markdown", - "id": "3efa94b0", - "metadata": {}, - "source": [ - "#### `confusion_matrix(y_true, y_pred)`\n", - "\n", - "- computes confusion matrix for classification:\n", - " - row dimension represents actual value\n", - " - column dimension represents predicted value\n", - "- documentation: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html" - ] - }, { "cell_type": "code", "execution_count": 97, @@ -4182,6 +4222,7 @@ } ], "source": [ + "# notice that we have a \"horse\" label even though the data doesn't have any info about horse\n", "labels = [\"dog\", \"cat\", \"mouse\", \"horse\"]\n", "cm = confusion_matrix(actual, predicted, labels=labels)\n", "cm" @@ -4272,9 +4313,658 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 99, "id": "0614ec05-3271-4501-a119-af413c5f68d3", "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>dog</th>\n", + " <th>cat</th>\n", + " <th>mouse</th>\n", + " <th>horse</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>dog</th>\n", + " <td>796</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>cat</th>\n", + " <td>398</td>\n", + " <td>398</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>mouse</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>398</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>horse</th>\n", + " <td>9</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " dog cat mouse horse\n", + "dog 796 0 0 0\n", + "cat 398 398 0 0\n", + "mouse 0 0 398 0\n", + "horse 9 0 0 1" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "actual = [\"dog\", \"dog\", \"dog\", \"dog\", \"cat\", \"cat\", \"cat\", \"cat\", \"mouse\", \"mouse\"] * 199\n", + "predicted = [\"dog\", \"dog\", \"dog\", \"dog\", \"cat\", \"dog\", \"cat\", \"dog\", \"mouse\", \"mouse\"] * 199\n", + "actual += [\"horse\"] * 10\n", + "predicted += [\"dog\"] * 9 + [\"horse\"]\n", + "\n", + "labels = [\"dog\", \"cat\", \"mouse\", \"horse\"]\n", + "cm = confusion_matrix(actual, predicted, labels=labels)\n", + "cm = pd.DataFrame(cm, index=labels, columns=labels)\n", + "cm" + ] + }, + { + "cell_type": "markdown", + "id": "54b57291-fa91-47a3-9d5b-7c054d416314", + "metadata": {}, + "source": [ + "### Recall and balanced accuracy score\n", + "\n", + "- import statement:\n", + "```python\n", + "from sklearn.metrics import recall_score, precision_score, balanced_accuracy_score\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "cd33658e-b99d-4a6f-b884-c01e67104a13", + "metadata": {}, + "source": [ + "#### Recall: row-wise ratio\n", + "\n", + "- What proportion of actual positives was identified correctly?" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "68f819ad-1f04-4868-8c5e-cd656fa4804a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5" + ] + }, + "execution_count": 100, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# what is the recall for cat?\n", + "cm.at[\"cat\", \"cat\"] / cm.loc[\"cat\", :].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "id": "1174544a-e1fd-403e-a093-27200f015c62", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[101], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# There are multiple recall scores as we have multiple labels\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# So, we need to pass argument to parameter \"average\"\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[43mrecall_score\u001b[49m\u001b[43m(\u001b[49m\u001b[43mactual\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpredicted\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/utils/_param_validation.py:214\u001b[0m, in \u001b[0;36mvalidate_params.<locals>.decorator.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 209\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 210\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 211\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 212\u001b[0m )\n\u001b[1;32m 213\u001b[0m ):\n\u001b[0;32m--> 214\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 216\u001b[0m \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[1;32m 219\u001b[0m \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[1;32m 220\u001b[0m msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[1;32m 221\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 222\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 223\u001b[0m \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m 224\u001b[0m )\n", + "File \u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/metrics/_classification.py:2304\u001b[0m, in \u001b[0;36mrecall_score\u001b[0;34m(y_true, y_pred, labels, pos_label, average, sample_weight, zero_division)\u001b[0m\n\u001b[1;32m 2144\u001b[0m \u001b[38;5;129m@validate_params\u001b[39m(\n\u001b[1;32m 2145\u001b[0m {\n\u001b[1;32m 2146\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my_true\u001b[39m\u001b[38;5;124m\"\u001b[39m: [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124marray-like\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msparse matrix\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2171\u001b[0m zero_division\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwarn\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2172\u001b[0m ):\n\u001b[1;32m 2173\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Compute the recall.\u001b[39;00m\n\u001b[1;32m 2174\u001b[0m \n\u001b[1;32m 2175\u001b[0m \u001b[38;5;124;03m The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2302\u001b[0m \u001b[38;5;124;03m array([1. , 1. , 0.5])\u001b[39;00m\n\u001b[1;32m 2303\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 2304\u001b[0m _, r, _, _ \u001b[38;5;241m=\u001b[39m \u001b[43mprecision_recall_fscore_support\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2305\u001b[0m \u001b[43m \u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2306\u001b[0m \u001b[43m \u001b[49m\u001b[43my_pred\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2307\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2308\u001b[0m \u001b[43m \u001b[49m\u001b[43mpos_label\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpos_label\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2309\u001b[0m \u001b[43m \u001b[49m\u001b[43maverage\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2310\u001b[0m \u001b[43m \u001b[49m\u001b[43mwarn_for\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrecall\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2311\u001b[0m \u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2312\u001b[0m \u001b[43m \u001b[49m\u001b[43mzero_division\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mzero_division\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2313\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m r\n", + "File \u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/utils/_param_validation.py:187\u001b[0m, in \u001b[0;36mvalidate_params.<locals>.decorator.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 185\u001b[0m global_skip_validation \u001b[38;5;241m=\u001b[39m get_config()[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mskip_parameter_validation\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m global_skip_validation:\n\u001b[0;32m--> 187\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 189\u001b[0m func_sig \u001b[38;5;241m=\u001b[39m signature(func)\n\u001b[1;32m 191\u001b[0m \u001b[38;5;66;03m# Map *args/**kwargs to the function signature\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1724\u001b[0m, in \u001b[0;36mprecision_recall_fscore_support\u001b[0;34m(y_true, y_pred, beta, labels, pos_label, average, warn_for, sample_weight, zero_division)\u001b[0m\n\u001b[1;32m 1566\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Compute precision, recall, F-measure and support for each class.\u001b[39;00m\n\u001b[1;32m 1567\u001b[0m \n\u001b[1;32m 1568\u001b[0m \u001b[38;5;124;03mThe precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1721\u001b[0m \u001b[38;5;124;03m array([2, 2, 2]))\u001b[39;00m\n\u001b[1;32m 1722\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1723\u001b[0m zero_division_value \u001b[38;5;241m=\u001b[39m _check_zero_division(zero_division)\n\u001b[0;32m-> 1724\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[43m_check_set_wise_labels\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_pred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maverage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpos_label\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1726\u001b[0m \u001b[38;5;66;03m# Calculate tp_sum, pred_sum, true_sum ###\u001b[39;00m\n\u001b[1;32m 1727\u001b[0m samplewise \u001b[38;5;241m=\u001b[39m average \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/.local/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1518\u001b[0m, in \u001b[0;36m_check_set_wise_labels\u001b[0;34m(y_true, y_pred, average, labels, pos_label)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m y_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmulticlass\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 1517\u001b[0m average_options\u001b[38;5;241m.\u001b[39mremove(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTarget is \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m but average=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbinary\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m. Please \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1520\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mchoose another average setting, one of \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (y_type, average_options)\n\u001b[1;32m 1521\u001b[0m )\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m pos_label \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m1\u001b[39m):\n\u001b[1;32m 1523\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNote that pos_label (set to \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) is ignored when \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maverage != \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbinary\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m (got \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m). You may use \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1528\u001b[0m \u001b[38;5;167;01mUserWarning\u001b[39;00m,\n\u001b[1;32m 1529\u001b[0m )\n", + "\u001b[0;31mValueError\u001b[0m: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted']." + ] + } + ], + "source": [ + "# There are multiple recall scores as we have multiple labels\n", + "# So, we need to pass argument to parameter \"average\"\n", + "recall_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "84bb10a8-8704-4ea6-b5ab-571918887fd0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.5, 1. , 0.1, 1. ])" + ] + }, + "execution_count": 102, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recall_score(actual, predicted, average=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "0f5b8965-46c8-49e7-912c-9f2e130099b4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['dog', 'cat', 'mouse', 'horse']\n" + ] + }, + { + "data": { + "text/plain": [ + "array([1. , 0.5, 1. , 0.1])" + ] + }, + "execution_count": 103, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# How can we identify which score is for which label?\n", + "print(labels)\n", + "# We can pass a list of labels argument to parameter \"labels\"\n", + "recall_score(actual, predicted, average=None, labels=labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "id": "be2b9c51-6de6-44ff-987b-d4a625f94788", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>dog</th>\n", + " <th>cat</th>\n", + " <th>mouse</th>\n", + " <th>horse</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>dog</th>\n", + " <td>796</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>cat</th>\n", + " <td>398</td>\n", + " <td>398</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>mouse</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>398</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>horse</th>\n", + " <td>9</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " dog cat mouse horse\n", + "dog 796 0 0 0\n", + "cat 398 398 0 0\n", + "mouse 0 0 398 0\n", + "horse 9 0 0 1" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cm" + ] + }, + { + "cell_type": "markdown", + "id": "2c33a4c8-e724-46fb-bd06-76a2ba5bde20", + "metadata": {}, + "source": [ + "How does average recall score compare against accuracy score?" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "id": "24e0f605-22e0-476b-a3a2-67a1752b4345", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.7965" + ] + }, + "execution_count": 105, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "bde42d32-6c63-4693-be68-a8827cf42906", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.65" + ] + }, + "execution_count": 106, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# will this be bigger or smaller than accuracy?\n", + "recall_score(actual, predicted, average=None, labels=labels).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "5f143049-0f26-41a6-8a7f-e607c04e4d43", + "metadata": {}, + "source": [ + "We are saying \"horse\" is equally important as other animals by taking an average of recall. So average recall score is lower than overall accuracy." + ] + }, + { + "cell_type": "markdown", + "id": "efbe97a3-da15-41a6-81c2-b0f302e103bd", + "metadata": {}, + "source": [ + "#### Average recall score is \"Balanced accuracy score\"" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "id": "6568c39a-4c07-47ec-a2b7-7f0b6257ca6d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.65" + ] + }, + "execution_count": 107, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# average of recall scores is called balanced accuracy score\n", + "balanced_accuracy_score(actual, predicted)" + ] + }, + { + "cell_type": "markdown", + "id": "d6e88c83-2d4f-4318-beeb-609f97e451dd", + "metadata": {}, + "source": [ + "Why does \"Balanced accuracy score\" matter?\n", + "\n", + "- Imagine you are building a new covid test. We know majority of the times covid test is supposed to be negative.\n", + "- So, you might get a high accuracy even if your covid test is missing actual positives.\n", + "- If you just look at accuracy, that might be misleading. So, you must also look at \"Balanced accuracy score\"." + ] + }, + { + "cell_type": "markdown", + "id": "2d253246-9111-4687-bd3e-810df3726d2b", + "metadata": {}, + "source": [ + "### Precision: column-wise ratio\n", + "\n", + "- What proportion of positive identifications was actually correct?" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "id": "f1f22bed-04a2-4eb5-b553-2ea1ed899f8c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>dog</th>\n", + " <th>cat</th>\n", + " <th>mouse</th>\n", + " <th>horse</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>dog</th>\n", + " <td>796</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>cat</th>\n", + " <td>398</td>\n", + " <td>398</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>mouse</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>398</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>horse</th>\n", + " <td>9</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " dog cat mouse horse\n", + "dog 796 0 0 0\n", + "cat 398 398 0 0\n", + "mouse 0 0 398 0\n", + "horse 9 0 0 1" + ] + }, + "execution_count": 108, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cm" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "id": "54a444c1-685c-43df-b9d0-d37855808ac9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['dog', 'cat', 'mouse', 'horse']\n" + ] + }, + { + "data": { + "text/plain": [ + "array([0.66167914, 1. , 1. , 1. ])" + ] + }, + "execution_count": 109, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(labels)\n", + "precision_score(actual, predicted, average=None, labels=labels)" + ] + }, + { + "cell_type": "markdown", + "id": "3bc52a31-f47a-43fa-b67c-9dd87f4db3f7", + "metadata": {}, + "source": [ + "### Binary Classification Metrics\n", + "\n", + "Unless otherwise specified, \"precision\" and \"recall\" refer to those metrics for the positive class when we're doing binary classification.\n", + "\n", + "<div>\n", + "<img src=\"Confusion_matrix_binary.png\" width=\"350\"/>\n", + "</div>" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "id": "63d9fe09-cb50-49b2-8300-6ef5ada9b856", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1, 2],\n", + " [3, 7]])" + ] + }, + "execution_count": 110, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "actual = [False, True, True, True, True, False, False, True, True, True, True, True, True]\n", + "predicted = [False, True, True, True, True, True, True, False, False, False, True, True, True]\n", + "confusion_matrix(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "id": "f452b5f2-4df4-4113-a682-9346aea24ab3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.7 , 0.33333333])" + ] + }, + "execution_count": 111, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recall_score(actual, predicted, average=None, labels=[True, False])" + ] + }, + { + "cell_type": "code", + "execution_count": 112, + "id": "d8505f1e-dbab-49ac-b325-ec2b44191c6a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.7" + ] + }, + "execution_count": 112, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# for binary classification, we have False recall and True recall\n", + "# \"recall\" is shorthand for \"True recall\"\n", + "recall_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "id": "21c2f505-5993-4b8d-bde5-2b3d524a59d6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.77777778, 0.25 ])" + ] + }, + "execution_count": 113, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "precision_score(actual, predicted, average=None, labels=[True, False])" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "id": "f536fa98-0c92-4ce0-bde8-e60364b151e8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.7777777777777778" + ] + }, + "execution_count": 114, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# for binary classification, we have False precision and True precision\n", + "# \"precision\" is shorthand for \"True precision\"\n", + "precision_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fe51d31-addf-4c49-a82c-75d1b3fcb3f3", + "metadata": {}, "outputs": [], "source": [] } diff --git a/lecture_material/21-classification/21-classification_001.ipynb b/lecture_material/21-classification/21-classification_001.ipynb index d27e29fb779edd5d94c20b2b71df07dc5540b35d..7393488c929fc9bc0672dac9b9c6c64975a465fb 100644 --- a/lecture_material/21-classification/21-classification_001.ipynb +++ b/lecture_material/21-classification/21-classification_001.ipynb @@ -1001,6 +1001,14 @@ "df[xcols]" ] }, + { + "cell_type": "markdown", + "id": "639036ad-cdca-4a5a-a920-caba08872ecb", + "metadata": {}, + "source": [ + "#### `predictions = F(sepl, sepw)`" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1008,6 +1016,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating range of values for sepl (X) and sepw (Y)\n", "sepl, sepw = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))\n", "sepl" ] @@ -1042,6 +1051,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Predicting setosa (True / False) labels using LogisticRegression model\n", "Z_predictions = cls_model.predict(predict_df)\n", "Z_predictions" ] @@ -1083,6 +1093,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating contourf plot\n", "plt.contourf(sepl, sepw, Z_predictions.reshape(sepl.shape))\n", "# get current axes subplot\n", "???\n", @@ -1126,6 +1137,8 @@ "metadata": {}, "outputs": [], "source": [ + "# we need numeric values for Z\n", + "# hence .index usage instead of actual values of variety predictions\n", "predictions = np.array([classes_.index(name) for name in mult_model.predict(predict_df)])\n", "predictions" ] @@ -1177,12 +1190,14 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating range of values for sepl (X) and sepw (Y)\n", "sepl, sepw = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))\n", "predict_df = pd.DataFrame({\n", " 'sepal length (cm)': sepl.reshape(-1),\n", " 'sepal width (cm)': sepw.reshape(-1),\n", " 'const': 1\n", "})\n", + "# Predicting setosa (True / False) labels using LogisticRegression model\n", "Z_predictions = model.predict(predict_df)\n", "Z_predictions" ] @@ -1194,6 +1209,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating contourf plot\n", "plt.contourf(sepl, sepw, Z_predictions.reshape(sepl.shape))\n", "ax = plt.gca()\n", "df[df[\"setosa\"]].plot.scatter(x=\"sepal length (cm)\", y=\"sepal width (cm)\", \\\n", @@ -1203,6 +1219,14 @@ " color=\"0.8\", label=\"Not Setosa\")" ] }, + { + "cell_type": "markdown", + "id": "560851c2-348a-4f07-991b-33522464b246", + "metadata": {}, + "source": [ + "#### Multi-classification model to predict `variety` using `Pipeline` of `PolynomialFeatures` and `LogisticRegression` models" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1266,6 +1290,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Predicting variety labels using LogisticRegression model\n", "predictions = np.array([classes_.index(name) for name in model.predict(predict_df)])\n", "predictions" ] @@ -1277,6 +1302,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating contourf plot\n", "plt.contourf(sepl, sepw, predictions.reshape(sepl.shape))\n", "ax = plt.gca() # get current axes\n", "df[df[\"variety\"] == \"setosa\"].plot.scatter(x=\"sepal length (cm)\", y=\"sepal width (cm)\", \\\n", @@ -1636,6 +1662,16 @@ "accuracy_score(ACTUAL, PREDICTED)" ] }, + { + "cell_type": "markdown", + "id": "c21ee7f5-267a-49c6-b3b5-2893aa72d1b5", + "metadata": {}, + "source": [ + "What are the range of values of `accuracy_score` for a classification model?\n", + "- `0 to 1` because it is a fraction of predicted labels / actual labels => remember you can either get it correct or wrong.\n", + "- Recall as opposed to classification, for regression range for R^2 score is `-infinity to 1` => that is because the model can introduce non-existent variance within the data while trying to fit it!" + ] + }, { "cell_type": "markdown", "id": "1bdc164a", @@ -1644,6 +1680,23 @@ "#### Confusion Matrices" ] }, + { + "cell_type": "markdown", + "id": "682885ff-b514-4a4c-8cd5-0b5ea94d5254", + "metadata": {}, + "source": [ + "#### `confusion_matrix(y_true, y_pred)`\n", + "\n", + "- computes confusion matrix for classification:\n", + " - row dimension represents actual value\n", + " - column dimension represents predicted value\n", + "- documentation: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html\n", + "\n", + "<div>\n", + "<img src=\"Confusion_matrix.png\" width=\"500\"/>\n", + "</div>" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1668,38 +1721,294 @@ ] }, { - "cell_type": "markdown", - "id": "3efa94b0", + "cell_type": "code", + "execution_count": null, + "id": "12f0a1ff", "metadata": {}, + "outputs": [], "source": [ - "#### `confusion_matrix(y_true, y_pred)`\n", - "\n", - "- computes confusion matrix for classification:\n", - " - row dimension represents actual value\n", - " - column dimension represents predicted value\n", - "- documentation: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html" + "# notice that we have a \"horse\" label even though the data doesn't have any info about horse\n", + "labels = [\"dog\", \"cat\", \"mouse\", \"horse\"]\n", + "cm = confusion_matrix(actual, predicted, labels=labels)\n", + "cm" ] }, { "cell_type": "code", "execution_count": null, - "id": "12f0a1ff", + "id": "ace44e40", + "metadata": {}, + "outputs": [], + "source": [ + "pd.DataFrame(cm, index=labels, columns=labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79fb6e1d-2b62-43d1-a158-e350f0b86145", "metadata": {}, "outputs": [], "source": [ + "actual = [\"dog\", \"dog\", \"dog\", \"dog\", \"cat\", \"cat\", \"cat\", \"cat\", \"mouse\", \"mouse\"] * 199\n", + "predicted = [\"dog\", \"dog\", \"dog\", \"dog\", \"cat\", \"dog\", \"cat\", \"dog\", \"mouse\", \"mouse\"] * 199\n", + "actual += [\"horse\"] * 10\n", + "predicted += [\"dog\"] * 9 + [\"horse\"]\n", + "\n", "labels = [\"dog\", \"cat\", \"mouse\", \"horse\"]\n", "cm = confusion_matrix(actual, predicted, labels=labels)\n", + "cm = pd.DataFrame(cm, index=labels, columns=labels)\n", + "cm" + ] + }, + { + "cell_type": "markdown", + "id": "c1eb06bb-23cb-4cce-a0e0-ffa586d0bc93", + "metadata": {}, + "source": [ + "### Recall and balanced accuracy score\n", + "\n", + "- import statement:\n", + "```python\n", + "from sklearn.metrics import recall_score, precision_score, balanced_accuracy_score\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "4f1b1be4-e780-4011-89b8-85531cdb5314", + "metadata": {}, + "source": [ + "#### Recall: row-wise ratio\n", + "\n", + "- What proportion of actual positives was identified correctly?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93d57a38-9114-4b09-8f11-b1718e60534a", + "metadata": {}, + "outputs": [], + "source": [ + "# what is the recall for cat?\n", + "cm.at[\"cat\", \"cat\"] / cm.loc[\"cat\", :].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca2dff07-6725-4ec0-a791-861ad3b58a7c", + "metadata": {}, + "outputs": [], + "source": [ + "# There are multiple recall scores as we have multiple labels\n", + "# So, we need to pass argument to parameter \"average\"\n", + "recall_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b16930a8-1a1c-416d-8fe5-04cd45f79749", + "metadata": {}, + "outputs": [], + "source": [ + "recall_score(actual, predicted, average=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70eabe61-5dce-4558-a484-de5d32e6906e", + "metadata": {}, + "outputs": [], + "source": [ + "# How can we identify which score is for which label?\n", + "print(labels)\n", + "# We can pass a list of labels argument to parameter \"labels\"\n", + "recall_score(actual, predicted, average=None, labels=labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "394edab6-08ae-411e-bf63-b01c0a12684f", + "metadata": {}, + "outputs": [], + "source": [ "cm" ] }, { "cell_type": "code", "execution_count": null, - "id": "ace44e40", + "id": "971d96ba-d046-4843-93dd-d173804808dd", "metadata": {}, "outputs": [], "source": [ - "pd.DataFrame(cm, index=labels, columns=labels)" + "How does average recall score compare against accuracy score?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50d8b289-ca1e-454c-b2da-c8c06522fc59", + "metadata": {}, + "outputs": [], + "source": [ + "accuracy_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d73672b-9f19-43cd-baad-447bca2bbe24", + "metadata": {}, + "outputs": [], + "source": [ + "# will this be bigger or smaller than accuracy?\n", + "recall_score(actual, predicted, average=None, labels=labels).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "bf98eeab-8cc2-4ab3-a7a1-cb3e96e4c96b", + "metadata": {}, + "source": [ + "We are saying \"horse\" is equally important as other animals by taking an average of recall. So average recall score is lower than overall accuracy." + ] + }, + { + "cell_type": "markdown", + "id": "926d4d94-7bc3-4614-b582-fc7333e68c40", + "metadata": {}, + "source": [ + "#### Average recall score is \"Balanced accuracy score\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8956e72b-694e-411b-ab6d-b8b2fda0ea0b", + "metadata": {}, + "outputs": [], + "source": [ + "# average of recall scores is called balanced accuracy score\n", + "balanced_accuracy_score(actual, predicted)" + ] + }, + { + "cell_type": "markdown", + "id": "cce370f6-5ba8-4ee2-983b-c511585c8de2", + "metadata": {}, + "source": [ + "Why does \"Balanced accuracy score\" matter?\n", + "\n", + "- Imagine you are building a new covid test. We know majority of the times covid test is supposed to be negative.\n", + "- So, you might get a high accuracy even if your covid test is missing actual positives.\n", + "- If you just look at accuracy, that might be misleading. So, you must also look at \"Balanced accuracy score\"." + ] + }, + { + "cell_type": "markdown", + "id": "be887a73-1731-41ba-b17f-5cc5e3eb7a81", + "metadata": {}, + "source": [ + "### Precision: column-wise ratio\n", + "\n", + "- What proportion of positive identifications was actually correct?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5206ca85-48ee-44e3-8dc1-b6e028f8aa0a", + "metadata": {}, + "outputs": [], + "source": [ + "cm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c03db37-226d-43f8-a0cf-50906bd66037", + "metadata": {}, + "outputs": [], + "source": [ + "print(labels)\n", + "precision_score(actual, predicted, average=None, labels=labels)" + ] + }, + { + "cell_type": "markdown", + "id": "b2cda510-211b-49b4-8e7b-c7216c3c5f53", + "metadata": {}, + "source": [ + "### Binary Classification Metrics\n", + "\n", + "Unless otherwise specified, \"precision\" and \"recall\" refer to those metrics for the positive class when we're doing binary classification.\n", + "\n", + "<div>\n", + "<img src=\"Confusion_matrix_binary.png\" width=\"350\"/>\n", + "</div>" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7b8595b-affa-4950-a59e-f6b607f08bb0", + "metadata": {}, + "outputs": [], + "source": [ + "actual = [False, True, True, True, True, False, False, True, True, True, True, True, True]\n", + "predicted = [False, True, True, True, True, True, True, False, False, False, True, True, True]\n", + "confusion_matrix(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "937b1197-26d4-4898-a75d-116ef5bfd1f6", + "metadata": {}, + "outputs": [], + "source": [ + "recall_score(actual, predicted, average=None, labels=[True, False])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88bbbf3c-aa9b-4ae9-bef2-74f532036057", + "metadata": {}, + "outputs": [], + "source": [ + "# for binary classification, we have False recall and True recall\n", + "# \"recall\" is shorthand for \"True recall\"\n", + "recall_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57221111-77d7-48ee-938a-c596be1fbec7", + "metadata": {}, + "outputs": [], + "source": [ + "precision_score(actual, predicted, average=None, labels=[True, False])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2bb37d4a-9374-4d23-be63-ca27a83ee301", + "metadata": {}, + "outputs": [], + "source": [ + "# for binary classification, we have False precision and True precision\n", + "# \"precision\" is shorthand for \"True precision\"\n", + "precision_score(actual, predicted)" ] } ], diff --git a/lecture_material/21-classification/21-classification_002.ipynb b/lecture_material/21-classification/21-classification_002.ipynb index d27e29fb779edd5d94c20b2b71df07dc5540b35d..7393488c929fc9bc0672dac9b9c6c64975a465fb 100644 --- a/lecture_material/21-classification/21-classification_002.ipynb +++ b/lecture_material/21-classification/21-classification_002.ipynb @@ -1001,6 +1001,14 @@ "df[xcols]" ] }, + { + "cell_type": "markdown", + "id": "639036ad-cdca-4a5a-a920-caba08872ecb", + "metadata": {}, + "source": [ + "#### `predictions = F(sepl, sepw)`" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1008,6 +1016,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating range of values for sepl (X) and sepw (Y)\n", "sepl, sepw = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))\n", "sepl" ] @@ -1042,6 +1051,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Predicting setosa (True / False) labels using LogisticRegression model\n", "Z_predictions = cls_model.predict(predict_df)\n", "Z_predictions" ] @@ -1083,6 +1093,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating contourf plot\n", "plt.contourf(sepl, sepw, Z_predictions.reshape(sepl.shape))\n", "# get current axes subplot\n", "???\n", @@ -1126,6 +1137,8 @@ "metadata": {}, "outputs": [], "source": [ + "# we need numeric values for Z\n", + "# hence .index usage instead of actual values of variety predictions\n", "predictions = np.array([classes_.index(name) for name in mult_model.predict(predict_df)])\n", "predictions" ] @@ -1177,12 +1190,14 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating range of values for sepl (X) and sepw (Y)\n", "sepl, sepw = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))\n", "predict_df = pd.DataFrame({\n", " 'sepal length (cm)': sepl.reshape(-1),\n", " 'sepal width (cm)': sepw.reshape(-1),\n", " 'const': 1\n", "})\n", + "# Predicting setosa (True / False) labels using LogisticRegression model\n", "Z_predictions = model.predict(predict_df)\n", "Z_predictions" ] @@ -1194,6 +1209,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating contourf plot\n", "plt.contourf(sepl, sepw, Z_predictions.reshape(sepl.shape))\n", "ax = plt.gca()\n", "df[df[\"setosa\"]].plot.scatter(x=\"sepal length (cm)\", y=\"sepal width (cm)\", \\\n", @@ -1203,6 +1219,14 @@ " color=\"0.8\", label=\"Not Setosa\")" ] }, + { + "cell_type": "markdown", + "id": "560851c2-348a-4f07-991b-33522464b246", + "metadata": {}, + "source": [ + "#### Multi-classification model to predict `variety` using `Pipeline` of `PolynomialFeatures` and `LogisticRegression` models" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1266,6 +1290,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Predicting variety labels using LogisticRegression model\n", "predictions = np.array([classes_.index(name) for name in model.predict(predict_df)])\n", "predictions" ] @@ -1277,6 +1302,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Creating contourf plot\n", "plt.contourf(sepl, sepw, predictions.reshape(sepl.shape))\n", "ax = plt.gca() # get current axes\n", "df[df[\"variety\"] == \"setosa\"].plot.scatter(x=\"sepal length (cm)\", y=\"sepal width (cm)\", \\\n", @@ -1636,6 +1662,16 @@ "accuracy_score(ACTUAL, PREDICTED)" ] }, + { + "cell_type": "markdown", + "id": "c21ee7f5-267a-49c6-b3b5-2893aa72d1b5", + "metadata": {}, + "source": [ + "What are the range of values of `accuracy_score` for a classification model?\n", + "- `0 to 1` because it is a fraction of predicted labels / actual labels => remember you can either get it correct or wrong.\n", + "- Recall as opposed to classification, for regression range for R^2 score is `-infinity to 1` => that is because the model can introduce non-existent variance within the data while trying to fit it!" + ] + }, { "cell_type": "markdown", "id": "1bdc164a", @@ -1644,6 +1680,23 @@ "#### Confusion Matrices" ] }, + { + "cell_type": "markdown", + "id": "682885ff-b514-4a4c-8cd5-0b5ea94d5254", + "metadata": {}, + "source": [ + "#### `confusion_matrix(y_true, y_pred)`\n", + "\n", + "- computes confusion matrix for classification:\n", + " - row dimension represents actual value\n", + " - column dimension represents predicted value\n", + "- documentation: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html\n", + "\n", + "<div>\n", + "<img src=\"Confusion_matrix.png\" width=\"500\"/>\n", + "</div>" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1668,38 +1721,294 @@ ] }, { - "cell_type": "markdown", - "id": "3efa94b0", + "cell_type": "code", + "execution_count": null, + "id": "12f0a1ff", "metadata": {}, + "outputs": [], "source": [ - "#### `confusion_matrix(y_true, y_pred)`\n", - "\n", - "- computes confusion matrix for classification:\n", - " - row dimension represents actual value\n", - " - column dimension represents predicted value\n", - "- documentation: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html" + "# notice that we have a \"horse\" label even though the data doesn't have any info about horse\n", + "labels = [\"dog\", \"cat\", \"mouse\", \"horse\"]\n", + "cm = confusion_matrix(actual, predicted, labels=labels)\n", + "cm" ] }, { "cell_type": "code", "execution_count": null, - "id": "12f0a1ff", + "id": "ace44e40", + "metadata": {}, + "outputs": [], + "source": [ + "pd.DataFrame(cm, index=labels, columns=labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79fb6e1d-2b62-43d1-a158-e350f0b86145", "metadata": {}, "outputs": [], "source": [ + "actual = [\"dog\", \"dog\", \"dog\", \"dog\", \"cat\", \"cat\", \"cat\", \"cat\", \"mouse\", \"mouse\"] * 199\n", + "predicted = [\"dog\", \"dog\", \"dog\", \"dog\", \"cat\", \"dog\", \"cat\", \"dog\", \"mouse\", \"mouse\"] * 199\n", + "actual += [\"horse\"] * 10\n", + "predicted += [\"dog\"] * 9 + [\"horse\"]\n", + "\n", "labels = [\"dog\", \"cat\", \"mouse\", \"horse\"]\n", "cm = confusion_matrix(actual, predicted, labels=labels)\n", + "cm = pd.DataFrame(cm, index=labels, columns=labels)\n", + "cm" + ] + }, + { + "cell_type": "markdown", + "id": "c1eb06bb-23cb-4cce-a0e0-ffa586d0bc93", + "metadata": {}, + "source": [ + "### Recall and balanced accuracy score\n", + "\n", + "- import statement:\n", + "```python\n", + "from sklearn.metrics import recall_score, precision_score, balanced_accuracy_score\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "4f1b1be4-e780-4011-89b8-85531cdb5314", + "metadata": {}, + "source": [ + "#### Recall: row-wise ratio\n", + "\n", + "- What proportion of actual positives was identified correctly?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93d57a38-9114-4b09-8f11-b1718e60534a", + "metadata": {}, + "outputs": [], + "source": [ + "# what is the recall for cat?\n", + "cm.at[\"cat\", \"cat\"] / cm.loc[\"cat\", :].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca2dff07-6725-4ec0-a791-861ad3b58a7c", + "metadata": {}, + "outputs": [], + "source": [ + "# There are multiple recall scores as we have multiple labels\n", + "# So, we need to pass argument to parameter \"average\"\n", + "recall_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b16930a8-1a1c-416d-8fe5-04cd45f79749", + "metadata": {}, + "outputs": [], + "source": [ + "recall_score(actual, predicted, average=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70eabe61-5dce-4558-a484-de5d32e6906e", + "metadata": {}, + "outputs": [], + "source": [ + "# How can we identify which score is for which label?\n", + "print(labels)\n", + "# We can pass a list of labels argument to parameter \"labels\"\n", + "recall_score(actual, predicted, average=None, labels=labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "394edab6-08ae-411e-bf63-b01c0a12684f", + "metadata": {}, + "outputs": [], + "source": [ "cm" ] }, { "cell_type": "code", "execution_count": null, - "id": "ace44e40", + "id": "971d96ba-d046-4843-93dd-d173804808dd", "metadata": {}, "outputs": [], "source": [ - "pd.DataFrame(cm, index=labels, columns=labels)" + "How does average recall score compare against accuracy score?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50d8b289-ca1e-454c-b2da-c8c06522fc59", + "metadata": {}, + "outputs": [], + "source": [ + "accuracy_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d73672b-9f19-43cd-baad-447bca2bbe24", + "metadata": {}, + "outputs": [], + "source": [ + "# will this be bigger or smaller than accuracy?\n", + "recall_score(actual, predicted, average=None, labels=labels).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "bf98eeab-8cc2-4ab3-a7a1-cb3e96e4c96b", + "metadata": {}, + "source": [ + "We are saying \"horse\" is equally important as other animals by taking an average of recall. So average recall score is lower than overall accuracy." + ] + }, + { + "cell_type": "markdown", + "id": "926d4d94-7bc3-4614-b582-fc7333e68c40", + "metadata": {}, + "source": [ + "#### Average recall score is \"Balanced accuracy score\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8956e72b-694e-411b-ab6d-b8b2fda0ea0b", + "metadata": {}, + "outputs": [], + "source": [ + "# average of recall scores is called balanced accuracy score\n", + "balanced_accuracy_score(actual, predicted)" + ] + }, + { + "cell_type": "markdown", + "id": "cce370f6-5ba8-4ee2-983b-c511585c8de2", + "metadata": {}, + "source": [ + "Why does \"Balanced accuracy score\" matter?\n", + "\n", + "- Imagine you are building a new covid test. We know majority of the times covid test is supposed to be negative.\n", + "- So, you might get a high accuracy even if your covid test is missing actual positives.\n", + "- If you just look at accuracy, that might be misleading. So, you must also look at \"Balanced accuracy score\"." + ] + }, + { + "cell_type": "markdown", + "id": "be887a73-1731-41ba-b17f-5cc5e3eb7a81", + "metadata": {}, + "source": [ + "### Precision: column-wise ratio\n", + "\n", + "- What proportion of positive identifications was actually correct?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5206ca85-48ee-44e3-8dc1-b6e028f8aa0a", + "metadata": {}, + "outputs": [], + "source": [ + "cm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c03db37-226d-43f8-a0cf-50906bd66037", + "metadata": {}, + "outputs": [], + "source": [ + "print(labels)\n", + "precision_score(actual, predicted, average=None, labels=labels)" + ] + }, + { + "cell_type": "markdown", + "id": "b2cda510-211b-49b4-8e7b-c7216c3c5f53", + "metadata": {}, + "source": [ + "### Binary Classification Metrics\n", + "\n", + "Unless otherwise specified, \"precision\" and \"recall\" refer to those metrics for the positive class when we're doing binary classification.\n", + "\n", + "<div>\n", + "<img src=\"Confusion_matrix_binary.png\" width=\"350\"/>\n", + "</div>" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7b8595b-affa-4950-a59e-f6b607f08bb0", + "metadata": {}, + "outputs": [], + "source": [ + "actual = [False, True, True, True, True, False, False, True, True, True, True, True, True]\n", + "predicted = [False, True, True, True, True, True, True, False, False, False, True, True, True]\n", + "confusion_matrix(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "937b1197-26d4-4898-a75d-116ef5bfd1f6", + "metadata": {}, + "outputs": [], + "source": [ + "recall_score(actual, predicted, average=None, labels=[True, False])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88bbbf3c-aa9b-4ae9-bef2-74f532036057", + "metadata": {}, + "outputs": [], + "source": [ + "# for binary classification, we have False recall and True recall\n", + "# \"recall\" is shorthand for \"True recall\"\n", + "recall_score(actual, predicted)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57221111-77d7-48ee-938a-c596be1fbec7", + "metadata": {}, + "outputs": [], + "source": [ + "precision_score(actual, predicted, average=None, labels=[True, False])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2bb37d4a-9374-4d23-be63-ca27a83ee301", + "metadata": {}, + "outputs": [], + "source": [ + "# for binary classification, we have False precision and True precision\n", + "# \"precision\" is shorthand for \"True precision\"\n", + "precision_score(actual, predicted)" ] } ],