kmeans.py 4.92 KB
Newer Older
sim-baz's avatar
sim-baz committed
1
2
3
from cassandra.cluster import Cluster
from datetime import datetime

4
5
6
from sklearn.cluster import KMeans
import numpy as np

sim-baz's avatar
sim-baz committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import loading as l
import history as h


def getDatasForPeriod(startPeriod, endPeriod, indicators):
	datas = []
	for i in range(int(startPeriod[0:4]), int(endPeriod[0:4]) + 1):
		datas += session.execute(f"SELECT year, month, day, station, {indicators} FROM {l.table_name_date} where year = {i}")

	return datas

def verifyDateInPeriod(startPeriod, endPeriod, year, month, day):
	date = datetime.strptime(year + "-" + month + "-" + day, "%Y-%m-%d")
	dateStart = datetime.strptime(startPeriod, "%Y-%m-%d")
	dateEnd = datetime.strptime(endPeriod, "%Y-%m-%d")

	if date < dateStart or date > dateEnd:
		return False
	return True

def getDecileForAllStations(startPeriod, endPeriod, table, nb_indicators, indicators_list):
	# map with station and list of maps
	# the list of maps is used for all indicators
	# the second map contains the indicator with the list of values for this indicator
	l = {}
	for t in table:
		if verifyDateInPeriod(startPeriod, endPeriod, str(t[0]), str(t[1]), str(t[2])):
			if t[3] not in l.keys():
				l[t[3]] = []
				for i in range(nb_indicators):
					if t[4 + i] != None:
						l[t[3]].append({indicators_list[i] : [float(t[4 + i])]})
			else:
				for i in range(nb_indicators):
					if t[4 + i] != None:
						l[t[3]][i][indicators_list[i]].append(float(t[4 + i]))

	# Sort all lists of values
	for station in l.keys():
		for i in range(nb_indicators):
			l[station][i][indicators_list[i]].sort()

	# Deciles is a map mapping station with a list of maps containing indicators and their deciles
	# example for 2 stations with 2 indicators
	# {'EFKI': [{'tmpf': [-23.8, 6.8, 17.6, 26.6, 32.0, 39.2, 44.6, 48.2, 53.6, 62.6, 91.4]}, {'dwpf': [-31.0, 5.0, 14.0, 24.8, 32.0, 35.6, 39.2, 42.8, 50.0, 55.4, 69.8]}], 'EFHA': [{'tmpf': [-23.8, 6.8, 17.6, 26.6, 32.0, 39.2, 44.6, 48.2, 53.6, 62.6, 91.4]}, {'dwpf': [-31.0, 5.0, 14.0, 24.8, 32.0, 35.6, 39.2, 42.8, 50.0, 55.4, 69.8]}]}
	deciles = {}
	for station in l.keys():
		deciles[station] = []
		for i in range(nb_indicators):
			deciles[station].append({indicators_list[i] : []})
			# Compute deciles, from 0 to 10 (= includes min and max)
			for d in range(11):
59
60
61
62
				if d == 10:
					deciles[station][i][indicators_list[i]].append(l[station][i][indicators_list[i]][len(l[station][i][indicators_list[i]]) - 1])
				else:
					deciles[station][i][indicators_list[i]].append(l[station][i][indicators_list[i]][len(l[station][i][indicators_list[i]]) // 10 * d])
sim-baz's avatar
sim-baz committed
63
64
65

	return deciles

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def applyKmeans(deciles, nb_indicators, indicators_list, startPeriod, endPeriod):
	# Create table without map
	table = []
	# Create list with stations name
	stations_name = []

	for station in deciles.keys():
		t = []
		stations_name.append(station)
		for i in range(nb_indicators):
			t += deciles[station][i][indicators_list[i]]
			print (t)
		table.append(t)

	if len(stations_name) < nb_clusters:
		print(f"Le nombre de villes ayant des données est trop inférieur ({len(stations_name)}) pour appliquer les kmeans pour la période du {startPeriod} au {endPeriod}")
		return None

	kmeans = KMeans(n_clusters = 3, max_iter = 100).fit(table)

	res = {}
	i = 0
	for station in stations_name:
		res[station] = kmeans.labels_[i]
		i += 1

	return res

sim-baz's avatar
sim-baz committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

def kmeans(startPeriod, endPeriod, indicators_list):
	startDate = datetime.strptime(startPeriod, "%Y-%m-%d")
	endDate = datetime.strptime(endPeriod, "%Y-%m-%d")

	firstDate = datetime.strptime(l.FIRST_DAY, "%Y-%m-%d")
	lastDate = datetime.strptime(l.LAST_DAY, "%Y-%m-%d")

	if startDate < firstDate or startDate > lastDate or endDate < firstDate or endDate > lastDate:
		print(f"Les dates doivent être comprises entre {l.FIRST_DAY} et {l.LAST_DAY}")
		return

	if not h.verifyYearValidity(int(startPeriod[0:4]), int(endPeriod[0:4])):
		return

	# Create a string with indicators concatenated
	indicators = ""
	indicators_list_numeric = []
	nb_indicators = 0
	for ind in indicators_list:
		if ind in l.numeric_columns:
			if nb_indicators == 0:
				indicators += ind
				indicators_list_numeric.append(ind)
				nb_indicators += 1
			else:
				indicators += "," + ind
				indicators_list_numeric.append(ind)
				nb_indicators += 1
	
	table = getDatasForPeriod(startPeriod, endPeriod, indicators)
	table = list(table)

	# Get the map with all deciles for all stations and indicators
128
129
130
131
132
133
	table_deciles = getDecileForAllStations(startPeriod, endPeriod, table, nb_indicators, indicators_list_numeric)

	station_with_center = applyKmeans(table_deciles, nb_indicators, indicators_list_numeric, startPeriod, endPeriod)
	if station_with_center != None:
		print(f"Voici les villes et le cluster auxquelles elles appartiennent:")
		print(f"{station_with_center}")
sim-baz's avatar
sim-baz committed
134
135
136
137
138
139
140

if __name__ == '__main__':
	cluster = Cluster()
	session = cluster.connect()
	session.set_keyspace("bazinsim_roisinos_metar")

	print()
141
	# kmeans("2001-01-01", "2010-12-31", ["tmpf", "skyc1"])
sim-baz's avatar
sim-baz committed
142
143
	kmeans("2001-01-01", "2010-12-31", ["tmpf", "dwpf", "skyc1"])
	print()