diff --git a/statina/API/v2/endpoints/batches.py b/statina/API/v2/endpoints/batches.py index 3f573b9..c9352d0 100644 --- a/statina/API/v2/endpoints/batches.py +++ b/statina/API/v2/endpoints/batches.py @@ -167,8 +167,16 @@ def ratio_plot( tris_thresholds=get_trisomy_metadata(dataset=dataset), chromosomes=[ncv], ncv_chrom_data={ncv: get_tris_samples(adapter=adapter, chr=ncv, batch_id=batch_id)}, - normal_data={ncv: get_tris_control_normal(adapter, ncv)}, - abnormal_data={ncv: get_tris_control_abnormal(adapter, ncv, 0)}, + normal_data={ + ncv: get_tris_control_normal( + adapter=adapter, chr=ncv, dataset_name=dataset.name + ) + }, + abnormal_data={ + ncv: get_tris_control_abnormal( + adapter=adapter, chr=ncv, dataset_name=dataset.name, x_axis=0 + ) + }, ), by_alias=False, ), diff --git a/statina/API/v2/endpoints/sample.py b/statina/API/v2/endpoints/sample.py index 45cfb6a..8289202 100644 --- a/statina/API/v2/endpoints/sample.py +++ b/statina/API/v2/endpoints/sample.py @@ -100,10 +100,14 @@ def sample_tris( """Sample view with trisomi plot.""" database_sample: DataBaseSample = find_samples.sample(sample_id=sample_id, adapter=adapter) + dataset = get_dataset(adapter=adapter, batch_id=database_sample.batch_id) abnormal_data: Dict[str, RatioSamples] = ratio_plot_data.get_abn_for_samp_tris_plot( - adapter=adapter + adapter=adapter, dataset_name=dataset.name + ) + + normal_data: Ratio131821 = ratio_plot_data.get_normal_for_samp_tris_plot( + adapter=adapter, dataset_name=dataset.name ) - normal_data: Ratio131821 = ratio_plot_data.get_normal_for_samp_tris_plot(adapter=adapter) sample_data: RatioSamples = ratio_plot_data.get_sample_for_samp_tris_plot(database_sample) return JSONResponse( diff --git a/statina/crud/find/batches.py b/statina/crud/find/batches.py index 9f5e191..63a2c50 100644 --- a/statina/crud/find/batches.py +++ b/statina/crud/find/batches.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Literal +from typing import Iterable, List, Optional, Literal, Any from pydantic import parse_obj_as @@ -6,10 +6,11 @@ from statina.constants import sort_table from statina.crud.utils import paginate from statina.models.database import DatabaseBatch +from typing import List def get_batches_text_query(query_string: str) -> dict: - """Text search with regex, case insensitive""" + """Text search with regex, case-insensitive""" return { "$or": [ {"batch_id": {"$regex": query_string, "$options": "i"}}, @@ -66,3 +67,11 @@ def count_query_batches(adapter: StatinaAdapter, query_string: Optional[str] = " return adapter.batch_collection.count_documents( filter=get_batches_text_query(query_string=query_string) ) + + +def get_batch_ids_by_dataset(adapter: StatinaAdapter, dataset_name: str) -> List[Any]: + batch_ids = [ + doc["batch_id"] + for doc in adapter.batch_collection.find({"dataset": dataset_name}, {"batch_id": 1}) + ] + return batch_ids diff --git a/statina/crud/find/plots/ratio_plot_data.py b/statina/crud/find/plots/ratio_plot_data.py index 4abc389..f111d9d 100644 --- a/statina/crud/find/plots/ratio_plot_data.py +++ b/statina/crud/find/plots/ratio_plot_data.py @@ -3,20 +3,26 @@ import statina from statina.adapter import StatinaAdapter +from statina.crud.find.batches import get_batch_ids_by_dataset from statina.models.database import DataBaseSample from statina.models.server.plots.ncv import Ratio131821, RatioSamples -def get_tris_control_abnormal(adapter: StatinaAdapter, chr, x_axis) -> Dict[str, RatioSamples]: +def get_tris_control_abnormal( + adapter: StatinaAdapter, chr, dataset_name: str, x_axis +) -> Dict[str, RatioSamples]: """Abnormal Control Samples for trisomi plots""" plot_data = {} + batch_ids = get_batch_ids_by_dataset(adapter=adapter, dataset_name=dataset_name) + pipe = [ { "$match": { - f"status_{chr}": {"$ne": "Normal", "$exists": "True"}, + f"status_{chr}": {"$ne": "Normal", "$exists": True}, "include": {"$eq": True}, + "batch_id": {"$in": batch_ids}, } }, { @@ -41,14 +47,16 @@ def get_tris_control_abnormal(adapter: StatinaAdapter, chr, x_axis) -> Dict[str, return plot_data -def get_abn_for_samp_tris_plot(adapter: StatinaAdapter) -> Dict[str, RatioSamples]: +def get_abn_for_samp_tris_plot( + adapter: StatinaAdapter, dataset_name: str +) -> Dict[str, RatioSamples]: """Format abnormal Control Samples for Sample trisomi plot""" plot_data = {} for x_axis, abn in enumerate(["13", "18", "21"], start=1): tris_control_abnormal: Dict[str, RatioSamples] = get_tris_control_abnormal( - adapter, abn, x_axis + adapter=adapter, chr=abn, dataset_name=dataset_name, x_axis=x_axis ) for status, data in tris_control_abnormal.items(): if status not in plot_data: @@ -64,12 +72,20 @@ def get_abn_for_samp_tris_plot(adapter: StatinaAdapter) -> Dict[str, RatioSample def get_tris_control_normal( - adapter: StatinaAdapter, chr: str, x_axis: Optional[int] = None + adapter: StatinaAdapter, chr: str, dataset_name: str, x_axis: Optional[int] = None ) -> RatioSamples: """Normal Control Samples for trisomi plots""" + batch_ids = get_batch_ids_by_dataset(adapter=adapter, dataset_name=dataset_name) + pipe = [ - {"$match": {f"status_{chr}": {"$eq": "Normal"}, "include": {"$eq": True}}}, + { + "$match": { + f"status_{chr}": {"$eq": "Normal"}, + "include": {"$eq": True}, + "batch_id": {"$in": batch_ids}, + } + }, { "$group": { "_id": {f"status_{chr}": f"$status_{chr}"}, @@ -88,26 +104,22 @@ def get_tris_control_normal( return RatioSamples(**data) -def get_normal_for_samp_tris_plot(adapter: StatinaAdapter) -> Ratio131821: +def get_normal_for_samp_tris_plot(adapter: StatinaAdapter, dataset_name: str) -> Ratio131821: """Format normal Control Samples for Sample trisomi plot""" return Ratio131821( - chr_13=get_tris_control_normal(adapter=adapter, chr="13", x_axis=1), - chr_18=get_tris_control_normal(adapter=adapter, chr="18", x_axis=2), - chr_21=get_tris_control_normal(adapter=adapter, chr="21", x_axis=3), + chr_13=get_tris_control_normal( + adapter=adapter, chr="13", dataset_name=dataset_name, x_axis=1 + ), + chr_18=get_tris_control_normal( + adapter=adapter, chr="18", dataset_name=dataset_name, x_axis=2 + ), + chr_21=get_tris_control_normal( + adapter=adapter, chr="21", dataset_name=dataset_name, x_axis=3 + ), ) -def get_abnormal_for_samp_tris_plot(adapter: StatinaAdapter) -> dict: - """Format normal Control Samples for Sample trisomi plot""" - - return { - "13": get_tris_control_abnormal(adapter=adapter, chr="13", x_axis=0), - "18": get_tris_control_abnormal(adapter=adapter, chr="18", x_axis=0), - "21": get_tris_control_abnormal(adapter=adapter, chr="21", x_axis=0), - } - - def get_samples_for_samp_tris_plot(adapter: StatinaAdapter, batch_id: str) -> Ratio131821: return Ratio131821( chr_13=get_tris_samples(adapter=adapter, chr="13", batch_id=batch_id), diff --git a/tests/crud/find/test_batches.py b/tests/crud/find/test_batches.py new file mode 100644 index 0000000..874cd01 --- /dev/null +++ b/tests/crud/find/test_batches.py @@ -0,0 +1,21 @@ +from mongomock import MongoClient +from statina.adapter.plugin import StatinaAdapter +from statina.crud.find.batches import get_batch_ids_by_dataset + + +def test_get_batch_ids_by_dataset(database): + # GIVEN a database with two batches belonging to "dataset_1" and one to "dataset_2" + adapter = StatinaAdapter(database.client, db_name="testdb") + adapter.batch_collection.insert_many( + [ + {"batch_id": "batch_1", "dataset": "dataset_1"}, + {"batch_id": "batch_2", "dataset": "dataset_1"}, + {"batch_id": "batch_3", "dataset": "dataset_2"}, + ] + ) + + # WHEN fetching batch_ids for "dataset_1" + batch_ids = get_batch_ids_by_dataset(adapter=adapter, dataset_name="dataset_1") + + # THEN only the two batch_ids belonging to "dataset_1" should be returned + assert batch_ids == ["batch_1", "batch_2"]