Source code for snafu.io

# IO functions:
#
# * Read graph from file
# * Write graph to file
# * Read fluency data from file

from . import *

# alias for backwards compatibility
[docs] def write_graph(*args, **kwargs): """ Alias for write_network for backwards compatibility. """ return write_network(*args, **kwargs)
# alias for backwards compatibility
[docs] def readX(*args, **kwargs): """ Alias for load_fluency_data for backwards compatibility. """ return load_fluency_data(*args, **kwargs)
# alias for backwards compatibility
[docs] def load_graph(*args, **kwargs): """ Alias for load_network for backwards compatibility. """ return load_network(*args, **kwargs)
# alias for backwards compatibility
[docs] def read_graph(*args, **kwargs): """ Alias for load_network for backwards compatibility. """ return load_network(*args, **kwargs)
# reads in graph from CSV # row order not preserved; could be optimized more
[docs] def load_network(fh,cols=(0,1),header=False,filters={},undirected=True,sparse=False): """ Load a network from a CSV file and return its adjacency matrix and node labels. Parameters ---------- fh : str Filepath to a CSV file containing the network edges. cols : tuple of int or str, optional Column indices or names indicating source and target node columns. Default is (0, 1). header : bool, optional Whether the CSV file has a header row. Default is False. filters : dict, optional Dictionary to filter rows based on additional column values. Only used if `header=True`. undirected : bool, optional If True, adds symmetric edges to make the graph undirected. Default is True. sparse : bool, optional If True, returns a sparse adjacency matrix. Otherwise, returns a dense NumPy array. Default is False. Returns ------- graph : ndarray or csr_matrix The adjacency matrix representing the network. items : dict Mapping from node indices to node labels. """ fh=open(fh,'rt', encoding='utf-8-sig') idx=0 bigdict={} if header: headerrow=fh.readline().split('\n')[0].split(',') cols=(headerrow.index(cols[0]),headerrow.index(cols[1])) filterrows={} for i in list(filters.keys()): filterrows[headerrow.index(i)]=filters[i] else: filterrows={} done=dict() for line in fh: line=line.rstrip() linesplit=line.split(',') twoitems=[linesplit[cols[0]],linesplit[cols[1]]] skiprow=0 for i in filterrows: if linesplit[i]!=filterrows[i]: skiprow=1 if skiprow==1: continue try: if twoitems[1] not in bigdict[twoitems[0]]: bigdict[twoitems[0]].append(twoitems[1]) except: bigdict[twoitems[0]] = [twoitems[1]] if twoitems[1] not in bigdict: # doesn't this scale with dictionary size-- something i was trying to avoid by rewriting this function? bigdict[twoitems[1]] = [] items_rev = dict(list(zip(list(bigdict.keys()),list(range(len(list(bigdict.keys()))))))) items = dict(list(zip(list(range(len(list(bigdict.keys())))),list(bigdict.keys())))) if sparse: from scipy.sparse import csr_matrix rows=[] cols=[] numedges=0 for i in bigdict: for j in bigdict[i]: rows.append(items_rev[i]) cols.append(items_rev[j]) numedges += 1 if undirected: rows.append(items_rev[j]) cols.append(items_rev[i]) numedges += 1 data=np.array([1]*numedges) rows=np.array(rows) cols=np.array(cols) graph = csr_matrix((data, (rows, cols)), shape=(len(items),len(items))) else: graph = np.zeros((len(items),len(items))) for item1 in bigdict: for item2 in bigdict[item1]: idx1=items_rev[item1] idx2=items_rev[item2] graph[idx1,idx2]=1 if undirected: graph[idx2,idx1]=1 return graph, items
[docs] def load_fluency_data(filepath,category=None,removePerseverations=False,removeIntrusions=False,spell=None,scheme=None,group=None,subject=None,removeNonAlphaChars=False,hierarchical=False,targetletter=None): """ Load verbal fluency data from a CSV file and preprocess it according to options. Parameters ---------- filepath : str Path to the CSV file containing fluency data. category : str or list of str, optional Restrict to specific semantic categories. removePerseverations : bool, optional If True, remove repeated responses in the same list. removeIntrusions : bool, optional If True, remove responses not in the scheme or target letter. spell : str, optional Path to a spelling correction file. scheme : str, optional Path to a valid item list for the given category. group : str or list of str, optional Filter data by group. subject : str or list of str, optional Filter data by subject ID. removeNonAlphaChars : bool, optional If True, remove non-alphabetic characters from item names. hierarchical : bool, optional If True, marks the returned structure as hierarchical. targetletter : str, optional Restrict responses to those starting with this letter (used for letter fluency tasks). Returns ------- Data A Data object containing structured fluency data including: - 'Xs': response index lists, - 'items': subject-specific item index mappings, - 'irts': inter-response times (if present), - 'categories': category info, - 'spell_corrected': spelling corrections, - 'perseverations': repeated items, - 'intrusions': invalid responses. """ if targetletter: targetletter = targetletter.lower() if type(group) is str: group = [group] if type(subject) is str: subject = [subject] if type(category) is str: category = [category] # grab header col indices mycsv = csv.reader(open(filepath,'rt', encoding='utf-8-sig')) headers = next(mycsv, None) subj_col = headers.index('id') listnum_col = headers.index('listnum') item_col = headers.index('item') # check for optional columns try: category_col = headers.index('category') has_category_col = True except: has_category_col = False try: group_col = headers.index('group') has_group_col = True except: has_group_col = False if group: raise ValueError('Data file does not have grouping column, but you asked for a specific group.') try: rt_col = headers.index('rt') has_rt_col = True except: has_rt_col = False try: itemnum_col = headers.index('itemnum') has_itemnum_col = True except: has_itemnum_col = False Xs=dict() irts=dict() categories=dict() items=dict() spellingdict=dict() spell_corrected = dict() perseverations = dict() intrusions = dict() validitems=[] # read in list of valid items when removeIntrusions = True if removeIntrusions: if (not scheme) and (not targetletter): raise ValueError('You need to provide a scheme or targetletter if you want to ignore intrusions!') elif scheme: with open(scheme,'rt', encoding='utf-8-sig') as fh: for line in fh: if line[0] == "#": continue # skip commented lines try: validitems.append(line.rstrip().split(',')[1].lower()) except: pass # fail silently on wrong format # read in spelling correction dictionary when spell is specified if spell: with open(spell,'rt', encoding='utf-8-sig') as spellfile: for line in spellfile: if line[0] == "#": continue # skip commented lines try: correct, incorrect = line.rstrip().split(',') spellingdict[incorrect] = correct except: pass # fail silently on wrong format data_rows = [] with open(filepath,'rt', encoding='utf-8-sig') as f: f.readline() # discard header row for line in f: if line[0] == "#": continue # skip commented lines row = line.rstrip().split(',') data_rows.append(row) if has_itemnum_col: def sort_key(row): id_val = row[subj_col] listnum_val = int(row[listnum_col]) itemnum_val = int(row[itemnum_col]) return (id_val, listnum_val, itemnum_val) data_rows.sort(key=sort_key) for row in data_rows: storerow = True # if the row meets the filters specified then load it, else skip it if (subject != None) and (row[subj_col] not in subject): storerow = False if (group != None) and (row[group_col] not in group): storerow = False if (category != None) and (row[category_col] not in category): storerow = False if storerow == True: idx = row[subj_col] listnum_int = int(row[listnum_col]) # make sure dict keys exist if idx not in Xs: Xs[idx] = dict() spell_corrected[idx] = dict() perseverations[idx] = dict() intrusions[idx] = dict() if has_rt_col: irts[idx] = dict() if has_category_col: categories[idx] = dict() if listnum_int not in Xs[idx]: Xs[idx][listnum_int] = [] spell_corrected[idx][listnum_int] = [] perseverations[idx][listnum_int] = [] intrusions[idx][listnum_int] = [] if has_rt_col: irts[idx][listnum_int] = [] if has_category_col: categories[idx][listnum_int] = row[category_col] if idx not in items: items[idx] = dict() # basic clean-up item=row[item_col].lower() if removeNonAlphaChars: goodchars = [] for char in item: if char.isalpha(): goodchars.append(char) item = "".join(goodchars) if item in list(spellingdict.keys()): newitem = spellingdict[item] spell_corrected[idx][listnum_int].append((item, newitem)) item = newitem if has_rt_col: irt=row[rt_col] if item not in list(items[idx].values()): if (item in validitems) or (not removeIntrusions) or (item[0] == targetletter): item_count = len(items[idx]) items[idx][item_count]=item else: intrusions[idx][listnum_int].append(item) # record as intrusion # add item to list try: itemval=list(items[idx].values()).index(item) if (not removePerseverations) or (itemval not in Xs[idx][listnum_int]): # ignore any duplicates in same list resulting from spelling corrections if (item in validitems) or (not removeIntrusions) or (item[0] == targetletter): Xs[idx][listnum_int].append(itemval) if has_rt_col: irts[idx][listnum_int].append(int(irt)) else: perseverations[idx][listnum_int].append(item) # record as perseveration except: pass # bad practice to have empty except return Data({'Xs': Xs, 'items': items, 'irts': irts, 'structure': hierarchical, 'categories': categories, 'spell_corrected': spell_corrected, 'perseverations': perseverations, 'intrusions': intrusions})
[docs] def write_network(gs, fh, subj="NA", directed=False, extra_data={}, header=True, labels=None, sparse=False): """ Write one or more graphs to a CSV file in edge-list format. Parameters ---------- gs : list or networkx.Graph A graph or list of graphs to be written. fh : str Path to the output file. subj : str, optional Subject identifier to include in the output. Default is "NA". directed : bool, optional If True, treat the graphs as directed. Default is False. extra_data : dict, optional Optional nested dictionary of additional edge-level data. header : bool or str, optional If True, write a default header. If str, use as custom header. Default is True. labels : dict or list of dict, optional Optional relabeling of graph node indices to string labels. sparse : bool, optional If True, only write edges that are present. If False, write all pair combinations. Default is False. Returns ------- None Writes data to file and returns nothing. """ onezero={True: '1', False: '0'} import networkx as nx fh=open(fh,'w') # if gs is not a list of graphs, then it should be a single graph if not isinstance(gs, list): gs = [gs] # turn them all into networkx graphs if they aren't already gs = [g if type(g) == nx.classes.graph.Graph else nx.to_networkx_graph(g) for g in gs] # label nodes if labels are provided if labels != None: if not isinstance(labels, list): labels = [labels] gs = [nx.relabel_nodes(i[0], i[1], copy=False) for i in zip(gs, labels)] nodes = list(set(flatten_list([list(gs[i].nodes()) for i in range(len(gs))]))) if header == True: fh.write('subj,item1,item2,edge,\n') # default header elif type(header) is str: fh.write('subj,item1,item2,'+ header+'\n') # manual header if string is specified for node1 in nodes: for node2 in nodes: if (node1 < node2) or ((directed) and (node1 != node2)): # write edges in alphabetical order unless directed graph edge = (node1,node2) edgelist="" if sparse: write_edge = 0 else: write_edge = 1 for g in gs: hasedge = onezero[g.has_edge(edge[0],edge[1])] edgelist=edgelist+"," + hasedge # assumes graph is symmetrical if directed=True !! write_edge += int(hasedge) if write_edge > 0: extrainfo="" if edge[0] in list(extra_data.keys()): if edge[1] in list(extra_data[edge[0]].keys()): if isinstance(extra_data[edge[0]][edge[1]],list): extrainfo=","+",".join([str(i) for i in extra_data[sortededge[0]][sortededge[1]]]) else: extrainfo=","+str(extra_data[sortededge[0]][sortededge[1]]) fh.write(subj + "," + str(edge[0]) + "," + str(edge[1]) + edgelist + "," + extrainfo + "\n") return