"""
  This is a utility program which provides functionality required for loading 
  data from cluster (ground truth and prediction), similarity and constraint 
  plausibility files.
  
  Since the optional comments in the cluster files are not used in the record 
  linkage process at the moment, those optional comments are ignored when 
  loading the data. 
"""

import csv
import numpy
import gzip
import time

def load_cluster_file(cluster_file):
  '''
  This function converts a cluster file of the following format to a list of 
  tuples, where each tuple contains records from a single cluster.
    <num_clusers>
    c1_a, c1_b, ...c1_x, <optional comment>
    c2_a, c2_b, ...c2_y, <optional comment>
  
  input
  -----
    cluster_file - A CSV file of the predefined format.
    
  output
  ------
    cluster_list - List to contain clusters of records as tuples.
  '''
  
  cluster_list = []  # List of clusters 
  
  # Assert file is CSV
  assert (cluster_file.endswith('.csv')), (cluster_file)
  
  try:
    csv_file = open(cluster_file, 'r')
  except:
    print 'File "%s" not found.' % (cluster_file)
    raise IOError

  # Populate the cluster_list with content in the cluster_file
  csv_reader = csv.reader(csv_file)
  first_line = next(csv_reader)
  assert len(first_line) == 1, len(first_line)
  assert (first_line[0].strip().isdigit()), (first_line[0])
  
  # The number of clusters reported to appear in the file
  reported_num_clusters = int(first_line[0].strip())
  
  actual_num_clusters = 0  # Actual number of clusters
  largest_cluster_size = 0  # Size of largest cluster
  
  for row in csv_reader:
    cluster_list.append(tuple(sorted(row[:-1])))  # Add the record IDs from 
           # one cluster. Last element is exempted since it's a comment.
                 
    assert '' not in row[:-1]  # Assert that there are no empty strings in 
                  # the cluster.
                  
    if (largest_cluster_size < len(row[:-1])):
      largest_cluster_size = len(row[:-1])
      
    actual_num_clusters += 1 
  
  csv_file.close() 
  
  # Check whether the number of clusters claimed to be in the file is correct.
  assert actual_num_clusters == reported_num_clusters, (actual_num_clusters,
                              reported_num_clusters)        
  assert largest_cluster_size > 2, (largest_cluster_size)  # The largest 
            # cluster size should be greater than two because we 
            # are considering groups of records, not record pairs.
            
  return cluster_list

# -----------------------------------------------------------------------------

def load_sim_file(sim_file):
  '''
  This function converts a similarity file of the following format to two sets 
  and a dictionary where the first contains all records in the data set, the 
  second contains all records in the data set which do not occur in the pairwise 
  similarity calculation, and the third contains record pairs and the 
  corresponding pairwise similarities.
    <num_rec>
    <num_links>
    rec_id_1, <optional attributes about record with rec_id_1>
    rec_id_2, <optional attributes about record with rec_id_2>
    ...
    rec_id_[num_rec], <optional attributes about record with rec_id_[num_rec]>
    rec_id_a,rec_id_b,sim_val, <optional attributes about link rec_a to rec_b>
    ...
    rec_id_x,rec_id_y,sim_val, <optional attributes about link rec_x to rec_y>
  
  input
  -----
    sim_file - A CSV file of the predefined format.
    
  output
  ------
    all_record_set - Set containing all records from the data set.
    
    not_in_sim_dict_set - Set of record IDs in the data set which do not
                          occur in the similarity dictionary.
                          
    sim_dict - Dictionary to contain record pair and their corresponding 
               similarity where record pair -> key, similarity -> value.
  '''
  print
  print 'Loading the similarity file...'
  print
  
  all_record_set = set()  # Set of all records.
  rec_in_sim_dict_set = set() # Set of records in the similarity dictionary.
  sim_dict = {} # Dictionary of pairwise similarities.
  
  # Assert file is CSV
  assert (sim_file.endswith('.csv')), (sim_file)
  
  try:
    csv_file = open(sim_file, 'r')
  except:
    print 'File "%s" not found.' % (sim_file)
    raise IOError

  # Populate the record_list and sim_dict with content in the sim_file
  csv_reader = csv.reader(csv_file)
  first_line = next(csv_reader)
  second_line = next(csv_reader)
  assert len(first_line) == 1, len(first_line)
  assert len(second_line) == 1, len(second_line)
  assert (first_line[0].strip().isdigit()), (first_line[0])
  assert (second_line[0].strip().isdigit()), (second_line[0])
  
  # The number of records reported to appear in the file
  reported_num_records = int(first_line[0].strip())
  # The number of record pairs reported to appear in the file
  reported_num_pairs = int(second_line[0].strip())
  
  actual_num_records = 0  # Actual number of records
  actual_num_pairs = 0  # Actual number of record pairs
  
  for row in csv_reader:
    if (len(row) == 2):
      all_record_set.add(row[0])  # Add each record ID from 
             # the data set. Second element is exempted since it's a comment.
      actual_num_records += 1
  
    else:
      assert len(row) == 4, len(row)
      rec_in_sim_dict_set.add(row[0]) # Add records with a similarity value.
      rec_in_sim_dict_set.add(row[1])
      sorted_pair = sorted([row[0],row[1]])
      sim_dict[tuple(sorted_pair)] = float(row[2]) 
                              # Last element is exempted since it's a comment.
      actual_num_pairs += 1 
  
  csv_file.close() 
  
  # Check if the number of records/pairs claimed to be in the file is correct.
  assert reported_num_records == actual_num_records, (reported_num_records,
                              actual_num_records)
  assert reported_num_pairs == actual_num_pairs, (reported_num_pairs,
                              actual_num_pairs)        
  not_in_sim_dict_set = all_record_set - rec_in_sim_dict_set
  return all_record_set, not_in_sim_dict_set, sim_dict

# -----------------------------------------------------------------------------
  
def load_constraints_file(constr_file):
  '''
  This function converts a constraint file of the following format to a 
  dictionary.
    Rec_ID    <rec_id_z, ..., rec_id_c, rec_id_b, rec_id_a>
    rec_id_a          1,  ...,         ,        1
    rec_id_b           ,  ...,       1
    rec_id_c          1,  ...
    ...               ...
    rec_id_z
  
  In this file, a value of 1 indicates a plausible record pair and empty string 
  otherwise. Note that we expect the file to be compressed as gz.
  
  input
  -----
    constr_file - A CSV file of the predefined format.
    
  output
  ------
    not_plaus_dict -  A dictionary containing an individual ID as the key,
        and a set of IDs of other records which are not plausible to occur with
        the key ID according to some constrain, as the value.
  '''
  
  print 'Loading the constraints file...'
  print
  
  not_plaus_dict = {} # Output dictionary
  
  # Assert file is a compressed gz file CSV
  assert (constr_file.endswith('.csv.gz')), (constr_file)
  
  try:
    csv_file = gzip.open(constr_file, 'rb')
  except:
    print 'File "%s" not found.' % (constr_file)
    raise IOError
  
  csv_reader = csv.reader(csv_file)
  
  column_rec_id_list = next(csv_reader)[1:] # Records in the data set as per 
                # column. First element is a name as 'Rec_ID' which is exempted.
  row_rec_id_list = [] # Records in the data set as per row
  
  index = 0 # Row index where a record ID appears
  for row in csv_reader:
    ind_id = row[0] # The record ID for which plausibilities are stored
    plaus_list = row[1:] # The plausibility values for ind_id.
    
    assert len(plaus_list) == len(column_rec_id_list) - (index + 1) # Assert   
    # that the correct number of plausibility values appear according to the  
    # format of the constr_file.
    
    for i in range(len(plaus_list)):
      if (plaus_list[i] == ''): # Empty string indicates non plausibility.
        not_plaus_rec_id = column_rec_id_list[i] # Record that is not possible
                                                 # to occur with ind_id.
                                                 
        # Add not_plaus_rec_id as a not plausible record with ind_id to 
        # not_plaus_dict
        not_plaus_set = not_plaus_dict.get(ind_id, set())
        not_plaus_set.add(not_plaus_rec_id)
        not_plaus_dict[ind_id] = not_plaus_set
        
        # Add ind_id as a not plausible record with not_plaus_rec_id to 
        # not_plaus_dict
        not_plaus_set = not_plaus_dict.get(not_plaus_rec_id, set())
        not_plaus_set.add(ind_id)
        not_plaus_dict[not_plaus_rec_id] = not_plaus_set
    
    row_rec_id_list.append(ind_id)
    index += 1
  
  csv_file.close()
  row_rec_id_list.reverse()

  assert row_rec_id_list == column_rec_id_list, "Row IDs are not " + \
                                                "in reverse order of column IDs"
    
  return not_plaus_dict
    
# -----------------------------------------------------------------------------

def generate_final_cluster_tuples(cluster_list, ind_id_never_compared_set, 
                                  not_assigned_ind_set, all_record_set, 
                                  ind_pair_sim_dict):
  '''
  This function generates the final cluster tuples containing a tuple of its 
  identifiers, the cluster size, and a quintuplet with the minimum, average, 
  median, maximum and the density (ratio of the number of edges in each cluster 
  from the similarity dictionary, divided by the total number of possible edges)
  
  input
  -----
    cluster_list - A list if clusters generated by a record linkage clustering
                   algorithm.
                   
    ind_id_never_compared_set - Set of record IDs in the data set which do not
                                occur in the similarity dictionary.
                                
    not_assigned_ind_set - Records which were not assigned to a cluster by the
                           algorithm.
                           
    all_record_set - Set containing all records from the data set.
                                
    ind_pair_sim_dict - A pairwise similarity dictionary containing a sorted 
                        record pair as the key, and the corresponding similarity
                        as the value.
                      
  output
  ------
  final_cluster_list - A list of clusters, each as a tuple made of the 
                       individual identifiers in the cluster, the cluster size, 
                       and a quintuplet with the minimum, average, median, 
                       maximum and the density (ratio of the number of edges 
                       in each cluster from the similarity dictionary, divided 
                       by the total number of possible edges).
    
  '''
  
  cluster_size_list = [] # List to contain sizes of non-singleton clusters.
  final_cluster_list = []
  records_in_prediction = set() # Records included in clusters
  cluster_size_dist = {} # Dictionary to contain the cluster size distribution.
                         # The key is the cluster size and the value is the
                         # number of clusters of that size.
  
  for cluster in cluster_list:
    assert (len(cluster) > 0)
    if (len(cluster) == 1):
      final_cluster_list.append((tuple(cluster), 1, 
                           (0.0, 0.0, 0.0, 0.0, 1.0)))
      assert list(cluster)[0] not in records_in_prediction, list(cluster)[0]   
                        # Assert that record is not appearing in two clusters.
      records_in_prediction.add(list(cluster)[0])
      
      cluster_size_dist[1] = cluster_size_dist.get(1,0) + 1 
      
    else:
      # Cluster is a group of records. Calculate density and similarity 
      # statistics of this cluster.
      #
      cluster_sim_list = []
      sorted_cluster_id_list = sorted(cluster)
      cluster_size = len(sorted_cluster_id_list)
      
      for (i, ind_id1) in enumerate(sorted_cluster_id_list[:-1]):
        for ind_id2 in sorted_cluster_id_list[i + 1:]:
          ind_id_pair = (ind_id1, ind_id2)
            
          if (ind_id_pair in ind_pair_sim_dict):
            pair_sim = ind_pair_sim_dict[ind_id_pair]
            cluster_sim_list.append(pair_sim)
          else:
              cluster_sim_list.append(0.0)
              
        assert ind_id1 not in records_in_prediction, ind_id1 # Assert that  
                                  # record is not appearing in two clusters.
        records_in_prediction.add(ind_id1)
        
      assert sorted_cluster_id_list[-1] not in records_in_prediction, \
                       sorted_cluster_id_list[-1] # Assert that record is 
                                            # not appearing in two clusters.
      records_in_prediction.add(sorted_cluster_id_list[-1])
      
      cluster_min_sim = min(cluster_sim_list)
      cluster_avr_sim = numpy.mean(cluster_sim_list)
      cluster_med_sim = numpy.median(cluster_sim_list)
      cluster_max_sim = max(cluster_sim_list)
      cluster_density = float(len(cluster_sim_list)) / \
                        (cluster_size * (cluster_size - 1) * 0.5)
      
      cluster_sim_tuple = (cluster_min_sim, cluster_avr_sim,
                           cluster_med_sim, cluster_max_sim,
                           cluster_density)
      
      final_cluster_list.append((tuple(sorted_cluster_id_list), cluster_size,
                           cluster_sim_tuple))
      
      cluster_size_dist[cluster_size] = cluster_size_dist.get(cluster_size,0) + 1 
      cluster_size_list.append(cluster_size)

  # Append singletons to cluster list from the never compared set
  # and the individuals not assigned to clusters set.
  #
  for ind_id in not_assigned_ind_set | ind_id_never_compared_set:
    final_cluster_list.append(((ind_id,), 1, (0.0, 0.0, 0.0, 0.0, 1.0)))
    assert ind_id not in records_in_prediction, ind_id # Assert that record is 
                                              # not appearing in two clusters.
    records_in_prediction.add(ind_id)
    
    cluster_size_dist[1] = cluster_size_dist.get(1,0) + 1
  
  assert records_in_prediction == all_record_set, (len(all_record_set) - \
                                                   len(records_in_prediction))
                                                   # Assert that all records in 
                          # the data set are included in the predicted clusters.
  
  num_singleton = cluster_size_dist.get(1,0)
  # Print calculated statistics
  #
  final_num_cluster = len(final_cluster_list)
  print
  print 'Statistics of the clusters generated by the algorithm >>>'
  print 'Number of clusters:    %d' %(final_num_cluster)
  print 'Number of singletons:  %d' %(num_singleton)
  if ((final_num_cluster - num_singleton) > 0):
    print 'Minimum, average, median and maximum sizes of ' + \
        'non-singleton clusters: %d / %.1f / %d / %d' % \
        (min(cluster_size_list), numpy.mean(cluster_size_list),
         numpy.median(cluster_size_list), max(cluster_size_list))
  print
  
  print 'Size distribution of the final clusters'
  print 'size  cluster-count'
  sorted_clust_sizes = sorted(cluster_size_dist.keys())
  for cluster_size in sorted_clust_sizes:
    count = cluster_size_dist[cluster_size]
    print '%4d  %4d' %(cluster_size, count)
  print
  
  print
  
  return final_cluster_list

# -----------------------------------------------------------------------------

def print_cluster_prediction(final_cluster_list, min_cluster_sim, algo_name):
  '''
  This function prints the final clusters generated from a given algorithm into
  a file. The cluster list is expected to be in the defined format.
  
  input
  -----
  final_cluster_list - A list of clusters, each as a tuple made of the 
                       individual identifiers in the cluster, the cluster size, 
                       and a quintuplet with the minimum, average, median, 
                       maximum and the density (ratio of the number of edges 
                       in each cluster from the similarity dictionary, divided 
                       by the total number of possible edges).
  
  min_cluster_sim - The similarity threshold to consider for filtering 
                    record pairs.
  
  algo_name - The name of the algorithm used to generate the clusters
                       
  output
  ------
  None
  '''
  
  now_str = time.strftime("%Y%m%d_%H%M", time.localtime())
  
  file_name = algo_name + '_' + now_str + '_' + str(min_cluster_sim) + '.csv'
  print 'Printing predicted clusters to file %s' %(file_name)
  print
  
  output_file = open(file_name, 'w')
  csv_file = csv.writer(output_file)
  
  csv_file.writerow([len(final_cluster_list)]) # Write the number of clusters
  for cluster_and_stat_tuple in final_cluster_list:
    cluster = list(cluster_and_stat_tuple[0])
    cluster.append('') # The last element is reserved for optional attributes
    csv_file.writerow(cluster)
    
  output_file.close()