diff --git a/pori_python/ipr/content.spec.json b/pori_python/ipr/content.spec.json index 5a1793a..9f03189 100644 --- a/pori_python/ipr/content.spec.json +++ b/pori_python/ipr/content.spec.json @@ -202,6 +202,16 @@ "number", "null" ] + }, + "flags": { + "description": "variant flags", + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] } }, "required": [ @@ -475,6 +485,16 @@ "null", "string" ] + }, + "flags": { + "description": "variant flags", + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] } }, "required": [ @@ -1106,6 +1126,16 @@ "string", "null" ] + }, + "flags": { + "description": "variant flags", + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] } }, "required": [ @@ -1161,8 +1191,7 @@ "description": "the type of underlying structural variant", "example": "deletion", "type": "string" - }, - "exon1": { + }, "exon1": { "description": "the 5' (n-terminal) exon", "example": 1, "type": [ @@ -1290,6 +1319,16 @@ "integer", "null" ] + }, + "flags": { + "description": "variant flags", + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] } }, "required": [ diff --git a/pori_python/ipr/inputs.py b/pori_python/ipr/inputs.py index f14fc69..568571b 100644 --- a/pori_python/ipr/inputs.py +++ b/pori_python/ipr/inputs.py @@ -60,6 +60,7 @@ 'comments', 'library', 'germline', + 'flags', ] SMALL_MUT_REQ = ['gene', 'proteinChange'] @@ -98,6 +99,7 @@ 'tumourRefCount', 'tumourRefCopies', 'zygosity', + 'flags', ] EXP_REQ = ['gene', 'kbCategory'] @@ -130,6 +132,7 @@ 'rnaReads', 'rpkm', 'tpm', + 'flags', ] SV_REQ = [ @@ -162,12 +165,13 @@ 'tumourDepth', 'germline', 'mavis_product_id', + 'flags', ] SIGV_REQ = ['signatureName', 'variantTypeName'] SIGV_COSMIC = ['signature'] # 1st element used as signatureName key SIGV_HLA = ['a1', 'a2', 'b1', 'b2', 'c1', 'c2'] -SIGV_OPTIONAL = ['displayName'] +SIGV_OPTIONAL = ['displayName', 'flags'] SIGV_KEY = SIGV_REQ[:] @@ -278,6 +282,7 @@ def row_key(row: IprSmallMutationVariant) -> Tuple[str, ...]: return tuple(['small mutation'] + key_vals) result = validate_variant_rows(rows, SMALL_MUT_REQ, SMALL_MUT_OPTIONAL, row_key) + if not result: return [] @@ -336,6 +341,7 @@ def row_key(row: Dict) -> Tuple[str, ...]: return tuple(['expression'] + [row[key] for key in EXP_KEY]) variants = validate_variant_rows(rows, EXP_REQ, EXP_OPTIONAL, row_key) + result = [cast(IprExprVariant, var) for var in variants] float_columns = [ col @@ -371,7 +377,6 @@ def row_key(row: Dict) -> Tuple[str, ...]: if errors: raise ValueError(f'{len(errors)} Invalid expression variants in file') - return result diff --git a/pori_python/ipr/ipr.py b/pori_python/ipr/ipr.py index f5cd687..ca1f1c6 100644 --- a/pori_python/ipr/ipr.py +++ b/pori_python/ipr/ipr.py @@ -160,7 +160,6 @@ def convert_statements_to_alterations( ) if query_result: recruitment_statuses[rid] = query_result[0]['recruitmentStatus'] # type: ignore - for statement in statements: variants = [ cast(Variant, c) for c in statement['conditions'] if c['@class'] in VARIANT_CLASSES @@ -229,6 +228,7 @@ def convert_statements_to_alterations( row['kbContextId'], 'not found' ) rows.append(row) + return rows @@ -727,3 +727,66 @@ def get_kb_disease_matches( raise ValueError(msg) return disease_matches + + +def ensure_str_list(val): + if isinstance(val, str): + return [f.strip() for f in val.split(',') if f.strip()] + if isinstance(val, list): + if not all(isinstance(item, str) for item in val): + raise TypeError('All items in flags must be strings') + return val + raise TypeError(f'Unexpected type in flags field: {type(val).__name__}') + + +def add_transcript_flags(variant_sources, transcript_flags_df): + lookup = dict(zip(transcript_flags_df['transcript'], transcript_flags_df['flags'])) + + for record in variant_sources: + flags_str = lookup.get(record.get('transcript')) + if not flags_str: + continue + # Split on commas and strip whitespace + new_flags = ensure_str_list(str(flags_str)) + flags = ensure_str_list(record.setdefault('flags', [])) + for new_flag in new_flags: + if new_flag not in flags: + flags.append(new_flag) + record['flags'] = flags + + # fusions: check both transcripts for flags and add to the same record + label_map = {'ctermTranscript': 'cterm', 'ntermTranscript': 'nterm'} + + for record in variant_sources: + flags = ensure_str_list(record.setdefault('flags', [])) + + for key, label in label_map.items(): + transcript = record.get(key) + flags_str = lookup.get(transcript) + if not flags_str: + continue + + for flag in ensure_str_list(str(flags_str)): + new_flag = f'{flag} ({label})' + if new_flag not in flags: + flags.append(new_flag) + record['flags'] = flags + return variant_sources + + +def get_variant_flags(variant_sources): + flags = [] + for item in variant_sources: + raw_flags = item.get('flags') + if not raw_flags: # skips None and '' + continue + # create record, removing dupes from flags list + flags.append( + { + 'variant': item['key'], + 'variantType': item['variantType'], + 'flags': list(set([f for f in ensure_str_list(raw_flags) if f])), + } + ) + item.pop('flags', None) # remove after extraction + return flags diff --git a/pori_python/ipr/main.py b/pori_python/ipr/main.py index cbb7c12..c94721a 100644 --- a/pori_python/ipr/main.py +++ b/pori_python/ipr/main.py @@ -6,6 +6,7 @@ import jsonschema.exceptions import logging import os +import pandas as pd from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from typing import Callable, Dict, List, Optional, Sequence, Set @@ -46,6 +47,8 @@ get_kb_disease_matches, get_kb_matches_sections, select_expression_plots, + get_variant_flags, + add_transcript_flags, ) from .summary import auto_analyst_comments, get_ipr_analyst_comments from .therapeutic_options import create_therapeutic_options @@ -157,6 +160,12 @@ def command_interface() -> None: action='store_true', help='True if ignore extra fields in json', ) + parser.add_argument( + '--transcript_flags', + required=False, + type=file_path, + help='TSV without header, with columns: gene, transcript, comma-separated list of flags', + ) args = parser.parse_args() with open(args.content, 'r') as fh: @@ -181,6 +190,7 @@ def command_interface() -> None: upload_json=args.upload_json, validate_json=args.validate_json, ignore_extra_fields=args.ignore_extra_fields, + transcript_flags=args.transcript_flags, ) @@ -234,7 +244,7 @@ def clean_unsupported_content(upload_content: Dict, ipr_spec: Dict = {}) -> Dict for key, count in removed_keys.items(): logger.warning(f"IPR unsupported property '{key}' removed from {count} genes.") - drop_columns = ['variant', 'variantType', 'histogramImage'] + drop_columns = ['variant', 'variantType', 'histogramImage', 'flags'] # DEVSU-2034 - use a 'displayName' VARIANT_LIST_KEYS = [ 'expressionVariants', @@ -281,7 +291,6 @@ def clean_unsupported_content(upload_content: Dict, ipr_spec: Dict = {}) -> Dict # Removing cosmicSignatures. Temporary upload_content.pop('cosmicSignatures', None) - return upload_content @@ -318,6 +327,7 @@ def ipr_report( validate_json: bool = False, ignore_extra_fields: bool = False, tmb_high: float = TMB_SIGNATURE_HIGH_THRESHOLD, + transcript_flags: str = '', ) -> Dict: """Run the matching and create the report JSON for upload to IPR. @@ -386,6 +396,12 @@ def ipr_report( logger.error('Failed schema check - report variants may be corrupted or unmatched.') logger.error(f'Failed schema check: {err}') + transcript_flags_df = None + if transcript_flags: + transcript_flags_df = pd.read_csv( + transcript_flags, sep='\t', names=['gene', 'transcript', 'flags'] + ) + # INPUT VARIANTS VALIDATION & PREPROCESSING (OBSERVED BIOMARKERS) signature_variants: List[IprSignatureVariant] = preprocess_signature_variants( [ @@ -410,6 +426,7 @@ def ipr_report( expression_variants: List[IprExprVariant] = preprocess_expression_variants( content.get('expressionVariants', []) ) + # Additional checks if expression_variants: check_comparators(content, expression_variants) @@ -459,6 +476,10 @@ def ipr_report( *structural_variants, ] # type: ignore + # ANNOTATING VARIANTS WITH TRANSCRIPT FLAGS + if transcript_flags_df is not None and not transcript_flags_df.empty: + all_variants = add_transcript_flags(all_variants, transcript_flags_df) + # GKB_MATCHES FILTERING if match_germline: # verify germline kb statements matched germline observed variants, not somatic variants @@ -527,10 +548,24 @@ def ipr_report( gkb_matches, all_variants, kb_matched_sections['kbMatches'] ) + variant_sources = [ + v + for source in [ + [v for v in small_mutations if v['gene'] in genes_with_variants], + [v for v in copy_variants if v['gene'] in genes_with_variants], + [v for v in expression_variants if v['gene'] in genes_with_variants], + signature_variants, + filter_structural_variants(structural_variants, gkb_matches, gene_information), + ] + for v in source + ] + observed_vars_section = get_variant_flags(variant_sources) + # OUTPUT CONTENT # thread safe deep-copy the original content output = json.loads(json.dumps(content)) output.update(kb_matched_sections) + output.update( { 'copyVariants': [ @@ -550,15 +585,20 @@ def ipr_report( for s in filter_structural_variants( structural_variants, gkb_matches, gene_information ) - ], + ], # TODO NB are we omitting non-matched sv's? 'signatureVariants': [trim_empty_values(s) for s in signature_variants], 'genes': gene_information, 'genomicAlterationsIdentified': key_alterations, 'variantCounts': variant_counts, 'analystComments': comments, 'therapeuticTarget': targets, + 'observedVariantAnnotations': observed_vars_section, } ) + + # TODO there are 13 outliers in the test data; if even only three are matched, why are only those three + # shown in the expression section? shouldn't we be seeing the non-kbmatched vars there as well? + output.setdefault('images', []).extend(select_expression_plots(gkb_matches, all_variants)) # if input includes hrdScore field, that is ok to pass to db @@ -577,6 +617,7 @@ def ipr_report( if not ipr_conn: raise ValueError('ipr_url required to upload report') ipr_spec = ipr_conn.get_spec() + output = clean_unsupported_content(output, ipr_spec) try: logger.info(f'Uploading to IPR {ipr_conn.url}') diff --git a/pori_python/types.py b/pori_python/types.py index dd1ab7e..3840cfc 100644 --- a/pori_python/types.py +++ b/pori_python/types.py @@ -134,11 +134,12 @@ def __hash__(self): class IprVariantBase(TypedDict): - """Required properties of all variants for IPR.""" + """Required or possible properties of all variants for IPR.""" key: str variantType: str variant: str + flags: Optional[List[str]] class IprGeneVariant(IprVariantBase): diff --git a/tests/test_ipr/test_ipr.py b/tests/test_ipr/test_ipr.py index 3e9b01a..68ea2cc 100644 --- a/tests/test_ipr/test_ipr.py +++ b/tests/test_ipr/test_ipr.py @@ -1,4 +1,5 @@ import pytest +import pandas as pd from unittest.mock import Mock, patch from pori_python.graphkb import statement as gkb_statement @@ -12,7 +13,11 @@ get_kb_variants, get_kb_matches_sections, create_key_alterations, + ensure_str_list, + add_transcript_flags, + get_variant_flags, ) + from pori_python.types import Statement DISEASE_RIDS = ['#138:12', '#138:13'] @@ -415,6 +420,133 @@ def test_approved_therapeutic(self, mock_get_evidencelevel_mapping, graphkb_conn assert row['category'] == 'therapeutic' +class TestFlagUtilities: + def test_ensure_str_list_accepts_string(self): + assert ensure_str_list('abc') == ['abc'] + + def test_ensure_str_list_splits_comma_separated_string(self): + assert ensure_str_list('a, b , c') == ['a', 'b', 'c'] + + def test_ensure_str_list_accepts_list_of_strings(self): + assert ensure_str_list(['a', 'b']) == ['a', 'b'] + + def test_ensure_str_list_rejects_bad_types(self): + with pytest.raises(TypeError): + ensure_str_list([1, 'a']) + with pytest.raises(TypeError): + ensure_str_list(123) + + def test_add_transcript_flags_basic_adds_flags_from_comma_separated_string(self): + variant_sources = [ + {'transcript': 'T1', 'key': 'k1', 'variantType': 'mut'}, + ] + df = pd.DataFrame({'transcript': ['T1'], 'flags': ['flag_a,flag_b']}) + result = add_transcript_flags(variant_sources, df) + assert set(result[0]['flags']) == {'flag_a', 'flag_b'} + + def test_add_transcript_flags_basic_converts_string_flag_to_list_avoiding_duplicates(self): + variant_sources = [ + {'transcript': 'T2', 'flags': 'existing', 'key': 'k2', 'variantType': 'mut'}, + ] + df = pd.DataFrame({'transcript': ['T2'], 'flags': ['existing']}) + result = add_transcript_flags(variant_sources, df) + assert result[0]['flags'] == ['existing'] + + def test_add_transcript_flags_basic_leaves_unmatched_transcripts_unaffected(self): + variant_sources = [ + {'transcript': 'T3', 'flags': ['present'], 'key': 'k3', 'variantType': 'mut'}, + ] + df = pd.DataFrame({'transcript': ['T1', 'T2'], 'flags': ['flag_a,flag_b', 'existing']}) + result = add_transcript_flags(variant_sources, df) + assert result[0]['flags'] == ['present'] + + def test_add_transcript_flags_basic_strips_whitespace_from_comma_separated_flags(self): + variant_sources = [ + {'transcript': 'T4', 'key': 'k4', 'variantType': 'mut'}, + ] + df = pd.DataFrame({'transcript': ['T4'], 'flags': ['flag_c, flag_d']}) + result = add_transcript_flags(variant_sources, df) + assert set(result[0]['flags']) == {'flag_c', 'flag_d'} + + def test_add_transcript_flags_fusions_tags_cterm_flags(self): + variant_sources = [ + { + 'key': 'f1', + 'variantType': 'fusion', + 'ctermTranscript': 'CT1', + 'ntermTranscript': 'NT1', + } + ] + df = pd.DataFrame( + { + 'transcript': ['CT1'], + 'flags': ['cterm_flag'], + } + ) + result = add_transcript_flags(variant_sources, df) + flags = result[0]['flags'] + assert 'cterm_flag (cterm)' in flags + + def test_add_transcript_flags_fusions_tags_nterm_flags(self): + variant_sources = [ + { + 'key': 'f1', + 'variantType': 'fusion', + 'ctermTranscript': 'CT1', + 'ntermTranscript': 'NT1', + } + ] + df = pd.DataFrame( + { + 'transcript': ['NT1'], + 'flags': ['nterm_flag'], + } + ) + result = add_transcript_flags(variant_sources, df) + flags = result[0]['flags'] + assert 'nterm_flag (nterm)' in flags + + def test_get_variant_flags_converts_string_flags_to_records(self): + variants = [ + {'key': 'k1', 'variantType': 'mut', 'flags': 'foo'}, + ] + out = get_variant_flags(variants) + assert any(item['variant'] == 'k1' and item['flags'] == ['foo'] for item in out) + assert len(out) == 1 + + def test_get_variant_flags_deduplicates_and_removes_empty_strings(self): + variants = [ + {'key': 'k2', 'variantType': 'mut', 'flags': ['bar', 'bar', '']}, + ] + out = get_variant_flags(variants) + assert any(item['variant'] == 'k2' and set(item['flags']) == {'bar'} for item in out) + + def test_get_variant_flags_skips_null_flags(self): + variants = [ + {'key': 'k3', 'variantType': 'mut', 'flags': None}, + ] + out = get_variant_flags(variants) + assert not any(item['variant'] == 'k3' for item in out) + assert len(out) == 0 + + def test_get_variant_flags_skips_empty_list_flags(self): + variants = [ + {'key': 'k4', 'variantType': 'mut', 'flags': []}, + ] + out = get_variant_flags(variants) + assert not any(item['variant'] == 'k4' for item in out) + assert len(out) == 0 + + def test_get_variant_flags_removes_flags_key_from_processed_records(self): + variants = [ + {'key': 'k1', 'variantType': 'mut', 'flags': 'foo'}, + {'key': 'k2', 'variantType': 'mut', 'flags': ['bar', 'bar', '']}, + ] + get_variant_flags(variants) + assert 'flags' not in variants[0] + assert 'flags' not in variants[1] + + class TestKbmatchFilters: def test_germline_kb_matches(self): assert len(germline_kb_matches(GERMLINE_KB_MATCHES, GERMLINE_VARIANTS)) == len( diff --git a/tests/test_ipr/test_upload.py b/tests/test_ipr/test_upload.py index 2c6fb73..70837bf 100644 --- a/tests/test_ipr/test_upload.py +++ b/tests/test_ipr/test_upload.py @@ -19,7 +19,7 @@ DELETE_UPLOAD_TEST_REPORTS = os.environ.get('DELETE_UPLOAD_TEST_REPORTS', '1') == '1' -def get_test_spec(): +def get_test_spec() -> dict: ipr_spec = {'components': {'schemas': {'genesCreate': {'properties': {}}}}} ipr_gene_keys = IprGene.__required_keys__ | IprGene.__optional_keys__ for key in ipr_gene_keys: @@ -31,12 +31,35 @@ def get_test_file(name: str) -> str: return os.path.join(os.path.dirname(__file__), 'test_data', name) +def get_test_transcript_flags(json_contents) -> pd.DataFrame: + """creates a dataframe of transcript flags for test purposes, based on the input json contents""" + transcript_flags = [] + for item in json_contents['structuralVariants']: + transcript_flags.append((item['gene1'], item['ntermTranscript'], 'TRANSCRIPT FLAG')) + transcript_flags.append((item['gene2'], item['ctermTranscript'], 'TRANSCRIPT FLAG')) + for item in json_contents['smallMutations']: + transcript_flags.append((item['gene'], item['transcript'], 'TRANSCRIPT FLAG')) + df = pd.DataFrame(transcript_flags, columns=['gene', 'transcript', 'flags']) + df = df.drop_duplicates() + return df + + +def add_test_variant_flags_to_input_data(json_contents) -> dict: + """adds flags to the input variants for test purposes""" + for vtype in ['structuralVariants', 'smallMutations', 'copyVariants', 'expressionVariants']: + for item in json_contents[vtype]: + item['flags'] = ['TEST FLAG'] + return json_contents + + @pytest.fixture(scope='module') def loaded_reports(tmp_path_factory) -> Generator: json_file = tmp_path_factory.mktemp('inputs') / 'content.json' async_json_file = tmp_path_factory.mktemp('inputs') / 'async_content.json' + transcript_flags_file = tmp_path_factory.mktemp('inputs') / 'transcript_flags.tsv' patient_id = f'TEST_{str(uuid.uuid4())}' async_patient_id = f'TEST_ASYNC_{str(uuid.uuid4())}' + json_contents = { 'comparators': [ {'analysisRole': 'expression (disease)', 'name': '1'}, @@ -109,6 +132,11 @@ def loaded_reports(tmp_path_factory) -> Generator: 'config': 'test config', } + json_contents = add_test_variant_flags_to_input_data(json_contents) + + transcript_flags_df = get_test_transcript_flags(json_contents) + transcript_flags_df.to_csv(transcript_flags_file, sep='\t', index=False) + json_file.write_text( json.dumps( json_contents, @@ -140,6 +168,8 @@ def loaded_reports(tmp_path_factory) -> Generator: os.environ.get('GRAPHKB_URL', False), '--therapeutics', '--allow_partial_matches', + '--transcript_flags', + str(transcript_flags_file), ] sync_argslist = argslist.copy() @@ -192,7 +222,7 @@ def stringify_sorted(obj): obj.sort() return str(obj) elif isinstance(obj, dict): - for key in ('ident', 'updatedAt', 'createdAt', 'deletedAt'): + for key in ('ident', 'updatedAt', 'createdAt', 'deletedAt', 'variantId', 'id', 'reportId'): obj.pop(key, None) keys = obj.keys() for key in keys: