"""REST API entity repository."""
from __future__ import annotations
import typing as t
from aiida import orm
from aiida.common.exceptions import NotExistent
from aiida.common.pydantic import get_metadata
from aiida_restapi.common.exceptions import QueryBuilderException
from aiida_restapi.common.pagination import PaginatedResults
from aiida_restapi.common.query import QueryBuilderParams
from aiida_restapi.common.types import EntityModelType, EntityType
[docs]
class EntityService(t.Generic[EntityType, EntityModelType]):
"""Service for managing AiiDA entities.
This class provides methods to retrieve AiiDA entities with optional filtering, sorting, and pagination.
:param entity_class: The AiiDA ORM entity class associated with this utility, e.g. `orm.User`, `orm.Node`, etc.
"""
[docs]
def __init__(self, entity_class: type[EntityType]) -> None:
self.entity_class = entity_class
self.with_key = entity_class.__name__.lower()
@property
def project(self) -> list[str]:
"""Get the list of projections to use when querying the AiiDA entity.
:return: The list of projections to use when querying the AiiDA entity.
:rtype: list[str]
"""
if not hasattr(self, '_project'):
self._project = self._get_projections()
return self._project
[docs]
def get_schema(self, which: t.Literal['read', 'write'] | None = None) -> dict[str, t.Any]:
"""Get JSON schema for the AiiDA entity.
:param which: The type of schema to retrieve: 'read' or 'write'.
:type which: str | None
:return: A dictionary with 'read' and 'write' keys containing the respective JSON schemas.
:rtype: dict[str, t.Any]
:raises ValueError: If the 'which' parameter is not 'read' or 'write'.
"""
if not which:
return {
'read': self.entity_class.ReadModel.model_json_schema(),
'write': self.entity_class.WriteModel.model_json_schema(),
}
elif which == 'write':
return self.entity_class.WriteModel.model_json_schema()
else:
return self.entity_class.ReadModel.model_json_schema()
[docs]
def get_projections(self) -> list[str]:
"""Get queryable projections for the AiiDA entity.
:return: The list of queryable projections for the AiiDA entity.
:rtype: list[str]
"""
return self.entity_class.fields.keys()
[docs]
def get_one(self, identifier: str | int) -> dict[str, t.Any]:
"""Get an AiiDA entity by id.
:param identifier: The id of the entity to retrieve.
:type identifier: str | int
:return: The serialized AiiDA entity.
:rtype: dict[str, t.Any]
"""
try:
entity = self.entity_class.collection.get(**{self.entity_class.identity_field: identifier})
except NotExistent as exception:
raise NotExistent(f'{self.entity_class.__name__}<{identifier}> does not exist.') from exception
return entity.serialize(minimal=True)
[docs]
def get_many(self, query_params: QueryBuilderParams) -> PaginatedResults[dict[str, t.Any]]:
"""Get AiiDA entities with optional filtering, sorting, and/or pagination.
:param query_params: The query parameters for filtering, sorting, and pagination.
:type query_params: QueryBuilderParams
:return: The paginated results, including total count, current page, page size, and list of serialized entities.
:rtype: PaginatedResults
"""
try:
total = self.entity_class.collection.count(filters=query_params.filters)
results = self.entity_class.collection.query(
filters=query_params.filters,
order_by=query_params.order_by,
limit=query_params.page_size,
offset=query_params.page_size * (query_params.page - 1),
project=self.project,
).dict()
except Exception as exception:
raise QueryBuilderException(str(exception)) from exception
return PaginatedResults(
total=total,
page=query_params.page,
page_size=len(results),
data=[next(iter(result.values())) for result in results],
)
[docs]
def get_field(self, identifier: str | int, field: str) -> t.Any:
"""Get a specific field of an entity.
:param identifier: The id of the entity to retrieve the extras for.
:type identifier: str | int
:param field: The specific field to retrieve.
:type field: str
:return: The value of the specified field.
:rtype: t.Any
"""
qb = self.entity_class.collection.query(
filters={self.entity_class.identity_field: identifier},
project=[field],
)
try:
result = qb.first()
except Exception as exception:
raise QueryBuilderException(str(exception)) from exception
if not result:
raise NotExistent(f'{self.entity_class.__name__}<{identifier}> does not exist.')
return result[0]
[docs]
def add(self, model: EntityModelType) -> dict[str, t.Any]:
"""Create new AiiDA entity from its model.
:param model: The Pydantic model of the entity to create.
:type model: EntityModelType
:return: The created and stored AiiDA `Entity` instance.
:rtype: dict[str, t.Any]
"""
entity = self.entity_class.from_model(model).store()
return entity.serialize(minimal=True)
[docs]
def update(self, identifier: str | int, model: EntityModelType) -> dict[str, t.Any]:
"""Update an existing AiiDA entity.
:param identifier: The id of the entity to update.
:type identifier: str | int
:param model: The Pydantic model of the entity to update.
:type model: EntityModelType
:return: The updated and stored AiiDA `Entity` instance.
:rtype: dict[str, t.Any]
"""
entity = self.entity_class.collection.get(**{self.entity_class.identity_field: identifier})
self._apply_update(entity, model)
return entity.serialize(minimal=True)
[docs]
def _get_projections(self, orm_class: type[orm.Entity] | None = None) -> list[str]:
"""Get the list of projections to use when querying the AiiDA entity.
Exclude fields that may be large.
:param orm_class: The AiiDA ORM entity class to get the projections for.
:type orm_class: type[orm.Entity] | None
:return: The list of projections to use when querying the AiiDA entity.
:rtype: list[str]
"""
orm_class = orm_class or self.entity_class
return [
key
for key, field in orm_class.ReadModel.model_fields.items()
if not get_metadata(
field,
'may_be_large',
)
]
[docs]
def _apply_update(self, entity: orm.Entity, model: EntityModelType) -> None:
"""Apply changes to stored entity from a model payload.
:param entity: The AiiDA entity to update.
:type entity: orm.Entity
:param model: The Pydantic model of the entity to update.
:type model: EntityModelType
"""
raise NotImplementedError