Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deployments/api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ requires-python = ">=3.12.12"
dependencies = [
"fastapi[standard]>=0.124.0",
"greenlet>=3.3.0",
"pydantic>=2.12.5",
"pydantic-settings>=2.12.0",
"sqlalchemy>=2.0.44",
"stitch-core",
"stitch-resources",
]

[project.scripts]
Expand Down Expand Up @@ -40,3 +42,4 @@ addopts = ["-v", "--strict-markers", "--tb=short"]

[tool.uv.sources]
stitch-core = { workspace = true }
stitch-resources = { workspace = true }
32 changes: 13 additions & 19 deletions deployments/api/src/stitch/api/db/init_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sqlalchemy.orm import Session

from stitch.api.db.model import (
CCReservoirsSourceModel,
GemSourceModel,
MembershipModel,
RMIManualSourceModel,
Expand Down Expand Up @@ -280,34 +279,33 @@ def create_dev_user() -> UserModel:
def create_seed_sources():
gem_sources = [
GemSourceModel.from_entity(
GemData(name="Permian Basin Field", country="USA", lat=31.8, lon=-102.3)
GemData(
name="Permian Basin Field",
country="USA",
latitude=31.8,
longitude=-102.3,
)
),
GemSourceModel.from_entity(
GemData(name="North Sea Platform", country="GBR", lat=57.5, lon=1.5)
GemData(
name="North Sea Platform", country="GBR", latitude=57.5, longitude=1.5
)
),
]
for i, src in enumerate(gem_sources, start=1):
src.id = i

wm_sources = [
WMSourceModel.from_entity(
WMData(
field_name="Eagle Ford Shale", field_country="USA", production=125000.5
)
),
WMSourceModel.from_entity(
WMData(field_name="Ghawar Field", field_country="SAU", production=500000.0)
),
WMSourceModel.from_entity(WMData(name="Eagle Ford Shale", country="USA")),
WMSourceModel.from_entity(WMData(name="Ghawar Field", country="SAU")),
]
for i, src in enumerate(wm_sources, start=1):
src.id = i

rmi_sources = [
RMIManualSourceModel.from_entity(
RMIManualData(
name_override="Custom Override Name",
gwp=25.5,
gor=0.45,
name="Custom Override Name",
country="CAN",
latitude=56.7,
longitude=-111.4,
Expand All @@ -317,11 +315,7 @@ def create_seed_sources():
for i, src in enumerate(rmi_sources, start=1):
src.id = i

# CC Reservoir sources are intentionally omitted from the dev seed profile;
# the CCReservoirsSourceModel table is still created from SQLAlchemy metadata.
cc_sources: list[CCReservoirsSourceModel] = []

return gem_sources, wm_sources, rmi_sources, cc_sources
return gem_sources, wm_sources, rmi_sources


def create_seed_resources(user: UserEntity) -> list[ResourceModel]:
Expand Down
2 changes: 0 additions & 2 deletions deployments/api/src/stitch/api/db/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from .sources import (
GemSourceModel,
RMIManualSourceModel,
CCReservoirsSourceModel,
WMSourceModel,
)
from .resource import MembershipStatus, MembershipModel, ResourceModel
from .user import User as UserModel

__all__ = [
"CCReservoirsSourceModel",
"GemSourceModel",
"MembershipModel",
"MembershipStatus",
Expand Down
7 changes: 0 additions & 7 deletions deployments/api/src/stitch/api/db/model/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,6 @@ def copy(self):
# def rmi(self) -> list[RMIManualSourceModel]:
# return self._owner._rmi_sources
#
# @property
# def cc(self) -> list[CCReservoirsSourceModel]:
# return self._owner._cc_sources
#
#
# class SourcesDescriptor:
# def __get__(self, obj: "ResourceModel | None", objtype: Any = None) -> SourceModels:
Expand Down Expand Up @@ -158,9 +154,6 @@ class ResourceModel(TimestampMixin, UserAuditMixin, Base):
# _rmi_sources: Mapped[list[RMIManualSourceModel]] = src_relationship(
# model=RMIManualSourceModel, source="rmi"
# )
# _cc_sources: Mapped[list[CCReservoirsSourceModel]] = src_relationship(
# model=CCReservoirsSourceModel, source="cc"
# )
#
# sources: SourceModels = SourcesDescriptor()

Expand Down
98 changes: 48 additions & 50 deletions deployments/api/src/stitch/api/db/model/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,16 @@
from collections.abc import Mapping, MutableMapping
from typing import Final, Generic, TypeVar, TypedDict, get_args, get_origin
from pydantic import BaseModel
from sqlalchemy import CheckConstraint, inspect
from sqlalchemy import CheckConstraint, DateTime, Float, Integer, String, inspect
from sqlalchemy.orm import Mapped, mapped_column
from .common import Base
from .types import PORTABLE_BIGINT, StitchJson
from .types import PORTABLE_BIGINT, PORTABLE_JSON
from stitch.api.entities import (
CCReservoirsSource,
GemSource,
IdType,
RMIManualSource,
SourceKey,
WMData,
GemData,
RMIManualData,
CCReservoirsData,
WMSource,
)
from stitch.api.sources import OilAndGasFieldSource
from stitch.resources.ogsi import OilAndGasFieldSourceData


def float_constraint(
Expand All @@ -38,6 +32,10 @@ def lon_constraints(colname: str):
return float_constraint(colname, -180, 180)


def year_constraints(colname: str):
return float_constraint(colname, 1800, 2100)


TModelIn = TypeVar("TModelIn", bound=BaseModel)
TModelOut = TypeVar("TModelOut", bound=BaseModel)

Expand All @@ -51,6 +49,41 @@ class SourceBase(Base, Generic[TModelIn, TModelOut]):
PORTABLE_BIGINT, primary_key=True, autoincrement=True
)

# All OilAndGasFieldSourceData columns
name: Mapped[str | None] = mapped_column(String, nullable=True)
country: Mapped[str | None] = mapped_column(String(3), nullable=True)
latitude: Mapped[float | None] = mapped_column(
Float, lat_constraints("latitude"), nullable=True
)
longitude: Mapped[float | None] = mapped_column(
Float, lon_constraints("longitude"), nullable=True
)
last_updated: Mapped[str | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
name_local: Mapped[str | None] = mapped_column(String, nullable=True)
state_province: Mapped[str | None] = mapped_column(String, nullable=True)
region: Mapped[str | None] = mapped_column(String, nullable=True)
basin: Mapped[str | None] = mapped_column(String, nullable=True)
owners: Mapped[list | None] = mapped_column(PORTABLE_JSON, nullable=True)
operators: Mapped[list | None] = mapped_column(PORTABLE_JSON, nullable=True)
location_type: Mapped[str | None] = mapped_column(String, nullable=True)
production_conventionality: Mapped[str | None] = mapped_column(
String, nullable=True
)
primary_hydrocarbon_group: Mapped[str | None] = mapped_column(String, nullable=True)
reservoir_formation: Mapped[str | None] = mapped_column(String, nullable=True)
discovery_year: Mapped[int | None] = mapped_column(
Integer, year_constraints("discovery_year"), nullable=True
)
production_start_year: Mapped[int | None] = mapped_column(
Integer, year_constraints("production_start_year"), nullable=True
)
fid_year: Mapped[int | None] = mapped_column(
Integer, year_constraints("fid_year"), nullable=True
)
field_status: Mapped[str | None] = mapped_column(String, nullable=True)

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
for base in getattr(cls, "__orig_bases__", ()):
Expand All @@ -74,68 +107,33 @@ def from_entity(cls, entity: TModelIn) -> "SourceBase":
return cls(**filtered)


class GemSourceModel(SourceBase[GemData, GemSource]):
class GemSourceModel(SourceBase[OilAndGasFieldSourceData, OilAndGasFieldSource]):
__tablename__ = "gem_sources"

name: Mapped[str]
country: Mapped[str]
lat: Mapped[float] = mapped_column(lat_constraints("lat"))
lon: Mapped[float] = mapped_column(lon_constraints("lon"))


class WMSourceModel(SourceBase[WMData, WMSource]):
class WMSourceModel(SourceBase[OilAndGasFieldSourceData, OilAndGasFieldSource]):
__tablename__ = "wm_sources"

field_name: Mapped[str]
field_country: Mapped[str]
production: Mapped[float]


class RMIManualSourceModel(SourceBase[RMIManualData, RMIManualSource]):
class RMIManualSourceModel(SourceBase[OilAndGasFieldSourceData, OilAndGasFieldSource]):
__tablename__ = "rmi_manual_sources"

name_override: Mapped[str | None]
gwp: Mapped[float | None]
gor: Mapped[float | None | None] = mapped_column(
float_constraint("gor", 0, 1), nullable=True
)
country: Mapped[str | None]
latitude: Mapped[float | None] = mapped_column(
lat_constraints("latitude"), nullable=True
)
longitude: Mapped[float | None] = mapped_column(
lon_constraints("longitude"), nullable=True
)


class CCReservoirsSourceModel(SourceBase[CCReservoirsData, CCReservoirsSource]):
__tablename__ = "cc_reservoirs_sources"

name: Mapped[str]
basin: Mapped[str]
depth: Mapped[float]
geofence: Mapped[list[tuple[float, float]]] = mapped_column(StitchJson())


SourceModel = (
GemSourceModel | WMSourceModel | RMIManualSourceModel | CCReservoirsSourceModel
)
SourceModel = GemSourceModel | WMSourceModel | RMIManualSourceModel
SourceModelCls = type[SourceModel]

SOURCE_TABLES: Final[Mapping[SourceKey, SourceModelCls]] = {
"gem": GemSourceModel,
"wm": WMSourceModel,
"rmi": RMIManualSourceModel,
"cc": CCReservoirsSourceModel,
}


class SourceModelData(TypedDict, total=False):
gem: MutableMapping[IdType, GemSourceModel]
wm: MutableMapping[IdType, WMSourceModel]
cc: MutableMapping[IdType, CCReservoirsSourceModel]
rmi: MutableMapping[IdType, RMIManualSourceModel]


def empty_source_model_data():
return SourceModelData(gem={}, wm={}, cc={}, rmi={})
return SourceModelData(gem={}, wm={}, rmi={})
5 changes: 1 addition & 4 deletions deployments/api/src/stitch/api/db/resource_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)

from .model import (
CCReservoirsSourceModel,
GemSourceModel,
MembershipModel,
RMIManualSourceModel,
Expand Down Expand Up @@ -144,13 +143,11 @@ async def create_source_data(session: AsyncSession, data: CreateSourceData):
gems = tuple(GemSourceModel.from_entity(gem) for gem in data.gem)
wms = tuple(WMSourceModel.from_entity(wm) for wm in data.wm)
rmis = tuple(RMIManualSourceModel.from_entity(rmi) for rmi in data.rmi)
ccs = tuple(CCReservoirsSourceModel.from_entity(cc) for cc in data.cc)

session.add_all(gems + wms + rmis + ccs)
session.add_all(gems + wms + rmis)
await session.flush()
return SourceData(
gem={g.id: g.as_entity() for g in gems},
wm={wm.id: wm.as_entity() for wm in wms},
rmi={rmi.id: rmi.as_entity() for rmi in rmis},
cc={cc.id: cc.as_entity() for cc in ccs},
)
Loading