# IO functions:
#
# * Read graph from file
# * Write graph to file
# * Read fluency data from file
from . import *
# alias for backwards compatibility
[docs]
def write_graph(*args, **kwargs):
"""
Alias for write_network for backwards compatibility.
"""
return write_network(*args, **kwargs)
# alias for backwards compatibility
[docs]
def readX(*args, **kwargs):
"""
Alias for load_fluency_data for backwards compatibility.
"""
return load_fluency_data(*args, **kwargs)
# alias for backwards compatibility
[docs]
def load_graph(*args, **kwargs):
"""
Alias for load_network for backwards compatibility.
"""
return load_network(*args, **kwargs)
# alias for backwards compatibility
[docs]
def read_graph(*args, **kwargs):
"""
Alias for load_network for backwards compatibility.
"""
return load_network(*args, **kwargs)
# reads in graph from CSV
# row order not preserved; could be optimized more
[docs]
def load_network(fh,cols=(0,1),header=False,filters={},undirected=True,sparse=False):
"""
Load a network from a CSV file and return its adjacency matrix and node labels.
Parameters
----------
fh : str
Filepath to a CSV file containing the network edges.
cols : tuple of int or str, optional
Column indices or names indicating source and target node columns. Default is (0, 1).
header : bool, optional
Whether the CSV file has a header row. Default is False.
filters : dict, optional
Dictionary to filter rows based on additional column values. Only used if `header=True`.
undirected : bool, optional
If True, adds symmetric edges to make the graph undirected. Default is True.
sparse : bool, optional
If True, returns a sparse adjacency matrix. Otherwise, returns a dense NumPy array. Default is False.
Returns
-------
graph : ndarray or csr_matrix
The adjacency matrix representing the network.
items : dict
Mapping from node indices to node labels.
"""
fh=open(fh,'rt', encoding='utf-8-sig')
idx=0
bigdict={}
if header:
headerrow=fh.readline().split('\n')[0].split(',')
cols=(headerrow.index(cols[0]),headerrow.index(cols[1]))
filterrows={}
for i in list(filters.keys()):
filterrows[headerrow.index(i)]=filters[i]
else:
filterrows={}
done=dict()
for line in fh:
line=line.rstrip()
linesplit=line.split(',')
twoitems=[linesplit[cols[0]],linesplit[cols[1]]]
skiprow=0
for i in filterrows:
if linesplit[i]!=filterrows[i]:
skiprow=1
if skiprow==1:
continue
try:
if twoitems[1] not in bigdict[twoitems[0]]:
bigdict[twoitems[0]].append(twoitems[1])
except:
bigdict[twoitems[0]] = [twoitems[1]]
if twoitems[1] not in bigdict: # doesn't this scale with dictionary size-- something i was trying to avoid by rewriting this function?
bigdict[twoitems[1]] = []
items_rev = dict(list(zip(list(bigdict.keys()),list(range(len(list(bigdict.keys())))))))
items = dict(list(zip(list(range(len(list(bigdict.keys())))),list(bigdict.keys()))))
if sparse:
from scipy.sparse import csr_matrix
rows=[]
cols=[]
numedges=0
for i in bigdict:
for j in bigdict[i]:
rows.append(items_rev[i])
cols.append(items_rev[j])
numedges += 1
if undirected:
rows.append(items_rev[j])
cols.append(items_rev[i])
numedges += 1
data=np.array([1]*numedges)
rows=np.array(rows)
cols=np.array(cols)
graph = csr_matrix((data, (rows, cols)), shape=(len(items),len(items)))
else:
graph = np.zeros((len(items),len(items)))
for item1 in bigdict:
for item2 in bigdict[item1]:
idx1=items_rev[item1]
idx2=items_rev[item2]
graph[idx1,idx2]=1
if undirected:
graph[idx2,idx1]=1
return graph, items
[docs]
def load_fluency_data(filepath,category=None,removePerseverations=False,removeIntrusions=False,spell=None,scheme=None,group=None,subject=None,removeNonAlphaChars=False,hierarchical=False,targetletter=None):
"""
Load verbal fluency data from a CSV file and preprocess it according to options.
Parameters
----------
filepath : str
Path to the CSV file containing fluency data.
category : str or list of str, optional
Restrict to specific semantic categories.
removePerseverations : bool, optional
If True, remove repeated responses in the same list.
removeIntrusions : bool, optional
If True, remove responses not in the scheme or target letter.
spell : str, optional
Path to a spelling correction file.
scheme : str, optional
Path to a valid item list for the given category.
group : str or list of str, optional
Filter data by group.
subject : str or list of str, optional
Filter data by subject ID.
removeNonAlphaChars : bool, optional
If True, remove non-alphabetic characters from item names.
hierarchical : bool, optional
If True, marks the returned structure as hierarchical.
targetletter : str, optional
Restrict responses to those starting with this letter (used for letter fluency tasks).
Returns
-------
Data
A Data object containing structured fluency data including:
- 'Xs': response index lists,
- 'items': subject-specific item index mappings,
- 'irts': inter-response times (if present),
- 'categories': category info,
- 'spell_corrected': spelling corrections,
- 'perseverations': repeated items,
- 'intrusions': invalid responses.
"""
if targetletter:
targetletter = targetletter.lower()
if type(group) is str:
group = [group]
if type(subject) is str:
subject = [subject]
if type(category) is str:
category = [category]
# grab header col indices
mycsv = csv.reader(open(filepath,'rt', encoding='utf-8-sig'))
headers = next(mycsv, None)
subj_col = headers.index('id')
listnum_col = headers.index('listnum')
item_col = headers.index('item')
# check for optional columns
try:
category_col = headers.index('category')
has_category_col = True
except:
has_category_col = False
try:
group_col = headers.index('group')
has_group_col = True
except:
has_group_col = False
if group:
raise ValueError('Data file does not have grouping column, but you asked for a specific group.')
try:
rt_col = headers.index('rt')
has_rt_col = True
except:
has_rt_col = False
try:
itemnum_col = headers.index('itemnum')
has_itemnum_col = True
except:
has_itemnum_col = False
Xs=dict()
irts=dict()
categories=dict()
items=dict()
spellingdict=dict()
spell_corrected = dict()
perseverations = dict()
intrusions = dict()
validitems=[]
# read in list of valid items when removeIntrusions = True
if removeIntrusions:
if (not scheme) and (not targetletter):
raise ValueError('You need to provide a scheme or targetletter if you want to ignore intrusions!')
elif scheme:
with open(scheme,'rt', encoding='utf-8-sig') as fh:
for line in fh:
if line[0] == "#": continue # skip commented lines
try:
validitems.append(line.rstrip().split(',')[1].lower())
except:
pass # fail silently on wrong format
# read in spelling correction dictionary when spell is specified
if spell:
with open(spell,'rt', encoding='utf-8-sig') as spellfile:
for line in spellfile:
if line[0] == "#": continue # skip commented lines
try:
correct, incorrect = line.rstrip().split(',')
spellingdict[incorrect] = correct
except:
pass # fail silently on wrong format
data_rows = []
with open(filepath,'rt', encoding='utf-8-sig') as f:
f.readline() # discard header row
for line in f:
if line[0] == "#": continue # skip commented lines
row = line.rstrip().split(',')
data_rows.append(row)
if has_itemnum_col:
def sort_key(row):
id_val = row[subj_col]
listnum_val = int(row[listnum_col])
itemnum_val = int(row[itemnum_col])
return (id_val, listnum_val, itemnum_val)
data_rows.sort(key=sort_key)
for row in data_rows:
storerow = True # if the row meets the filters specified then load it, else skip it
if (subject != None) and (row[subj_col] not in subject):
storerow = False
if (group != None) and (row[group_col] not in group):
storerow = False
if (category != None) and (row[category_col] not in category):
storerow = False
if storerow == True:
idx = row[subj_col]
listnum_int = int(row[listnum_col])
# make sure dict keys exist
if idx not in Xs:
Xs[idx] = dict()
spell_corrected[idx] = dict()
perseverations[idx] = dict()
intrusions[idx] = dict()
if has_rt_col:
irts[idx] = dict()
if has_category_col:
categories[idx] = dict()
if listnum_int not in Xs[idx]:
Xs[idx][listnum_int] = []
spell_corrected[idx][listnum_int] = []
perseverations[idx][listnum_int] = []
intrusions[idx][listnum_int] = []
if has_rt_col:
irts[idx][listnum_int] = []
if has_category_col:
categories[idx][listnum_int] = row[category_col]
if idx not in items:
items[idx] = dict()
# basic clean-up
item=row[item_col].lower()
if removeNonAlphaChars:
goodchars = []
for char in item:
if char.isalpha():
goodchars.append(char)
item = "".join(goodchars)
if item in list(spellingdict.keys()):
newitem = spellingdict[item]
spell_corrected[idx][listnum_int].append((item, newitem))
item = newitem
if has_rt_col:
irt=row[rt_col]
if item not in list(items[idx].values()):
if (item in validitems) or (not removeIntrusions) or (item[0] == targetletter):
item_count = len(items[idx])
items[idx][item_count]=item
else:
intrusions[idx][listnum_int].append(item) # record as intrusion
# add item to list
try:
itemval=list(items[idx].values()).index(item)
if (not removePerseverations) or (itemval not in Xs[idx][listnum_int]): # ignore any duplicates in same list resulting from spelling corrections
if (item in validitems) or (not removeIntrusions) or (item[0] == targetletter):
Xs[idx][listnum_int].append(itemval)
if has_rt_col:
irts[idx][listnum_int].append(int(irt))
else:
perseverations[idx][listnum_int].append(item) # record as perseveration
except:
pass # bad practice to have empty except
return Data({'Xs': Xs, 'items': items, 'irts': irts, 'structure': hierarchical, 'categories': categories,
'spell_corrected': spell_corrected, 'perseverations': perseverations, 'intrusions': intrusions})
[docs]
def write_network(gs, fh, subj="NA", directed=False, extra_data={}, header=True, labels=None, sparse=False):
"""
Write one or more graphs to a CSV file in edge-list format.
Parameters
----------
gs : list or networkx.Graph
A graph or list of graphs to be written.
fh : str
Path to the output file.
subj : str, optional
Subject identifier to include in the output. Default is "NA".
directed : bool, optional
If True, treat the graphs as directed. Default is False.
extra_data : dict, optional
Optional nested dictionary of additional edge-level data.
header : bool or str, optional
If True, write a default header. If str, use as custom header. Default is True.
labels : dict or list of dict, optional
Optional relabeling of graph node indices to string labels.
sparse : bool, optional
If True, only write edges that are present. If False, write all pair combinations. Default is False.
Returns
-------
None
Writes data to file and returns nothing.
"""
onezero={True: '1', False: '0'}
import networkx as nx
fh=open(fh,'w')
# if gs is not a list of graphs, then it should be a single graph
if not isinstance(gs, list):
gs = [gs]
# turn them all into networkx graphs if they aren't already
gs = [g if type(g) == nx.classes.graph.Graph else nx.to_networkx_graph(g) for g in gs]
# label nodes if labels are provided
if labels != None:
if not isinstance(labels, list):
labels = [labels]
gs = [nx.relabel_nodes(i[0], i[1], copy=False) for i in zip(gs, labels)]
nodes = list(set(flatten_list([list(gs[i].nodes()) for i in range(len(gs))])))
if header == True:
fh.write('subj,item1,item2,edge,\n') # default header
elif type(header) is str:
fh.write('subj,item1,item2,'+ header+'\n') # manual header if string is specified
for node1 in nodes:
for node2 in nodes:
if (node1 < node2) or ((directed) and (node1 != node2)): # write edges in alphabetical order unless directed graph
edge = (node1,node2)
edgelist=""
if sparse:
write_edge = 0
else:
write_edge = 1
for g in gs:
hasedge = onezero[g.has_edge(edge[0],edge[1])]
edgelist=edgelist+"," + hasedge # assumes graph is symmetrical if directed=True !!
write_edge += int(hasedge)
if write_edge > 0:
extrainfo=""
if edge[0] in list(extra_data.keys()):
if edge[1] in list(extra_data[edge[0]].keys()):
if isinstance(extra_data[edge[0]][edge[1]],list):
extrainfo=","+",".join([str(i) for i in extra_data[sortededge[0]][sortededge[1]]])
else:
extrainfo=","+str(extra_data[sortededge[0]][sortededge[1]])
fh.write(subj + "," +
str(edge[0]) + "," +
str(edge[1]) +
edgelist + "," +
extrainfo + "\n")
return