I'd like to find all of the PL indices in a VCF line for a specific allele at sites with an arbitrary number of alleles and arbitrary sample ploidy. For dipolid samples, the algorithm is straightforward, but things get more complicated for plodies 3 and above. The related question of finding the PL index given a genotype already has a nice algorithm http://gatkforums.broadinstitute.org/gatk/discussion/2157/gatk-vcf-pl-field-ordering-for-pooled-polyploid-samples.
I already have some code to do this (see below), but it's not very elegant. Does anyone have a better algorithm?
def find_pl_allele(allele, n_alleles, ploidy): | |
gts = [0] * ploidy | |
idx = ploidy - 1 | |
pl_pos = 0 | |
while (sum(gts) <= ploidy * (n_alleles - 1)): | |
if allele in gts: | |
yield pl_pos | |
if sum(gts) == ploidy * (n_alleles - 1): | |
break | |
if idx == 0: | |
while True: | |
if idx + 1 >= ploidy: | |
tmp = gts[idx] + 1 | |
gts = [0] * ploidy | |
gts[idx] = tmp | |
idx -= 1 | |
pl_pos += 1 | |
break | |
if gts[idx + 1] > gts[idx]: | |
gts[idx] += 1 | |
for j in range(idx): | |
gts[j] = 0 | |
idx = max(0, idx - 1) | |
pl_pos += 1 | |
break | |
idx += 1 | |
else: | |
gts[idx] += 1 | |
idx -= 1 | |
pl_pos += 1 | |
def find_pl_recursive(allele, n_alleles, ploidy): | |
if ploidy == 1: # base case | |
if allele < n_alleles: | |
return ([allele], n_alleles) | |
else: | |
return ([], n_alleles) | |
else: | |
pl_pos = 0 | |
res = [] | |
for i in range(n_alleles): | |
tmp_res, tmp_cnt = find_pl_recursive(allele, i + 1, ploidy - 1) | |
if (i == allele): | |
res += list(range(pl_pos, pl_pos + tmp_cnt)) | |
pl_pos += tmp_cnt | |
else: | |
res += [x + pl_pos for x in tmp_res] | |
pl_pos += tmp_cnt | |
return (res, pl_pos) | |
def n_pl(n_alleles, ploidy): | |
res = 1 | |
for i in range(ploidy): | |
res *= (n_alleles + i) / ( i + 1) | |
return int(res) | |
def unwind_pl_stk(stk): | |
stk_len = len(stk) | |
for i in range(stk_len - 1, 0, -1): | |
if stk[i] < stk[i - 1]: | |
stk[i] += 1 | |
return | |
else: | |
stk.pop() | |
else: | |
stk[0] += 1 | |
return | |
def find_pl(allele, n_alleles, ploidy): | |
allele_stk = [0] | |
pl_pos = 0 | |
while (allele_stk[0] < n_alleles): | |
stk_len = len(allele_stk) | |
if allele_stk[-1] == allele: | |
for i in range(n_pl(allele_stk[-1] + 1, ploidy - stk_len)): | |
yield pl_pos | |
pl_pos += 1 | |
unwind_pl_stk(allele_stk) | |
elif stk_len < ploidy: | |
allele_stk.append(0) | |
else: | |
pl_pos += 1 | |
unwind_pl_stk(allele_stk) | |
return | |
def find_pl_q(allele, n_alleles, ploidy): | |
allele_stk = [0] * ploidy | |
pl_pos = 0 | |
idx = 0 | |
while (allele_stk[0] < n_alleles): | |
if allele_stk[idx] == allele: | |
tmp = 1 | |
for i in range(ploidy - idx - 1): | |
tmp *= (allele_stk[idx] + 1 + i) / (i + 1) | |
for i in range(int(tmp)): | |
yield pl_pos | |
pl_pos += 1 | |
tmp = idx | |
for i in range(tmp, 0, -1): | |
if allele_stk[i] < allele_stk[i - 1]: | |
allele_stk[i] += 1 | |
break | |
else: | |
idx -= 1 | |
else: | |
idx = 0 | |
allele_stk[0] += 1 | |
elif idx + 1 < ploidy: | |
idx += 1 | |
allele_stk[idx] = 0 | |
else: | |
pl_pos += 1 | |
tmp = idx | |
for i in range(tmp, 0, -1): | |
if allele_stk[i] < allele_stk[i - 1]: | |
allele_stk[i] += 1 | |
break | |
else: | |
idx -= 1 | |
else: | |
idx = 0 | |
allele_stk[0] += 1 | |
return |
$ python3 | |
# Test the results # | |
>>> from find_pl_allele import * | |
>>> list(find_pl_allele(3,4,4)) | |
[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] | |
>>> list(find_pl_recursive(3,4,4)) | |
[[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], 35] | |
>>> list(find_pl(3,4,4)) | |
[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] | |
>>> list(find_pl_q(3,4,4)) | |
[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] | |
# Test the speed # | |
>>> timeit.Timer('list(find_pl_allele.find_pl_allele(8,9,9))', setup="import find_pl_allele").repeat(3, 5) | |
[0.35032107803272083, 0.3501879220129922, 0.34999616601271555] | |
>>> timeit.Timer('list(find_pl_allele.find_pl(8,9,9))', setup="import find_pl_allele").repeat(3, 5) | |
[0.17059430398512632, 0.17093267798190936, 0.17054276500130072] | |
>>> timeit.Timer('list(find_pl_allele.find_pl_q(8,9,9))', setup="import find_pl_allele").repeat(3, 5) | |
[0.12451086001237854, 0.12442409800132737, 0.12427003699121997] |