Commit 8ad19b40bfbd2af7cb21b714260abe1241a399e2

Authored by Louis Baetens
1 parent 4a0ff7fe
Exists in master

SAVE

ALCD/all_run_alcd.py
... ... @@ -8,6 +8,7 @@ import json
8 8 import shutil
9 9 import json
10 10 import argparse
  11 +import tempfile
11 12  
12 13 import OTB_workflow
13 14 import masks_preprocessing
... ... @@ -104,7 +105,7 @@ def invitation_to_copy(global_parameters, first_iteration = False):
104 105 print('scp -r {source} {server}{destination}'.format(server = current_server, source = source_dir, destination = dest_dir))
105 106  
106 107  
107   -def run_all(part, first_iteration = False, location=None, wanted_date=None, clear_date=None):
  108 +def run_all(part, first_iteration = False, location=None, wanted_date=None, clear_date=None, k_fold_step=None, k_fold_dir=None):
108 109 if part == 1:
109 110 # Define the main parameters for the algorithm
110 111 # If all is filled, will update the JSON file
... ... @@ -164,8 +165,11 @@ def run_all(part, first_iteration = False, location=None, wanted_date=None, clea
164 165  
165 166 elif part == 2:
166 167 # Merge the layers, split in training and validation data, augment the data
167   - masks_preprocessing.masks_preprocess(global_parameters)
168   -
  168 + if k_fold_step == None or k_fold_dir == None:
  169 + masks_preprocessing.masks_preprocess(global_parameters)
  170 + else:
  171 + masks_preprocessing.masks_preprocess(global_parameters, k_fold_step, k_fold_dir)
  172 +
169 173 elif part == 3:
170 174 # Compute the statistics of the image and samples, and extract the later
171 175 if first_iteration == True:
... ... @@ -225,6 +229,7 @@ def main():
225 229 parser.add_argument('-d', action='store', default=None, dest='wanted_date', help='Date, The desired date to process (e.g. 20170702)')
226 230 parser.add_argument('-c', action='store', default=None, dest='clear_date', help='Date, The nearest clear date (e.g. 20170704)')
227 231 parser.add_argument('-dates', action='store', default='false', dest='get_dates', help='Bool, Get the available dates')
  232 + parser.add_argument('-kfold', action='store', default='false', dest='kfold', help='Bool, Do a K-fold cross validation')
228 233  
229 234 results = parser.parse_args()
230 235 location = results.location
... ... @@ -236,6 +241,8 @@ def main():
236 241 print([str(d) for d in available_dates])
237 242 return
238 243  
  244 +
  245 +
239 246 if results.first_iteration == None:
240 247 print('Please enter a boolean for the first iteration')
241 248 return
... ... @@ -253,6 +260,29 @@ def main():
253 260 clear_date = results.clear_date
254 261  
255 262  
  263 +
  264 + kfold = str2bool(results.kfold)
  265 + if kfold:
  266 + tmp_name = next(tempfile._get_candidate_names())
  267 + k_fold_dir = op.join('tmp', 'kfold_{}'.format(tmp_name))
  268 + if not op.exists(k_fold_dir):
  269 + os.makedirs(k_fold_dir)
  270 + print(k_fold_dir + ' created')
  271 +
  272 +
  273 + #~ run_all(part, first_iteration = False, location=None, wanted_date=None, clear_date=None, k_fold_step=None, k_fold_dir=None)
  274 + global_parameters = json.load(open(op.join('parameters_files','global_parameters.json')))
  275 + K = int(global_parameters["training_parameters"]["Kfold"])
  276 +
  277 + for k_fold_step in range(K):
  278 + run_all(part = 2, first_iteration = first_iteration, k_fold_step=k_fold_step, k_fold_dir=k_fold_dir)
  279 + run_all(part = 3, first_iteration = first_iteration, k_fold_step=k_fold_step, k_fold_dir=k_fold_dir)
  280 + run_all(part = 4, first_iteration = first_iteration, k_fold_step=k_fold_step, k_fold_dir=k_fold_dir)
  281 + run_all(part = 5, first_iteration = first_iteration, k_fold_step=k_fold_step, k_fold_dir=k_fold_dir)
  282 + #~ run_all(part = 6, first_iteration = first_iteration, k_fold_step=k_fold_step, k_fold_dir=k_fold_dir)
  283 +
  284 + return
  285 +
256 286  
257 287 if user_input == 0:
258 288 run_all(part = 1, first_iteration = first_iteration, location=location, wanted_date=wanted_date, clear_date=clear_date)
... ...
ALCD/layers_creation.py
... ... @@ -11,6 +11,7 @@ import json
11 11 import otbApplication
12 12 import L1C_band_composition
13 13 import subprocess
  14 +import tempfile
14 15  
15 16 def empty_shapefile_creation(in_tif, out_shp_list, geometry_type = 'point'):
16 17 '''
... ...
ALCD/masks_preprocessing.py
1 1 #!/usr/bin/python
2 2 # -*- coding: utf-8 -*-
3 3  
  4 +import os
  5 +import shutil
4 6 import os.path as op
5 7 from osgeo import ogr, gdal
6 8  
... ... @@ -8,8 +10,9 @@ from osgeo import ogr, gdal
8 10 import expand_point_region
9 11 import split_samples
10 12 import merge_shapefiles
  13 +import glob
11 14  
12   -def split_and_augment(global_parameters):
  15 +def split_and_augment(global_parameters, k_fold_step=None, k_fold_dir=None):
13 16 '''
14 17 Split the 'merged.shp' file in two dataset
15 18 Then augment the data
... ... @@ -21,12 +24,22 @@ def split_and_augment(global_parameters):
21 24 validation_shp_extended = op.join(main_dir, 'Intermediate', global_parameters["general"]["validation_shp_extended"])
22 25 training_shp_extended = op.join(main_dir, 'Intermediate', global_parameters["general"]["training_shp_extended"])
23 26  
24   - # training proportion
25   - proportion = float(global_parameters["training_parameters"]["training_proportion"])
26   -
27   - # split into 2 datasets
28   - split_samples.split_points_sample(in_shp = merged_shp, train_shp = training_shp, validation_shp = validation_shp, proportion = proportion)
29   -
  27 + if k_fold_step != None and k_fold_dir != None:
  28 + # if not done before, create the split
  29 + if k_fold_step == 0:
  30 + K = global_parameters["training_parameters"]["Kfold"]
  31 + split_samples.k_split(merged_shp, k_fold_dir, K)
  32 +
  33 + # copy directly the k fold
  34 + load_kfold(training_shp, validation_shp, k_fold_step, k_fold_dir)
  35 +
  36 + else:
  37 + # training proportion
  38 + proportion = float(global_parameters["training_parameters"]["training_proportion"])
  39 +
  40 + # split into 2 datasets
  41 + split_samples.split_points_sample(in_shp = merged_shp, train_shp = training_shp, validation_shp = validation_shp, proportion = proportion)
  42 +
30 43 # set the distance of the zone around each point
31 44 max_dist_X = float(global_parameters["training_parameters"]["expansion_distance"])
32 45 max_dist_Y = float(global_parameters["training_parameters"]["expansion_distance"])
... ... @@ -36,6 +49,33 @@ def split_and_augment(global_parameters):
36 49 expand_point_region.create_squares(validation_shp, validation_shp_extended, max_dist_X, max_dist_Y)
37 50  
38 51  
  52 +def load_kfold(train_shp, validation_shp, k_fold_step, k_fold_dir):
  53 + '''
  54 + Copy the K train and validation shp to the default validation shp,
  55 + in order to obtain the validation
  56 + '''
  57 + validation_files = glob.glob(op.join(k_fold_dir, 'validation*_k_*{}*'.format(k_fold_step)))
  58 + train_files = glob.glob(op.join(k_fold_dir, 'train*_k_*{}*'.format(k_fold_step)))
  59 +
  60 + # Problem with the shapefiles is that they are accompagnied with
  61 + # other files (prj, dbf, shx) that we should copy as well
  62 + # so we go through all the names and copy them
  63 + for valid_f in validation_files:
  64 + # get the extension of the file
  65 + _, extension = op.splitext(valid_f)
  66 + dst_basename, _ = op.splitext(validation_shp)
  67 + dst = dst_basename + extension
  68 + shutil.copy(valid_f, dst)
  69 +
  70 + for train_f in train_files:
  71 + # get the extension of the file
  72 + _, extension = op.splitext(train_f)
  73 + dst_basename, _ = op.splitext(train_shp)
  74 + dst = dst_basename + extension
  75 + shutil.copy(train_f, dst)
  76 +
  77 +
  78 +
39 79 def rasterize_shp(input_shp, out_tif, reference_tif):
40 80 '''
41 81 From a shapefile, rasterize it
... ... @@ -68,7 +108,7 @@ def rasterize_shp(input_shp, out_tif, reference_tif):
68 108 image = None
69 109 shapefile = None
70 110  
71   -def masks_preprocess(global_parameters):
  111 +def masks_preprocess(global_parameters, k_fold_step=None, k_fold_dir=None):
72 112 '''
73 113 Global preprocessing of the masks
74 114 '''
... ... @@ -88,9 +128,11 @@ def masks_preprocess(global_parameters):
88 128 merge_shapefiles.merge_shapefiles(in_shp_list = layers_to_merge, class_list = layers_classes, out_shp = merged_layers)
89 129 print('Done')
90 130  
91   -
92   - print(' Split into two datasets and augment the data')
93   - split_and_augment(global_parameters)
  131 + if k_fold_step != None and k_fold_dir != None:
  132 + print(' Copy the {}th dataset and augment the data'.format(k_fold_step))
  133 + else:
  134 + print(' Split into two datasets and augment the data')
  135 + split_and_augment(global_parameters, k_fold_step=k_fold_step, k_fold_dir=k_fold_dir)
94 136 print('Done')
95 137  
96 138 print(' Transform the no-data mask to raster')
... ... @@ -104,6 +146,15 @@ def masks_preprocess(global_parameters):
104 146  
105 147  
106 148 def main():
  149 + kfold_out = '/mnt/data/home/baetensl/clouds_detection_git/Data_ALCD/Arles_31TFJ_20171002/kfold_out'
  150 + train_shp = op.join(kfold_out, 'train_test.shp')
  151 + validation_shp = op.join(kfold_out, 'validation_test.shp')
  152 + k_step = 2
  153 + k_fold_dir = '/mnt/data/home/baetensl/clouds_detection_git/Data_ALCD/Arles_31TFJ_20171002/kfold'
  154 + load_kfold(train_shp, validation_shp, k_step, k_fold_dir)
  155 +
  156 + return
  157 +
107 158 masks_preprocess()
108 159 #~ split_and_augment()
109 160  
... ...
ALCD/metrics_exploitation.py
... ... @@ -246,19 +246,52 @@ def save_model_metrics(global_parameters):
246 246  
247 247 # copy the interesting files into it
248 248 files_of_interest = []
  249 + src_dirs = []
  250 + src_dirs.append(statistics_dir)
249 251 files_of_interest.append(global_parameters["postprocessing"]["confusion_matrix"])
  252 + src_dirs.append(statistics_dir)
250 253 files_of_interest.append(global_parameters["postprocessing"]["binary_confusion_matrix"])
  254 + src_dirs.append(statistics_dir)
251 255 files_of_interest.append(global_parameters["postprocessing"]["model_metrics"])
  256 + src_dirs.append(statistics_dir)
252 257 files_of_interest.append(global_parameters["general"]["class_stats"])
253   -
254   - for file_name in files_of_interest:
255   - src = op.join(statistics_dir, file_name)
  258 + src_dirs.append(statistics_dir)
  259 +
  260 + for n in range(len(files_of_interest)):
  261 + file_name = files_of_interest[n]
  262 + src_dir = src_dirs[n]
  263 + src = op.join(src_dir, file_name)
256 264 dst = op.join(K_fold_dir, file_name)
257 265 if op.exists(src):
258 266 shutil.copyfile(src, dst)
  267 +
  268 +
  269 + samples_dir = op.join(main_dir, 'Intermediate')
  270 + samples_files = []
  271 +
  272 + samples_files.append(global_parameters["general"]["validation_shp"])
  273 + samples_files.append(global_parameters["general"]["training_shp"])
  274 +
  275 + samples_files = [op.join(samples_dir, s) for s in samples_files]
  276 +
  277 + for valid_f in samples_files:
  278 + # get the extension of the file
  279 + _, extension = op.splitext(valid_f)
  280 + src_basename, _ = op.splitext(valid_f)
  281 + all_src = glob.glob(src_basename + '*')
  282 +
  283 + for src in all_src:
  284 + if not '_ext' in op.basename(src):
  285 + dst = op.join(K_fold_dir, op.basename(src))
  286 + shutil.copy(src, dst)
  287 +
  288 +
  289 +
  290 +
  291 +
259 292  
260 293  
261   -def retrieve_Kfold_data(global_parameters, metrics_plotting = True, location = '', date = ''):
  294 +def retrieve_Kfold_data(global_parameters, metrics_plotting = False, location = '', date = ''):
262 295 '''
263 296 After having run the model K times, this function is used to do some
264 297 stats on all the runs
... ... @@ -267,8 +300,11 @@ def retrieve_Kfold_data(global_parameters, metrics_plotting = True, location = '
267 300 paths_configuration = json.load(open(op.join('..', 'paths_configuration.json')))
268 301 Data_ALCD_dir = paths_configuration["data_paths"]["data_alcd"]
269 302 main_dir = glob.glob(op.join(Data_ALCD_dir, '{}_*_{}'.format(location, date)))[0]
  303 +
270 304 else:
271 305 main_dir = global_parameters["user_choices"]["main_dir"]
  306 + location = global_parameters["user_choices"]["location"]
  307 + date = global_parameters["user_choices"]["current_date"]
272 308 statistics_dir = op.join(main_dir, 'Statistics')
273 309 model_metrics_csv_basename = op.join(global_parameters["postprocessing"]["model_metrics"])
274 310  
... ... @@ -289,29 +325,62 @@ def retrieve_Kfold_data(global_parameters, metrics_plotting = True, location = '
289 325 metrics_names.append(row[0])
290 326 current_metrics.append(float(row[1]))
291 327 all_metrics.append(current_metrics)
292   - #~ print('Stats computed on {}-fold cross validation'.format(len(all_metrics)))
293   - #~ print(metrics_names)
294   -
  328 + #~ print(all_metrics)
295 329 # Compute the mean and standard deviation for each metric
296 330 means = np.mean(all_metrics, axis = 0)
297 331 stds = np.std(all_metrics, axis = 0)
298   - #~ print('Means')
299   - #~ print(means)
300   - #~ print('Standard deviations')
301   - #~ print(stds)
302 332  
303   - #~ if metrics_plotting == True:
304   - #~ plot_metrics(all_metrics, metrics_names)
  333 + # save the stats
  334 + out_json = op.join(statistics_dir, 'k_fold_summary.json')
  335 + data = {}
  336 + data["means"] = list(means)
  337 + data["stds"] = list(stds)
  338 + data["metrics_names"] = metrics_names
  339 + data["all_metrics"] = all_metrics
  340 + data["K"] = len(all_metrics)
  341 +
  342 + jsonFile = open(out_json, "w+")
  343 + jsonFile.write(json.dumps(data, indent=3, sort_keys=True))
  344 + jsonFile.close()
  345 +
  346 +
  347 + if metrics_plotting:
  348 + indices = [0,1,2,3]
  349 + accuracies = [m[0] for m in all_metrics]
  350 + f1scores = [m[1] for m in all_metrics]
  351 + recalls = [m[2] for m in all_metrics]
  352 + precisions = [m[3] for m in all_metrics]
  353 +
  354 +
  355 + plt.errorbar(indices, means[0:4], stds[0:4], linestyle='',
  356 + marker='o', color = 'b')
  357 + met_nb = 0
  358 +
  359 + for metric in [accuracies, f1scores, recalls, precisions]:
  360 + rnd = [(indices[met_nb] - 0.1 + 0.2*(float(k)/len(accuracies))) for k in range(len(accuracies))]
  361 + plt.scatter(rnd, metric, color='k', marker='.', alpha = 0.2)
  362 + met_nb += 1
  363 +
  364 + plt.ylim(0.5,1)
  365 + metrics_names = ['Accuracy\n{:.1f}%'.format(means[0]*100),
  366 + 'F1-score\n{:.1f}%'.format(means[1]*100),
  367 + 'Recall\n{:.1f}%'.format(means[2]*100),
  368 + 'Precision\n{:.1f}%'.format(means[3]*100)]
  369 + plt.xticks(indices, metrics_names)
  370 +
  371 + nb_dates = float(len(accuracies))/11
  372 + plt.title('Metrics of a {}-fold random cross-validation\n{}, {}'.format(len(accuracies), location, date))
  373 + plt.xlabel('Score type')
  374 + plt.ylabel('Scores')
  375 + #~ plt.show()
  376 +
  377 + out_fig = op.join(statistics_dir, 'kfold_metrics.png')
  378 + plt.savefig(out_fig, bbox_inches='tight')
  379 + #~ plt.close()
305 380  
306 381 return means, stds, all_metrics
307 382  
308   -#def plot_metrics(all_metrics, metrics_names):
309   -# k_range = np.arange(0,len(all_metrics))
310   -# accuracies = [m[0] for m in all_metrics]
311   -# f1_scores = [m[1] for m in all_metrics]
312   -# precisions = [m[2] for m in all_metrics]
313   -# recalls = [m[3] for m in all_metrics]
314   -# specificities = [m[4] for m in all_metrics]
  383 +
315 384  
316 385 def load_previous_global_parameters(location, date):
317 386 paths_configuration = json.load(open(op.join('..', 'paths_configuration.json')))
... ... @@ -344,13 +413,22 @@ def plot_statistics_all_sites():
344 413 locations, _, dates = get_all_locations_dates(csv_file)
345 414  
346 415 all_metrics = []
  416 +
  417 + low_accuracies = []
  418 + low_accuracies_scenes = []
  419 +
347 420 for j in range(len(locations)):
348 421 location = locations[j]
349 422 date = dates[j]
350 423  
351 424 try:
352   - _, _, temp_metrics = retrieve_Kfold_data(global_parameters, metrics_plotting = True, location = location, date = date)
  425 + _, _, temp_metrics = retrieve_Kfold_data(global_parameters, metrics_plotting = False, location = location, date = date)
353 426 all_metrics.extend(temp_metrics)
  427 + accuracies_tmp = [t[0] for t in temp_metrics]
  428 + if any(a < 0.8 for a in accuracies_tmp):
  429 + low_accuracies.append(accuracies_tmp)
  430 + low_accuracies_scenes.append((location + date))
  431 +
354 432 except:
355 433 print('Error on {}, {}'.format(location, date))
356 434  
... ... @@ -373,9 +451,8 @@ def plot_statistics_all_sites():
373 451 met_nb = 0
374 452 for metric in [accuracies, f1scores, recalls, precisions]:
375 453  
376   - #~ rnd = [(indices[met_nb] + (random.random()-0.5)/4) for k in range(len(accuracies))]
377 454 rnd = [(indices[met_nb] - 0.1 + 0.2*(float(k)/len(accuracies))) for k in range(len(accuracies))]
378   - plt.scatter(rnd, accuracies, color='k', marker='.', alpha = 0.2)
  455 + plt.scatter(rnd, metric, color='k', marker='.', alpha = 0.2)
379 456 met_nb += 1
380 457 plt.errorbar(indices, means[0:4], stds[0:4], linestyle='',
381 458 marker='o', lw=2, elinewidth = 2, capsize = 8, capthick = 1, color = 'b')
... ... @@ -388,7 +465,7 @@ def plot_statistics_all_sites():
388 465 plt.xticks(indices, metrics_names)
389 466  
390 467 nb_dates = float(len(accuracies))/11
391   - plt.title('Metrics of a 11-fold random cross-validation \n on {:.0f} dates'.format(nb_dates))
  468 + plt.title('Metrics of a 10-fold random cross-validation \n on {:.0f} dates'.format(nb_dates))
392 469 plt.xlabel('Score type')
393 470 plt.ylabel('Scores')
394 471  
... ... @@ -405,13 +482,18 @@ def plot_statistics_all_sites():
405 482 out_fig = op.join('tmp', 'kfold_synthese.png')
406 483 plt.savefig(out_fig, bbox_inches='tight')
407 484 plt.close()
408   -
  485 + print('Scenes with low accuracies:')
  486 + print(low_accuracies_scenes)
409 487  
410 488 def main():
411 489 global_parameters = json.load(open(op.join('parameters_files','global_parameters.json')))
412 490 plot_statistics_all_sites()
413   -
  491 + return
  492 + location = 'Ispra'
  493 + date = '20171009'
  494 + a = retrieve_Kfold_data(global_parameters, metrics_plotting = True, location = '', date = '')
414 495 #~ retrieve_Kfold_data(global_parameters, metrics_plotting = True)
  496 +
415 497 #~ location = 'Arles'
416 498 #~ date = 20171221
417 499 #~ load_previous_global_parameters(location, date)
... ...
ALCD/parameters_files/global_parameters.json
... ... @@ -60,17 +60,18 @@
60 60 "model_metrics": "model_metrics.csv"
61 61 },
62 62 "training_parameters": {
  63 + "Kfold": "10",
63 64 "dilatation_radius": "1",
64 65 "expansion_distance": "100",
65 66 "regularization_radius": "1",
66 67 "training_proportion": "0.9"
67 68 },
68 69 "user_choices": {
69   - "clear_date": "20170820",
70   - "current_date": "20171002",
71   - "location": "Arles",
72   - "main_dir": "/mnt/data/home/baetensl/clouds_detection_git/Data_ALCD/Arles_31TFJ_20171002",
73   - "raw_img": "Arles_bands.tif",
74   - "tile": "31TFJ"
  70 + "clear_date": "20180213",
  71 + "current_date": "20180213",
  72 + "location": "RailroadValley",
  73 + "main_dir": "/mnt/data/home/baetensl/clouds_detection_git/Data_ALCD/RailroadValley_11SPC_20180213",
  74 + "raw_img": "RailroadValley_bands.tif",
  75 + "tile": "11SPC"
75 76 }
76 77 -}
  78 +}
77 79 \ No newline at end of file
... ...
ALCD/split_samples.py
... ... @@ -105,6 +105,7 @@ def split_points_sample(in_shp, train_shp, validation_shp, proportion, proportio
105 105 for class_name in list(set(points_classes_list)):
106 106 # get all the indexes of the points belonging to that class
107 107 class_indexes = [index for index, value in enumerate(points_classes_list) if value == class_name]
  108 + shuffle(class_indexes) # added later, should really be random
108 109  
109 110 # set the max number of points with the command below
110 111 cutoff = int(np.ceil(proportion*len(class_indexes)))
... ... @@ -161,6 +162,136 @@ def split_points_sample(in_shp, train_shp, validation_shp, proportion, proportio
161 162  
162 163 return
163 164  
  165 +def k_split(in_shp, out_dir, K):
  166 + '''
  167 + Split the in_shp in K different sets
  168 + They will be saved in the out_dir folder
  169 + '''
  170 + # Create the output dir
  171 + if not os.path.exists(out_dir):
  172 + os.makedirs(out_dir)
  173 + print(out_dir + ' created')
  174 +
  175 + # Get a Layer's Extent
  176 + inDriver = ogr.GetDriverByName("ESRI Shapefile")
  177 + inDataSource = inDriver.Open(in_shp, 0)
  178 + inLayer = inDataSource.GetLayer()
  179 +
  180 + layerDefinition = inLayer.GetLayerDefn()
  181 + srs = inLayer.GetSpatialRef()
  182 +
  183 + # get the field names
  184 + field_names = []
  185 + for i in range(layerDefinition.GetFieldCount()):
  186 + field_names.append(layerDefinition.GetFieldDefn(i).GetName())
  187 +
  188 + shpDriver = ogr.GetDriverByName("ESRI Shapefile")
  189 +
  190 +
  191 +
  192 + # each class will respect the proportion
  193 + points_classes_list = []
  194 + points_FID_list = []
  195 +
  196 + # Get a list of all the classes and FID
  197 + for point in inLayer:
  198 + points_classes_list.append(point.GetField("class"))
  199 + points_FID_list.append(point.GetFID())
  200 +
  201 + # Shuffle the two lists in the same order
  202 + points_classes_list, points_FID_list = shuffle_two_lists(points_classes_list, points_FID_list)
  203 +
  204 + # Get the indexes to respect the quota
  205 + train_idx = []
  206 + validation_idx = []
  207 + # for each class
  208 + for class_name in list(set(points_classes_list)):
  209 + # get all the indexes of the points belonging to that class
  210 + class_indexes = [index for index, value in enumerate(points_classes_list) if value == class_name]
  211 + shuffle(class_indexes)
  212 +
  213 + # split it into K chunks of same size
  214 + K = int(K)
  215 + splitted_class_indexes = np.array_split(class_indexes, K)
  216 +
  217 + # prepare all the k lists
  218 + for k in range(K):
  219 + train_k_idx = np.concatenate([x for i,x in enumerate(splitted_class_indexes) if i!=k])
  220 + validation_k_idx = np.concatenate([x for i,x in enumerate(splitted_class_indexes) if i==k])
  221 + train_idx.append(train_k_idx)
  222 + validation_idx.append(validation_k_idx)
  223 +
  224 + # here, train_idx and validation_idx contains K*nb_of_classes elements
  225 + # it should be transformed into K elements by concatenating the list into
  226 + # itself every nb_of_classes element
  227 + nb_classes = len(list(set(points_classes_list)))
  228 +
  229 + train_idx_all_K = []
  230 + validation_idx_all_K = []
  231 + for k in range(K):
  232 + train_idx_all_K.append(np.concatenate(train_idx[k::K]))
  233 + validation_idx_all_K.append(np.concatenate(validation_idx[k::K]))
  234 +
  235 + for k in range(K):
  236 + train_idx = train_idx_all_K[k]
  237 + validation_idx = validation_idx_all_K[k]
  238 +
  239 + train_shp = op.join(out_dir, 'train_k_{}.shp'.format(k))
  240 + validation_shp = op.join(out_dir, 'validation_k_{}.shp'.format(k))
  241 +
  242 + # Associate the indexes to the FIDs
  243 + train_FID = [points_FID_list[int(idx)] for idx in train_idx]
  244 + validation_FID = [points_FID_list[int(idx)] for idx in validation_idx]
  245 + #~ print(train_FID)
  246 + print('{} training points will be taken'.format(len(train_FID)))
  247 + print('{} validation points will be taken'.format(len(validation_FID)))
  248 +
  249 + inLayer.ResetReading() # needs to be reset to be readable again
  250 +
  251 +
  252 + # Remove output shapefile if it already exists
  253 + for dire in [train_shp, validation_shp]:
  254 + if os.path.exists(dire):
  255 + shpDriver.DeleteDataSource(dire)
  256 +
  257 + # Create the output shapefiles
  258 + trainDataSource = shpDriver.CreateDataSource(train_shp)
  259 + trainLayer = trainDataSource.CreateLayer("buff_layer", srs, geom_type=ogr.wkbPoint)
  260 +
  261 + validationDataSource = shpDriver.CreateDataSource(validation_shp)
  262 + validationLayer = validationDataSource.CreateLayer("buff_layer", srs, geom_type=ogr.wkbPoint)
  263 +
  264 +
  265 + # Add all the fields
  266 + for field_name in field_names:
  267 + newField = ogr.FieldDefn(field_name, ogr.OFTInteger)
  268 + trainLayer.CreateField(newField)
  269 + validationLayer.CreateField(newField)
  270 +
  271 +
  272 + # Create the feature and set values
  273 + for point in inLayer:
  274 + current_FID = point.GetFID()
  275 + if current_FID in train_FID:
  276 + trainLayer.CreateFeature(point)
  277 + elif current_FID in validation_FID:
  278 + validationLayer.CreateFeature(point)
  279 + else:
  280 + print('FID {} not in any list'.format(current_FID))
  281 +
  282 +
  283 +
  284 + # Close DataSources
  285 + trainDataSource.Destroy()
  286 + validationDataSource.Destroy()
  287 +
  288 +
  289 + inDataSource.Destroy()
  290 +
  291 + return
  292 +
  293 +
  294 +
164 295  
165 296  
166 297 def main():
... ... @@ -169,7 +300,15 @@ def main():
169 300 #~ validation_shp = '/mnt/data/home/baetensl/OTB_codes/OTB_commands/Full_orleans/In_data/Masks/points_clouds_validation.shp'
170 301  
171 302 shp_dir = '/mnt/data/home/baetensl/classification_clouds/Data/Orleans_all/Intermediate'
  303 + shp_dir = '/mnt/data/home/baetensl/clouds_detection_git/Data_ALCD/Arles_31TFJ_20171002/Intermediate'
172 304 in_shp = op.join(shp_dir, 'merged.shp')
  305 +
  306 + out_dir = '/mnt/data/home/baetensl/clouds_detection_git/Data_ALCD/Arles_31TFJ_20171002/kfold'
  307 + K = 10
  308 + k_split(in_shp, out_dir, K)
  309 +
  310 + return
  311 +
173 312 train_shp = op.join(shp_dir, 'train_points.shp')
174 313 validation_shp = op.join(shp_dir, 'validation_points.shp')
175 314  
... ...
PCC/parameters_files/comparison_parameters.json
... ... @@ -3,7 +3,7 @@
3 3 "cirrus_threshold": "200",
4 4 "dilatation_radius": "4",
5 5 "labeled_img_name": "labeled_img_regular.tif",
6   - "main_dir": "/mnt/data/home/baetensl/clouds_detection_git/Data_ALCD/Gobabeb_33KWP_20171014"
  6 + "main_dir": "/mnt/data/home/baetensl/clouds_detection_git/Data_ALCD/Ispra_32TMR_20171009"
7 7 },
8 8 "processing": {
9 9 "alcd_cirrus": {
... ... @@ -42,10 +42,10 @@
42 42 }
43 43 },
44 44 "user_choices": {
45   - "current_date": "20171014",
46   - "location": "Gobabeb",
47   - "main_dir": "/mnt/data/home/baetensl/clouds_detection_git/Data_PCC/Gobabeb_33KWP_20171014",
48   - "raw_img": "Gobabeb_bands.tif",
49   - "tile": "33KWP"
  45 + "current_date": "20171009",
  46 + "location": "Ispra",
  47 + "main_dir": "/mnt/data/home/baetensl/clouds_detection_git/Data_PCC/Ispra_32TMR_20171009",
  48 + "raw_img": "Ispra_bands.tif",
  49 + "tile": "32TMR"
50 50 }
51 51 }
52 52 \ No newline at end of file
... ...
Tools/kfold_all 0 → 100644
... ... @@ -0,0 +1,8 @@
  1 +location=Arles
  2 +for date in 20170917 20171002 20171221
  3 +do
  4 + python all_run_alcd.py -f false -s 0 -l $location -d $date -c $date
  5 + python all_run_alcd.py -f false -kfold true
  6 + python metrics_exploitation.py
  7 +done
  8 +
... ...
Tools/kfold_all_true 0 → 100644
... ... @@ -0,0 +1,42 @@
  1 +location=Gobabeb
  2 +for date in 20161221 20170909 20171014 20180209
  3 +do
  4 + python all_run_alcd.py -f false -s 0 -l $location -d $date -c $date
  5 + python all_run_alcd.py -f false -kfold true
  6 + python metrics_exploitation.py
  7 +done
  8 +location=Marrakech
  9 +for date in 20160417 20170621 20171218
  10 +do
  11 + python all_run_alcd.py -f false -s 0 -l $location -d $date -c $date
  12 + python all_run_alcd.py -f false -kfold true
  13 + python metrics_exploitation.py
  14 +done
  15 +location=Mongu
  16 +for date in 20161112 20170804 20171013
  17 +do
  18 + python all_run_alcd.py -f false -s 0 -l $location -d $date -c $date
  19 + python all_run_alcd.py -f false -kfold true
  20 + python metrics_exploitation.py
  21 +done
  22 +location=Orleans
  23 +for date in 20170516 20170819 20180218
  24 +do
  25 + python all_run_alcd.py -f false -s 0 -l $location -d $date -c $date
  26 + python all_run_alcd.py -f false -kfold true
  27 + python metrics_exploitation.py
  28 +done
  29 +location=Pretoria
  30 +for date in 20170313 20170820 20171014 20171213
  31 +do
  32 + python all_run_alcd.py -f false -s 0 -l $location -d $date -c $date
  33 + python all_run_alcd.py -f false -kfold true
  34 + python metrics_exploitation.py
  35 +done
  36 +location=RailroadValley
  37 +for date in 20170501 20170827 20180213
  38 +do
  39 + python all_run_alcd.py -f false -s 0 -l $location -d $date -c $date
  40 + python all_run_alcd.py -f false -kfold true
  41 + python metrics_exploitation.py
  42 +done
... ...
Tools/kfold_launch 0 → 100644
... ... @@ -0,0 +1,8 @@
  1 +location=Pretoria
  2 +date=20171213
  3 +python all_run_alcd.py -f false -s 0 -l $location -d $date -c $date
  4 +python all_run_alcd.py -f false -kfold true
  5 +python metrics_exploitation.py
  6 +#python all_run_alcd.py -f false -s 0 -l Pretoria -d 20171014 -c 20171014
  7 +#python all_run_alcd.py -f false -kfold true
  8 +#python metrics_exploitation.py
... ...