Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions statina/API/v2/endpoints/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,16 @@ def ratio_plot(
tris_thresholds=get_trisomy_metadata(dataset=dataset),
chromosomes=[ncv],
ncv_chrom_data={ncv: get_tris_samples(adapter=adapter, chr=ncv, batch_id=batch_id)},
normal_data={ncv: get_tris_control_normal(adapter, ncv)},
abnormal_data={ncv: get_tris_control_abnormal(adapter, ncv, 0)},
normal_data={
ncv: get_tris_control_normal(
adapter=adapter, chr=ncv, dataset_name=dataset.name
)
},
abnormal_data={
ncv: get_tris_control_abnormal(
adapter=adapter, chr=ncv, dataset_name=dataset.name, x_axis=0
)
},
),
by_alias=False,
),
Expand Down
8 changes: 6 additions & 2 deletions statina/API/v2/endpoints/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,14 @@ def sample_tris(
"""Sample view with trisomi plot."""

database_sample: DataBaseSample = find_samples.sample(sample_id=sample_id, adapter=adapter)
dataset = get_dataset(adapter=adapter, batch_id=database_sample.batch_id)
abnormal_data: Dict[str, RatioSamples] = ratio_plot_data.get_abn_for_samp_tris_plot(
adapter=adapter
adapter=adapter, dataset_name=dataset.name
)

normal_data: Ratio131821 = ratio_plot_data.get_normal_for_samp_tris_plot(
adapter=adapter, dataset_name=dataset.name
)
normal_data: Ratio131821 = ratio_plot_data.get_normal_for_samp_tris_plot(adapter=adapter)
sample_data: RatioSamples = ratio_plot_data.get_sample_for_samp_tris_plot(database_sample)

return JSONResponse(
Expand Down
13 changes: 11 additions & 2 deletions statina/crud/find/batches.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Iterable, List, Optional, Literal
from typing import Iterable, List, Optional, Literal, Any

from pydantic import parse_obj_as

from statina.adapter import StatinaAdapter
from statina.constants import sort_table
from statina.crud.utils import paginate
from statina.models.database import DatabaseBatch
from typing import List


def get_batches_text_query(query_string: str) -> dict:
"""Text search with regex, case insensitive"""
"""Text search with regex, case-insensitive"""
return {
"$or": [
{"batch_id": {"$regex": query_string, "$options": "i"}},
Expand Down Expand Up @@ -66,3 +67,11 @@ def count_query_batches(adapter: StatinaAdapter, query_string: Optional[str] = "
return adapter.batch_collection.count_documents(
filter=get_batches_text_query(query_string=query_string)
)


def get_batch_ids_by_dataset(adapter: StatinaAdapter, dataset_name: str) -> List[Any]:
batch_ids = [
doc["batch_id"]
for doc in adapter.batch_collection.find({"dataset": dataset_name}, {"batch_id": 1})
]
return batch_ids
52 changes: 32 additions & 20 deletions statina/crud/find/plots/ratio_plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,26 @@

import statina
from statina.adapter import StatinaAdapter
from statina.crud.find.batches import get_batch_ids_by_dataset
from statina.models.database import DataBaseSample
from statina.models.server.plots.ncv import Ratio131821, RatioSamples


def get_tris_control_abnormal(adapter: StatinaAdapter, chr, x_axis) -> Dict[str, RatioSamples]:
def get_tris_control_abnormal(
adapter: StatinaAdapter, chr, dataset_name: str, x_axis
) -> Dict[str, RatioSamples]:
"""Abnormal Control Samples for trisomi plots"""

plot_data = {}

batch_ids = get_batch_ids_by_dataset(adapter=adapter, dataset_name=dataset_name)

pipe = [
{
"$match": {
f"status_{chr}": {"$ne": "Normal", "$exists": "True"},
f"status_{chr}": {"$ne": "Normal", "$exists": True},
"include": {"$eq": True},
"batch_id": {"$in": batch_ids},
}
},
{
Expand All @@ -41,14 +47,16 @@ def get_tris_control_abnormal(adapter: StatinaAdapter, chr, x_axis) -> Dict[str,
return plot_data


def get_abn_for_samp_tris_plot(adapter: StatinaAdapter) -> Dict[str, RatioSamples]:
def get_abn_for_samp_tris_plot(
adapter: StatinaAdapter, dataset_name: str
) -> Dict[str, RatioSamples]:
"""Format abnormal Control Samples for Sample trisomi plot"""

plot_data = {}

for x_axis, abn in enumerate(["13", "18", "21"], start=1):
tris_control_abnormal: Dict[str, RatioSamples] = get_tris_control_abnormal(
adapter, abn, x_axis
adapter=adapter, chr=abn, dataset_name=dataset_name, x_axis=x_axis
)
for status, data in tris_control_abnormal.items():
if status not in plot_data:
Expand All @@ -64,12 +72,20 @@ def get_abn_for_samp_tris_plot(adapter: StatinaAdapter) -> Dict[str, RatioSample


def get_tris_control_normal(
adapter: StatinaAdapter, chr: str, x_axis: Optional[int] = None
adapter: StatinaAdapter, chr: str, dataset_name: str, x_axis: Optional[int] = None
) -> RatioSamples:
"""Normal Control Samples for trisomi plots"""

batch_ids = get_batch_ids_by_dataset(adapter=adapter, dataset_name=dataset_name)

pipe = [
{"$match": {f"status_{chr}": {"$eq": "Normal"}, "include": {"$eq": True}}},
{
"$match": {
f"status_{chr}": {"$eq": "Normal"},
"include": {"$eq": True},
"batch_id": {"$in": batch_ids},
}
},
{
"$group": {
"_id": {f"status_{chr}": f"$status_{chr}"},
Expand All @@ -88,26 +104,22 @@ def get_tris_control_normal(
return RatioSamples(**data)


def get_normal_for_samp_tris_plot(adapter: StatinaAdapter) -> Ratio131821:
def get_normal_for_samp_tris_plot(adapter: StatinaAdapter, dataset_name: str) -> Ratio131821:
"""Format normal Control Samples for Sample trisomi plot"""

return Ratio131821(
chr_13=get_tris_control_normal(adapter=adapter, chr="13", x_axis=1),
chr_18=get_tris_control_normal(adapter=adapter, chr="18", x_axis=2),
chr_21=get_tris_control_normal(adapter=adapter, chr="21", x_axis=3),
chr_13=get_tris_control_normal(
adapter=adapter, chr="13", dataset_name=dataset_name, x_axis=1
),
chr_18=get_tris_control_normal(
adapter=adapter, chr="18", dataset_name=dataset_name, x_axis=2
),
chr_21=get_tris_control_normal(
adapter=adapter, chr="21", dataset_name=dataset_name, x_axis=3
),
)


def get_abnormal_for_samp_tris_plot(adapter: StatinaAdapter) -> dict:
"""Format normal Control Samples for Sample trisomi plot"""

return {
"13": get_tris_control_abnormal(adapter=adapter, chr="13", x_axis=0),
"18": get_tris_control_abnormal(adapter=adapter, chr="18", x_axis=0),
"21": get_tris_control_abnormal(adapter=adapter, chr="21", x_axis=0),
}


def get_samples_for_samp_tris_plot(adapter: StatinaAdapter, batch_id: str) -> Ratio131821:
return Ratio131821(
chr_13=get_tris_samples(adapter=adapter, chr="13", batch_id=batch_id),
Expand Down
21 changes: 21 additions & 0 deletions tests/crud/find/test_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from mongomock import MongoClient
from statina.adapter.plugin import StatinaAdapter
from statina.crud.find.batches import get_batch_ids_by_dataset


def test_get_batch_ids_by_dataset(database):
# GIVEN a database with two batches belonging to "dataset_1" and one to "dataset_2"
adapter = StatinaAdapter(database.client, db_name="testdb")
adapter.batch_collection.insert_many(
[
{"batch_id": "batch_1", "dataset": "dataset_1"},
{"batch_id": "batch_2", "dataset": "dataset_1"},
{"batch_id": "batch_3", "dataset": "dataset_2"},
]
)

# WHEN fetching batch_ids for "dataset_1"
batch_ids = get_batch_ids_by_dataset(adapter=adapter, dataset_name="dataset_1")

# THEN only the two batch_ids belonging to "dataset_1" should be returned
assert batch_ids == ["batch_1", "batch_2"]
Loading