kmeans.py 8.15 KB
Newer Older
sim-baz's avatar
sim-baz committed
1 2 3
from cassandra.cluster import Cluster
from datetime import datetime

4 5
from sklearn.cluster import KMeans
import numpy as np
6
import folium
7

sim-baz's avatar
sim-baz committed
8 9 10
import loading as l
import history as h

11
colours = ['blue', 'red', 'green', 'orange', 'pink', 'white', 'purple', 'gray']
sim-baz's avatar
sim-baz committed
12

sim-baz's avatar
sim-baz committed
13 14 15 16 17 18 19 20 21
'''
name: getDatasForPeriod
description: Query the database and get the values for a period for indicators
parameters:
    * startPeriod: beginning date of period to select
    * endPeriod: ending date of period to select
    * indicators: list of indicators to select
return: result of the query
'''
sim-baz's avatar
sim-baz committed
22
def getDatasForPeriod(startPeriod, endPeriod, indicators):
23 24 25
    datas = []
    for i in range(int(startPeriod[0:4]), int(endPeriod[0:4]) + 1):
        datas += session.execute(f"SELECT year, month, day, station, lat, lon, {indicators} FROM {l.table_name_date} where year = {i}")
sim-baz's avatar
sim-baz committed
26

27
    return datas
sim-baz's avatar
sim-baz committed
28

sim-baz's avatar
sim-baz committed
29 30 31 32 33 34 35 36 37 38 39
'''
name: verifyDateInPeriod
description: Verify that the date given is within the period of study
parameters:
    * startPeriod: beginning date of period to select
    * endPeriod: ending date of period to select
    * year: year given by user
    * month: month given by user
    * day: day given by user
return: boolean indicating the validity
'''
sim-baz's avatar
sim-baz committed
40
def verifyDateInPeriod(startPeriod, endPeriod, year, month, day):
41
    isDate = year.isdigit() and month.isdigit() and day.isdigit()
Oscar Roisin's avatar
Oscar Roisin committed
42 43 44 45 46 47
    if isDate:
        date = datetime.strptime(year + "-" + month + "-" + day, "%Y-%m-%d")
        dateStart = datetime.strptime(startPeriod, "%Y-%m-%d")
        dateEnd = datetime.strptime(endPeriod, "%Y-%m-%d")
        if date >= dateStart and date <= dateEnd:
            return True
48
    return False
sim-baz's avatar
sim-baz committed
49

sim-baz's avatar
sim-baz committed
50 51 52 53 54 55 56 57 58 59 60
'''
name: getDecileForAllStations
description: Compute the decile of a list for different stations and indicators
parameters:
    * startPeriod: beginning date of period to select
    * endPeriod: ending date of period to select
    * table: list of lists with all values
    * nb_indicators: number of indicators to compute
    * indicators_list: list of names of indicators to compute
return: a dictionary with lists of dictionaries of lists containing the deciles for indicators for stations
'''
sim-baz's avatar
sim-baz committed
61
def getDecileForAllStations(startPeriod, endPeriod, table, nb_indicators, indicators_list):
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    # map with station and list of maps
    # the list of maps is used for all indicators
    # the second map contains the indicator with the list of values for this indicator
    l = {}
    for t in table:
        if verifyDateInPeriod(startPeriod, endPeriod, str(t[0]), str(t[1]), str(t[2])):
            if t[3] not in l.keys():
                l[t[3]] = []
                for i in range(nb_indicators):
                    if t[6 + i] != None:
                        l[t[3]].append({indicators_list[i] : [float(t[6 + i])]})
            else:
                for i in range(nb_indicators):
                    if t[6 + i] != None:
                        l[t[3]][i][indicators_list[i]].append(float(t[6 + i]))

    # Sort all lists of values
    for station in l.keys():
        for i in range(nb_indicators):
            l[station][i][indicators_list[i]].sort()

    # Deciles is a map mapping station with a list of maps containing indicators and their deciles
    # example for 2 stations with 2 indicators
    # {'EFKI': [{'tmpf': [-23.8, 6.8, 17.6, 26.6, 32.0, 39.2, 44.6, 48.2, 53.6, 62.6, 91.4]}, {'dwpf': [-31.0, 5.0, 14.0, 24.8, 32.0, 35.6, 39.2, 42.8, 50.0, 55.4, 69.8]}], 'EFHA': [{'tmpf': [-23.8, 6.8, 17.6, 26.6, 32.0, 39.2, 44.6, 48.2, 53.6, 62.6, 91.4]}, {'dwpf': [-31.0, 5.0, 14.0, 24.8, 32.0, 35.6, 39.2, 42.8, 50.0, 55.4, 69.8]}]}
    deciles = {}
    for station in l.keys():
        deciles[station] = []
        for i in range(nb_indicators):
            deciles[station].append({indicators_list[i] : []})
            # Compute deciles, from 0 to 10 (= includes min and max)
            for d in range(11):
                if d == 10:
                    deciles[station][i][indicators_list[i]].append(l[station][i][indicators_list[i]][len(l[station][i][indicators_list[i]]) - 1])
                else:
                    deciles[station][i][indicators_list[i]].append(l[station][i][indicators_list[i]][len(l[station][i][indicators_list[i]]) // 10 * d])

    return deciles
sim-baz's avatar
sim-baz committed
99

sim-baz's avatar
sim-baz committed
100 101 102 103 104 105 106 107 108 109 110
'''
name: applyKmeans
description: Apply k-means algorithm to clusterize space
parameters:
    * deciles: a dictionary with lists of dictionaries of lists containing the deciles for indicators for stations
    * nb_indicators: number of indicators in the deciles
    * indicators_list: list of names of indicators in the deciles
    * startPeriod: beginning date of period to select
    * endPeriod: ending date of period to select
return: a dictionary with the station and its associated cluster
'''
111
def applyKmeans(deciles, nb_indicators, indicators_list, startPeriod, endPeriod):
112 113 114 115
    # Create table without map
    table = []
    # Create list with stations name
    stations_name = []
116

117 118 119 120 121 122
    for station in deciles.keys():
        t = []
        stations_name.append(station)
        for i in range(nb_indicators):
            t += deciles[station][i][indicators_list[i]]
        table.append(t)
123

124 125 126 127
    nb_clusters = 4
    if len(stations_name) < nb_clusters:
        print(f"Le nombre de villes ayant des données est trop inférieur ({len(stations_name)}) pour appliquer les kmeans pour la période du {startPeriod} au {endPeriod}")
        return None
128

129
    kmeans = KMeans(n_clusters = nb_clusters, max_iter = 100).fit(table)
130

131 132 133 134 135
    res = {}
    i = 0
    for station in stations_name:
        res[station] = kmeans.labels_[i]
        i += 1
136

137
    return res
138

sim-baz's avatar
sim-baz committed
139

sim-baz's avatar
sim-baz committed
140 141 142 143 144 145 146 147
'''
name: kmeans
description: Clusterize space for a period depending on deciles and create a map of the country
parameters:
    * startPeriod: beginning date of period to select
    * endPeriod: ending date of period to select
    * indicators_list: list of names of indicators to take in account
'''
sim-baz's avatar
sim-baz committed
148
def kmeans(startPeriod, endPeriod, indicators_list):
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    startDate = datetime.strptime(startPeriod, "%Y-%m-%d")
    endDate = datetime.strptime(endPeriod, "%Y-%m-%d")

    firstDate = datetime.strptime(l.FIRST_DAY, "%Y-%m-%d")
    lastDate = datetime.strptime(l.LAST_DAY, "%Y-%m-%d")

    if startDate < firstDate or startDate > lastDate or endDate < firstDate or endDate > lastDate:
        print(f"Les dates doivent être comprises entre {l.FIRST_DAY} et {l.LAST_DAY}")
        return

    if not h.verifyYearValidity(int(startPeriod[0:4]), int(endPeriod[0:4])):
        return

    # Create a string with indicators concatenated
    indicators = ""
    indicators_list_numeric = []
    nb_indicators = 0
    for ind in indicators_list:
        if ind in l.numeric_columns:
            if nb_indicators == 0:
                indicators += ind
                indicators_list_numeric.append(ind)
                nb_indicators += 1
            else:
                indicators += "," + ind
                indicators_list_numeric.append(ind)
                nb_indicators += 1

    table = getDatasForPeriod(startPeriod, endPeriod, indicators)
    table = list(table)
    # Get coordinates
    coord = dict()
    for t in table:
        coord[t[3]]=(t[4], t[5])
    # Get the map with all deciles for all stations and indicators
    table_deciles = getDecileForAllStations(startPeriod, endPeriod, table, nb_indicators, indicators_list_numeric)

    station_with_center = applyKmeans(table_deciles, nb_indicators, indicators_list_numeric, startPeriod, endPeriod)
    if station_with_center != None:
sim-baz's avatar
sim-baz committed
188
        file_name = f"map_kmeans_{startPeriod}_to_{endPeriod}.html"
189
        # Create map
sim-baz's avatar
sim-baz committed
190
        m = folium.Map(location=[64.2815, 27.6753], zoom_start = 5)
191 192 193 194 195 196 197 198
        # Add Marker for each station
        for key, value in station_with_center.items():
            folium.Marker([coord[key][0], coord[key][1]], popup=f"<b>{key}</b>", icon=folium.Icon(color=colours[value])).add_to(m)
        # Save map
        m.save(file_name)
        print(f"La carte a été enregistrée à {file_name}")
    else:
        print(f"Aucune clusterisation déterminée")
sim-baz's avatar
sim-baz committed
199 200

if __name__ == '__main__':
201 202 203 204 205 206 207 208
    cluster = Cluster()
    session = cluster.connect()
    session.set_keyspace("bazinsim_roisinos_metar")

    print()
    # kmeans("2001-01-01", "2010-12-31", ["tmpf", "skyc1"])
    kmeans("2001-01-01", "2010-12-31", ["tmpf", "dwpf", "skyc1"])
    print()