Source code for aiida_restapi.services.entity

"""REST API entity repository."""

from __future__ import annotations

import typing as t

from aiida import orm
from aiida.common.exceptions import NotExistent
from aiida.common.pydantic import get_metadata

from aiida_restapi.common.exceptions import QueryBuilderException
from aiida_restapi.common.pagination import PaginatedResults
from aiida_restapi.common.query import QueryBuilderParams
from aiida_restapi.common.types import EntityModelType, EntityType


[docs] class EntityService(t.Generic[EntityType, EntityModelType]): """Service for managing AiiDA entities. This class provides methods to retrieve AiiDA entities with optional filtering, sorting, and pagination. :param entity_class: The AiiDA ORM entity class associated with this utility, e.g. `orm.User`, `orm.Node`, etc. """
[docs] def __init__(self, entity_class: type[EntityType]) -> None: self.entity_class = entity_class self.with_key = entity_class.__name__.lower()
@property def project(self) -> list[str]: """Get the list of projections to use when querying the AiiDA entity. :return: The list of projections to use when querying the AiiDA entity. :rtype: list[str] """ if not hasattr(self, '_project'): self._project = self._get_projections() return self._project
[docs] def get_schema(self, which: t.Literal['read', 'write'] | None = None) -> dict[str, t.Any]: """Get JSON schema for the AiiDA entity. :param which: The type of schema to retrieve: 'read' or 'write'. :type which: str | None :return: A dictionary with 'read' and 'write' keys containing the respective JSON schemas. :rtype: dict[str, t.Any] :raises ValueError: If the 'which' parameter is not 'read' or 'write'. """ if not which: return { 'read': self.entity_class.ReadModel.model_json_schema(), 'write': self.entity_class.WriteModel.model_json_schema(), } elif which == 'write': return self.entity_class.WriteModel.model_json_schema() else: return self.entity_class.ReadModel.model_json_schema()
[docs] def get_projections(self) -> list[str]: """Get queryable projections for the AiiDA entity. :return: The list of queryable projections for the AiiDA entity. :rtype: list[str] """ return self.entity_class.fields.keys()
[docs] def get_one(self, identifier: str | int) -> dict[str, t.Any]: """Get an AiiDA entity by id. :param identifier: The id of the entity to retrieve. :type identifier: str | int :return: The serialized AiiDA entity. :rtype: dict[str, t.Any] """ try: entity = self.entity_class.collection.get(**{self.entity_class.identity_field: identifier}) except NotExistent as exception: raise NotExistent(f'{self.entity_class.__name__}<{identifier}> does not exist.') from exception return entity.serialize(minimal=True)
[docs] def get_many(self, query_params: QueryBuilderParams) -> PaginatedResults[dict[str, t.Any]]: """Get AiiDA entities with optional filtering, sorting, and/or pagination. :param query_params: The query parameters for filtering, sorting, and pagination. :type query_params: QueryBuilderParams :return: The paginated results, including total count, current page, page size, and list of serialized entities. :rtype: PaginatedResults """ try: total = self.entity_class.collection.count(filters=query_params.filters) results = self.entity_class.collection.query( filters=query_params.filters, order_by=query_params.order_by, limit=query_params.page_size, offset=query_params.page_size * (query_params.page - 1), project=self.project, ).dict() except Exception as exception: raise QueryBuilderException(str(exception)) from exception return PaginatedResults( total=total, page=query_params.page, page_size=len(results), data=[next(iter(result.values())) for result in results], )
[docs] def get_field(self, identifier: str | int, field: str) -> t.Any: """Get a specific field of an entity. :param identifier: The id of the entity to retrieve the extras for. :type identifier: str | int :param field: The specific field to retrieve. :type field: str :return: The value of the specified field. :rtype: t.Any """ qb = self.entity_class.collection.query( filters={self.entity_class.identity_field: identifier}, project=[field], ) try: result = qb.first() except Exception as exception: raise QueryBuilderException(str(exception)) from exception if not result: raise NotExistent(f'{self.entity_class.__name__}<{identifier}> does not exist.') return result[0]
[docs] def add(self, model: EntityModelType) -> dict[str, t.Any]: """Create new AiiDA entity from its model. :param model: The Pydantic model of the entity to create. :type model: EntityModelType :return: The created and stored AiiDA `Entity` instance. :rtype: dict[str, t.Any] """ entity = self.entity_class.from_model(model).store() return entity.serialize(minimal=True)
[docs] def update(self, identifier: str | int, model: EntityModelType) -> dict[str, t.Any]: """Update an existing AiiDA entity. :param identifier: The id of the entity to update. :type identifier: str | int :param model: The Pydantic model of the entity to update. :type model: EntityModelType :return: The updated and stored AiiDA `Entity` instance. :rtype: dict[str, t.Any] """ entity = self.entity_class.collection.get(**{self.entity_class.identity_field: identifier}) self._apply_update(entity, model) return entity.serialize(minimal=True)
[docs] def _get_projections(self, orm_class: type[orm.Entity] | None = None) -> list[str]: """Get the list of projections to use when querying the AiiDA entity. Exclude fields that may be large. :param orm_class: The AiiDA ORM entity class to get the projections for. :type orm_class: type[orm.Entity] | None :return: The list of projections to use when querying the AiiDA entity. :rtype: list[str] """ orm_class = orm_class or self.entity_class return [ key for key, field in orm_class.ReadModel.model_fields.items() if not get_metadata( field, 'may_be_large', ) ]
[docs] def _apply_update(self, entity: orm.Entity, model: EntityModelType) -> None: """Apply changes to stored entity from a model payload. :param entity: The AiiDA entity to update. :type entity: orm.Entity :param model: The Pydantic model of the entity to update. :type model: EntityModelType """ raise NotImplementedError